CountDownLatch是基于AQS实现的,AQS是一个抽象的队列同步器,通过维护一个共享的资源状态(state)和一个先进先出的等待队列来实现一个多线程访问共享资源的同步框架, CountDownLatch的sync 类实现了AQS。
public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
// 实现AQS
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 初始化state的值,该值由volatile关键字修饰。
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
// 获取共享资源锁
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// 尝试释放共享资源锁
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
.....
}
上述代码中tryAcquireShared(int acquires)方法表示被唤醒的线程尝试去拿共享锁,如果state的值为0时,返回的1,如果大于0返回-1。
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
此方法的官方解释为Attempts to acquire in shared mode, aborting if interrupted, and failing if the given timeout elapses, 意思是尝试获取共享锁,获取到了返回true, 如果被中断了那么抛出InterruptedException,如果没有获取到,那么进入到doAcquireSharedNanos(int arg, long nanosTimeOut)方法进行阻塞,知道被唤醒超时了那么返回false。
private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (nanosTimeout <= 0L)
return false;
final long deadline = System.nanoTime() + nanosTimeout;
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return true;
}
}
nanosTimeout = deadline - System.nanoTime();
if (nanosTimeout <= 0L)
return false;
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > spinForTimeoutThreshold)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
在队列里的头结点那么会出队去执行tryAcquireShared(arg)方法,如果成功了,那么返回true,如果失败了,那么执行cacelAcquire(node)方法,主要的逻辑是取消争抢资源的动作。
countDownLatch()构造方法
在countDownLatch的构造方法里,可以指定等待的子线程数量,参数值必须>0, 具体是交给了AQS的实现类sync去初始化。
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
在Sync的构造方法里通过setState()方法将AQS里state变量赋值,AQS里的state状态属性由volatile关键字修饰,我们知道,在多线程环境下,volatile关键字能够保证多线程环境下的顺序性和可见性,state字段会被多个子线程共用。
countDown()方法详解
每调用一次countDown() 方法,state值就会-1, 看一下实现
public void countDown() {
sync.releaseShared(1);
}
在Sync的构造方法里通过setState()方法将AQS里state变量赋值,AQS里的state状态属性由volatile关键字修饰,我们知道,在多线程环境下,volatile关键字能够保证多线程环境下的顺序性和可见性,state字段会被多个子线程共用。
countDown()方法详解
每调用一次countDown() 方法,state值就会-1, 看一下实现
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
// 通过自旋,尝试将state-1
if (tryReleaseShared(arg)) {
// 释放资源锁,交给其他线程去争取
doReleaseShared();
return true;
}
return false;
}
追一下tryReleaseShared(args)方法, 可以发现是在AQS里定义的
具体实现在CountDownLatch的 静态内部类Sync里 tryReleaseShared(int releasees)方法里。通过自旋的形式 ,判断当前的state是否为c, 如果是那么就将其当前的state -1,如果不是c,那么表示其他线程已经修改了state值,那么当前线程会重新进入到下一次循环中,直到设置成功,当nextc==0 时,才会返回true,释放共享锁,也就意味着子线程执行完毕。
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
// 直到nextc=0, 释放资源锁。
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
Await()方法详解
await()方法的主要是根据state值是否为0进行操作,如果state==0,那么表示指定数量的线程拿到了锁并执行完毕,如果state>0,那么所有线程就会阻塞到await()方法。
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
如果 state> 0,如下代码即返回-1
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
执行doAcquireSharedInterruptibly(arg):
/**
* Acquires in shared interruptible mode.
* @param arg the acquire argument
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
上述方法和countdown()方法都的在for(;;)里调用了tryAcquireShared(arg)方法,用意是直到state的值等于0时,await()方法执行结束,也就意味着所有线程能通过, 另外采用await()方法的好处是能够设置任务执行的超时时间。
CountDownLatch在RocketMq源码中的应用
批量发送消息等所有回调确认后才返回
Rocketmq中有一个异步批量发送消息的Producer, 如果是批量发送,那么就需要等待所有消息都确认后才返回,然后执行主线程
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.example.simple;
import java.io.UnsupportedEncodingException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.producer.DefaultMQProducer;
import org.apache.rocketmq.client.producer.SendCallback;
import org.apache.rocketmq.client.producer.SendResult;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.remoting.common.RemotingHelper;
public class AsyncProducer {
// 异步批量发送消息
public static void main(
String[] args) throws MQClientException, InterruptedException, UnsupportedEncodingException {
DefaultMQProducer producer = new DefaultMQProducer("Jodie_Daily_test", AclClient.getAclRPCHook());
producer.setNamesrvAddr("127.0.0.1:9876");
producer.start();
producer.setRetryTimesWhenSendAsyncFailed(0);
int messageCount = 100;
final CountDownLatch countDownLatch = new CountDownLatch(messageCount);
for (int i = 0; i < messageCount; i++) {
try {
final int index = i;
Message msg = new Message("Jodie_topic_1023",
"TagA",
"OrderID188",
"Hello world".getBytes(RemotingHelper.DEFAULT_CHARSET));
producer.send(msg, new SendCallback() {
@Override
public void onSuccess(SendResult sendResult) {
countDownLatch.countDown();
System.out.printf("%-10d OK %s %n", index, sendResult.getMsgId());
}
@Override
public void onException(Throwable e) {
countDownLatch.countDown();
System.out.printf("%-10d Exception %s %n", index, e);
e.printStackTrace();
}
});
} catch (Exception e) {
e.printStackTrace();
}
}
countDownLatch.await(5, TimeUnit.SECONDS);
producer.shutdown();
}
}
结果:
broker向所有的NameServer注册信息完毕后才返回
broker向所有的NameServer 注册信息,必须等待所有的NameServer都接收到broker的信息后才返回。
public List<RegisterBrokerResult> registerBrokerAll(
final String clusterName,
final String brokerAddr,
final String brokerName,
final long brokerId,
final String haServerAddr,
final TopicConfigSerializeWrapper topicConfigWrapper,
final List<String> filterServerList,
final boolean oneway,
final int timeoutMills,
final boolean compressed) {
final List<RegisterBrokerResult> registerBrokerResultList = new CopyOnWriteArrayList<>();
List<String> nameServerAddressList = this.remotingClient.getNameServerAddressList();
if (nameServerAddressList != null && nameServerAddressList.size() > 0) {
final RegisterBrokerRequestHeader requestHeader = new RegisterBrokerRequestHeader();
requestHeader.setBrokerAddr(brokerAddr);
requestHeader.setBrokerId(brokerId);
requestHeader.setBrokerName(brokerName);
requestHeader.setClusterName(clusterName);
requestHeader.setHaServerAddr(haServerAddr);
requestHeader.setCompressed(compressed);
RegisterBrokerBody requestBody = new RegisterBrokerBody();
requestBody.setTopicConfigSerializeWrapper(topicConfigWrapper);
requestBody.setFilterServerList(filterServerList);
final byte[] body = requestBody.encode(compressed);
final int bodyCrc32 = UtilAll.crc32(body);
requestHeader.setBodyCrc32(bodyCrc32);
final CountDownLatch countDownLatch = new CountDownLatch(nameServerAddressList.size());
for (final String namesrvAddr : nameServerAddressList) {
brokerOuterExecutor.execute(new Runnable() {
@Override
public void run() {
try {
RegisterBrokerResult result = registerBroker(namesrvAddr,oneway, timeoutMills,requestHeader,body);
if (result != null) {
registerBrokerResultList.add(result);
}
log.info("register broker[{}]to name server {} OK", brokerId, namesrvAddr);
} catch (Exception e) {
log.warn("registerBroker Exception, {}", namesrvAddr, e);
} finally {
countDownLatch.countDown();
}
}
});
}
try {
countDownLatch.await(timeoutMills, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
}
}
return registerBrokerResultList;
}
评论区