JDK源码阅读之CountDownLatch

写在前面

_作者Doug Lea_如此描述这个类:A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

分析自JDK 1.8.0_171
这是一个多线程协调的辅助类。源码中给出的示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Driver { // ...
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);

for (int i = 0; i < N; ++i) // create and start threads
new Thread(new Worker(startSignal, doneSignal)).start();

doSomethingElse(); // don't let run yet
startSignal.countDown(); // let all threads proceed
doSomethingElse();
doneSignal.await(); // wait for all to finish
}
}

class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;
Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}
public void run() {
try {
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ex) {} // return;
}

void doWork() { ... }
}}

通过示例对CountDownLatch的使用场景应该有个清晰的认识。即当有需要线程等待,直到在其他线程的一系列操作完成之后,再接着往下执行。

Sync变量

Sync类是CountDownLatch的一个内部类,继承自 AbstractQueuedSynchronizer ,也就是常说的AQS。内部类重写了AQS的 tryAcquireSharedtryReleaseShared 两个方法。此外Sync的构造函数带一个int参数,在构造函数内调用了AQS的 setState 方法,这个方法是对AQS的内部一个int volatile变量赋值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}

protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

在上边的示例中,我们用到了 CountDownLatchawait 方法和 countDown 方法,CountDownLatch的功能也就是通过这两个方法实现。这两个方法其实也就是调用sync。

CountDownLatch.await

功能:当前线程等待,直到state等于0,或该线程interrupt。
该方法就是调用 sync.acquireSharedInterruptibly(1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 调用实现类Sync的tryAcquireShared
if (tryAcquireShared(arg) < 0)
// state 不等于 0 则执行
doAcquireSharedInterruptibly(arg);
}

private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 用一个链表来表示需要"wait"的线程,在链表尾部加入一个node
// 注意 如果head == null ,则初始化一个head,令head.next = node
// 所以能理解state==0,一一唤醒所有等待的线程时,是唤醒头结点的下一节点所表示的线程
// 这样就达到了唤醒所有线程的目的
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
// 前驱
final Node p = node.predecessor();
// 如果当前节点为头节点
if (p == head) {
// 调用CountDownLatch实现类里的tryAcquireShared
int r = tryAcquireShared(arg);
if (r >= 0) {
// 当state == 0时,设置node为头结点
// 并且唤醒(LockSupport.unpark)node下一节点的所表示的线程
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// wait,直到state == 0 时被唤醒(LockSupport.unpark)
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}

总结:CountDownLatch.await就是调用LockSupport.park阻塞当前线程,并用一链表表示所有阻塞的线程,方便唤醒时一一唤醒。

CountDownLatch.countDown

功能:令state减1,但state为0时,唤醒所有等待线程。
该方法是调用 sync.releaseShared(1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
   public final boolean releaseShared(int arg) {
// 调用实现类Sync的tryReleaseShared
// state减1
if (tryReleaseShared(arg)) {
// 当state减为0时
doReleaseShared();
return true;
}
return false;
}

// 唤醒等待链表的头结点的下一节点
// 下一节点唤醒后在doAcquireSharedInterruptibly这个方法中继续循环
// 直到执行setHeadAndPropagate,在此方法中又会调用doReleaseShared,唤醒接下来的节点
// 依次一一唤醒,直到唤醒所有等待线程节点
private void doReleaseShared() {
/*
* Ensure that a release propagates, even if there are other
* in-progress acquires/releases. This proceeds in the usual
* way of trying to unparkSuccessor of head if it needs
* signal. But if it does not, status is set to PROPAGATE to
* ensure that upon release, propagation continues.
* Additionally, we must loop in case a new node is added
* while we are doing this. Also, unlike other uses of
* unparkSuccessor, we need to know if CAS to reset status
* fails, if so rechecking.
*/
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
// cas头结点状态为初始状态
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 此方法中是唤醒头结点的下一线程节点
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
// 如果线程节点改变重新循环
if (h == head) // loop if head changed
break;
}
}

// 唤醒node下一线程节点
private void unparkSuccessor(Node node) {
/*
* If status is negative (i.e., possibly needing signal) try
* to clear in anticipation of signalling. It is OK if this
* fails or if status is changed by waiting thread.
*/
int ws = node.waitStatus;
if (ws < 0)
compareAndSetWaitStatus(node, ws, 0);

/*
* Thread to unpark is held in successor, which is normally
* just the next node. But if cancelled or apparently null,
* traverse backwards from tail to find the actual
* non-cancelled successor.
*/
Node s = node.next;
if (s == null || s.waitStatus > 0) {
s = null;
// node后继节点可能为null或cancel
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
if (s != null)
// 唤醒节点s表示线程
LockSupport.unpark(s.thread);
}

CountDownLatch源码分析的整个流程就是这样。CountDownLatch的功能是基于AQS展开,在后续的JUC的分析文章中还可以看到AQS的身影。谢谢。

-------------The End-------------