CountDownLatch及AQS共享锁源码分析

首先说AQS有2种锁的机制,也就是共享锁和独占锁(也叫排它锁)。上一次通过读ReentrantLock把AQS的独占锁分析了一遍。那其实AQS还有另外一半,就是共享锁。共享锁在ReentrantReadWriteLock的读锁中有体现:

1
2
3
4
5
6
7
8
9
10
11
12
13
public static class ReadLock implements Lock, java.io.Serializable {
public void lock() {
// 这里是acquireShared方法
sync.acquireShared(1);
}
}

public static class WriteLock implements Lock, java.io.Serializable {
public void lock() {
// 这里是acquire方法
sync.acquire(1);
}
}

ReadLock和WriteLock在调用lock()方法时,在AQS中调用的分别是acquireShared()和acquire()方法。ReentrantReadWriteLock放在这里暂且不谈,本文想通过CountDownLatch(闭锁)来研究一下AQS中的共享锁。我们一起分析之后,结合上次写的独占锁,相信聪明的读者自己会很容易读懂ReentrantReadWriteLock的。

CountDownLatch用法

其实用法我是一直不想详细说的,到了看源码这个程度,用法总不会是什么问题。我就贴一下Doug Lea在源码中给的示例吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Driver {
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);

for (int i = 0; i < N; ++i)
// 起N个线程
new Thread(new Worker(startSignal, doneSignal)).start();

// main线程先做事情,先不让其他线程跑
doSomethingElse();
// N个线程可以跑起来了
startSignal.countDown();
doSomethingElse();
// main线程等待N个线程跑完
doneSignal.await();
}
}

class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;

Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}

public void run() {
try {
// 阻塞住,等待开始信号
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ex) {
}
}

void doWork() { ...}
}

其实CountDownLatch可以理解为火箭发射的倒计时,所有的准备和检查步骤都结束后,开始发射!

这里涉及到CountDownLatch中2个比较重要的方法:

  • await():调用这个方法的线程会被阻塞住,当计数器为0时,这个线程被唤醒,继续执行;
  • countDown():调用这个方法会让计数器 - 1。

所以下面针对源码的分析也主要是这两个方法。

CountDownLatch源码分析

其实共享锁和独占锁有很多代码是复用的,但是为了完整性,下面的分析我尽量还是都写上,温故而知新嘛。

CountDownLatch的构造方法

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

把它单拎出来就是想说,它最开始就对state(count)做了验证,所以下面分析的state只有2种情况,要么等于0,要么大于0。

然后就是把这个count(火箭发射前需要完成的步骤数)扔进Sync内部类。

await()方法

上面我们说了await()主要作用是阻塞住当前线程,直到计数器为0。

1
2
3
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

sync.acquireSharedInterruptibly(1)

1
2
3
4
5
6
7
8
9
10
11
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
// 如果线程被interrupt了,抛异常
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
// 如果当前计数器的值不为0,就继续往下走
// (发射前有步骤还没完成谁敢发射?)
// 接下来去park线程
doAcquireSharedInterruptibly(arg);
}
  • 如果tryAcquireShared为负,发令枪没响,进去park线程,自己阻塞;
  • 如果tryAcquireShared为正,state归0了,就不用阻塞了,直接返回。

所以CountDownLatch解阻塞本质上就是await()方法的正常返回(想起来ReentrantLock成功拿到锁的标志就是lock()方法的正常返回)。

tryAcquireShared(arg)

这个方法在CountDownLatch的Sync内部类中是被重写了的,多次被调用:

1
2
3
4
5
6
protected int tryAcquireShared(int acquires) {
// 如果state大于0,步骤没完成,返回负数
// 如果state等于0,步骤全完成了,返回正数
// 这里别纠结正负1了,记正负数就行了
return (getState() == 0) ? 1 : -1;
}

doAcquireSharedInterruptibly(arg)

这个方法的代码就似曾相识了对吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 先创建一个队列(链表),把所有的线程加进去park住
// 这里就是CountDownLatch的设计理念了,它不是抢锁
// 而是先把所有线程阻塞(park)住,等待state归0
// 发令枪一响,阻塞的线程再继续往下跑,火箭升空
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (; ; ) {
// 死循环
// 上半部分逻辑是unpark后执行的
// 后面分析countDown()的时候再说
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 先看下半部分的逻辑,这里的判断和独占锁中抢不到锁的逻辑一样
// 1. 先修改前一个node的ws
if (shouldParkAfterFailedAcquire(p, node) &&
// 2. 再park自己
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

接下来回顾一下下面3个方法吧:

  • addWaiter(Node.SHARED)
  • shouldParkAfterFailedAcquire(p, node)
  • parkAndCheckInterrupt()

addWaiter(Node.SHARED)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
private Node addWaiter(Node mode) {
// 把当前线程封装为Node对象,mode为共享(ReentrantLock中的读锁才是独占mode)
Node node = new Node(Thread.currentThread(), mode);
// Try the fast path of enq; backup to full enq on failure
Node pred = tail;
// 判断队尾是否为null,其实只要队列被初始化了,队尾一定不为null
// 换而言之这里判断的就是队列有没有被初始化,也就是上面我们说的那2种情况
if (pred != null) {
// 队尾不为空,队列已经初始化了
// 这种情况比较简单,把当前Node设置为队尾,维护链表关系即可
node.prev = pred;
if (compareAndSetTail(pred, node)) {
pred.next = node;
return node;
}
}
// 队列没有被初始化,下面来分析这个enq
enq(node);
// 返回当前node
return node;
}

// enq每次执行的情况都不同,这里分次来分析
private Node enq(final Node node) {
// 死循环
for (;;) {
// 第一次进入,t为null,队列没有被初始化
Node t = tail;
if (t == null) { // Must initialize
// 调用无参的Node构造方法,也就是加了一个thread为null的虚拟Node
// 并把这个虚拟Node设为头部
if (compareAndSetHead(new Node()))
// 此时AQS中只有一个元素,就是这个虚拟Node
// 然后将尾部指向它,第一次循环结束
tail = head;
} else {
node.prev = t;
if (compareAndSetTail(t, node)) {
t.next = node;
return t;
}
}
}
}

// 下面我们来看第二次循环
// 这个代码写的比较高效,处理了不同的情况
private Node enq(final Node node) {
// 死循环
for (;;) {
// 这个时候t指向的是那个虚拟Node了,不为null
Node t = tail;
if (t == null) { // Must initialize
if (compareAndSetHead(new Node()))
tail = head;
} else {
// 当前Node入队,前驱为虚拟Node
node.prev = t;
// cas设置尾部
if (compareAndSetTail(t, node)) {
// 维护链表关系
t.next = node;
// 返回虚拟节点?这个返回其实就是终止死循环
// 返回出去的t没啥意义,外面的addWaiter没有接收enq的返回
return t;
}
}
}
}

总结下来这个addWaiter()就是创建AQS队列,把所有await()的线程封装为Node,再通过CAS维护这个队列。

shouldParkAfterFailedAcquire(p, node)

这里以第一个线程为例(pred为虚拟节点,node为当前线程对应的节点):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
// 拿到虚拟线程的ws,此时应该为0
int ws = pred.waitStatus;
if (ws == Node.SIGNAL)
// Node.SIGNAL为-1,ws为0,不进这里
/*
* This node has already set status asking a release
* to signal it, so it can safely park.
*/
return true;
if (ws > 0) {
/*
* Predecessor was cancelled. Skip over predecessors and
* indicate retry.
*/
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {
/*
* waitStatus must be 0 or PROPAGATE. Indicate that we
* need a signal, but don't park yet. Caller will need to
* retry to make sure it cannot acquire before parking.
*/
// 将前驱节点的ws改为-1,原因在另一篇里面说过了
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}

这个方法由于外面是死循环,会进2次:

  • 第一次将前面node的ws改为-1,返回false;
    • 先返回false进入第二次循环的主要目的是多自旋一次,如果在此期间,state被归0了,可以不park线程,直接返回了;
    • 多自旋一次的目的就是想从java层面优化效率,避免park()操作的用户态到内核态的切换;
  • 第二次判断前驱node的ws为-1小于0,直接返回true,接下来去park线程。

parkAndCheckInterrupt()

1
2
3
4
5
6
7
8
private final boolean parkAndCheckInterrupt() {
// park自己
LockSupport.park(this);
// 执行完park()之后,线程就阻塞在这里了
// 这也是CountDownLatch调用await()的终点
// 等待countDown()中一旦state归0的unpark()
return Thread.interrupted();
}

countDown()方法

countDown()的入口就是下面的代码,语义是将AQS中的state减1:

1
2
3
public void countDown() {
sync.releaseShared(1);
}

sync.releaseShared(1)

releaseShared()无论怎样都会成功返回:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public final boolean releaseShared(int arg) {
// 这个tryReleaseShared()就是将state减1
// 并且返回要不要unpark所有线程
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
// 死循环的目的是避免CAS操作失败
int c = getState();
// 拿到当前state
if (c == 0)
// 如果已经是0了,啥都不用做
// 外面也不用unpark
// 因为上一个归0的线程已经全部unpark好了,不需要重复调用
return false;
int nextc = c-1;
// CAS减1
if (compareAndSetState(c, nextc))
// 返回操作完的state是不是0
// 如果是0,出去unpark线程
// 如果不是,state单纯减1
return nextc == 0;
}
}

doReleaseShared()

这是个新的方法,我们重点看一下,以下说的都是state归0后,unpark线程的逻辑,要注意此时我们AQS队列中应该有节点了,至少有2个(别问我为啥),或者茫茫多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
private void doReleaseShared() {
/*
* Ensure that a release propagates, even if there are other
* in-progress acquires/releases. This proceeds in the usual
* way of trying to unparkSuccessor of head if it needs
* signal. But if it does not, status is set to PROPAGATE to
* ensure that upon release, propagation continues.
* Additionally, we must loop in case a new node is added
* while we are doing this. Also, unlike other uses of
* unparkSuccessor, we need to know if CAS to reset status
* fails, if so rechecking.
*/
for (;;) {
// 取到头结点,第一次调用应该为虚拟Node
Node h = head;
if (h != null && h != tail) {
// 至少有2个Node,h不为null,也不为tail
int ws = h.waitStatus;
// 在await()中,h的ws已经被改为-1了
if (ws == Node.SIGNAL) {
// 将h的ws改为0
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
// 如果改失败,说明其他线程已经改过了
// 此时的h.next应该指向null了to help gc(之后会分析为啥h出队了)
// 那么跳过这轮检查,重新拿head
continue; // loop to recheck cases
// unpark h的下一个Node对应的线程
unparkSuccessor(h);
}
// 这个else和本流程无关
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
// 这个break非常巧妙,这里先留一个坑,下面再说等于head的退出条件
if (h == head) // loop if head changed
break;
}
}

这个方法从head这个虚拟节点开始unpark AQS队列中的所有线程,这里有2个问题:

  1. unparkSuccessor(h)做了啥?
  2. 这死循环是咋退出的?

unparkSuccessor(h)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private void unparkSuccessor(Node node) {
/*
* If status is negative (i.e., possibly needing signal) try
* to clear in anticipation of signalling. It is OK if this
* fails or if status is changed by waiting thread.
*/
int ws = node.waitStatus;
// ws已经为0了,外面改过了
if (ws < 0)
compareAndSetWaitStatus(node, ws, 0);

/*
* Thread to unpark is held in successor, which is normally
* just the next node. But if cancelled or apparently null,
* traverse backwards from tail to find the actual
* non-cancelled successor.
*/
// 取到“排队”的第一个线程为s
Node s = node.next;
if (s == null || s.waitStatus > 0) {
// ws大于0这种情况我们先不讨论
s = null;
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
// 如果有排队线程,unpark唤醒它
if (s != null)
LockSupport.unpark(s.thread);
}

唤醒线程之后,我们要回到await()的阻塞代码中看了。这本来不是countDown()的代码,但流程上是连通的,就在这里继续说吧。

await()中的阻塞唤醒

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (; ; ) {
// 2. p是node的前驱,此时就是head节点
final Node p = node.predecessor();
if (p == head) {
// state为0了,tryAcquireShared(arg)返回正数
int r = tryAcquireShared(arg);
if (r >= 0) {
// 将当前node设为head
// 并且一连串的unpark下去
// 往下看,详细分析
setHeadAndPropagate(node, r);
// 将head断链,也就是上面我们说的h.next指向null了
p.next = null; // help GC
failed = false;
// await() return,最外面阻塞的线程开始运行
// 火箭升空!掌声鼓励
return;
}
}

if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
// 1. unpark之后,判断没有被interrupt,继续下一轮循环
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

每一个被unpark的线程都会尝试去CAS设置自己为head,并且去unpark下一个线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
// 将当前node设为head,head仍然为虚拟节点
setHead(node);
/*
* Try to signal next queued node if:
* Propagation was indicated by caller,
* or was recorded (as h.waitStatus either before
* or after setHead) by a previous operation
* (note: this uses sign-check of waitStatus because
* PROPAGATE status may transition to SIGNAL.)
* and
* The next node is waiting in shared mode,
* or we don't know, because it appears null
*
* The conservatism in both of these checks may cause
* unnecessary wake-ups, but only when there are multiple
* racing acquires/releases, so most need signals now or soon
* anyway.
*/
if (propagate > 0 || h == null || h.waitStatus < 0 ||
// 下面这些||操作是避免中间有别的线程改了head
// 于是重新获取head和判断ws
(h = head) == null || h.waitStatus < 0) {
// node已经unpark完了,取node的下一个节点
Node s = node.next;
// 下一个节点为null,或者不为空且为shared
if (s == null || s.isShared())
doReleaseShared();
}
}

最下面进入doReleaseShared()的条件为:

  • 下一个节点为null,那么此时所有的线程已经unpark了,按理来说再进入doReleaseShared()的时候应该退出那个死循环了;
  • 下一个节点不为null,并且是shared的,那么进死循环去unpark下一个Node。

再看一下doReleaseShared()方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
private void doReleaseShared() {
/*
* Ensure that a release propagates, even if there are other
* in-progress acquires/releases. This proceeds in the usual
* way of trying to unparkSuccessor of head if it needs
* signal. But if it does not, status is set to PROPAGATE to
* ensure that upon release, propagation continues.
* Additionally, we must loop in case a new node is added
* while we are doing this. Also, unlike other uses of
* unparkSuccessor, we need to know if CAS to reset status
* fails, if so rechecking.
*/
for (;;) {
Node h = head;
// 如果h为null了,所有的都unpark了,跳到最底下break
if (h != null && h != tail) {
// 如果队列中还有线程,继续unpark下一个
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
// 全部unpark结束,退出,所有线程的await()停止阻塞
if (h == head) // loop if head changed
break;
}
}

所以如此往复,直到所有线程都被唤醒。

CountDownLatch和Thread.join()对比

首先说CountDownLatch和join()的机制不同:

  • CountDownLatch是通过park()将线程挂起;
  • join()的原理是,假如在main线程中,调用某个thread的join()方法,那么main线程就会阻塞,直到thread执行完毕,main才继续执行;join()是不断检查thread是否存活,如果存活,那么让当前线程一直wait(),直到thread线程终止,线程的this.notifyAll()就会被调用

另外,join()一定要等thread完全执行完毕才能继续向下执行,而CountDownLatch只需要计数器为0就能继续执行。根据不同的业务场景,CountDownLatch更加灵活一些。

总结

看完这篇分析之后,尝试自行看看ReentrantReadWriteLock的源码吧。