Rubin's Blog

  • 首页
  • 关于作者
  • 隐私政策
享受恬静与美好~~~
分享生活的点点滴滴~~~
  1. 首页
  2. 并发编程
  3. 正文

java并发编程之同步工具类

2021年 12月 31日 520点热度 0人点赞 0条评论

Semaphore

Semaphore也就是信号量,提供了资源数量的并发访问控制,其使用代码很简单,如下所示:

package com.rubin.concurrent.semaphore;

import java.util.Random;
import java.util.concurrent.Semaphore;

public class MyThread extends Thread {

    private final Semaphore semaphore;
    private final Random random = new Random();

    public MyThread(String name, Semaphore semaphore) {
        super(name);
        this.semaphore = semaphore;
    }

    @Override
    public void run() {
        try {
            // 获取信标:抢座
            semaphore.acquire();
            // 抢到之后开始写作业
            System.out.println(Thread.currentThread().getName() + " - 抢到了座位,开始写作业");
            Thread.sleep(random.nextInt(1000));
            System.out.println(Thread.currentThread().getName() + " - 作业写完,腾出座位");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        // 释放信标:腾出座位
        semaphore.release();
    }

}
package com.rubin.concurrent.semaphore;
import java.util.concurrent.Semaphore;

public class Main {
    public static void main(String[] args) {
        Semaphore semaphore = new Semaphore(1);
        for (int i = 0; i < 5; i++) {
            new MyThread("学生-" + (i + 1), semaphore).start();
        }
    }
}

如下图所示,假设有n个线程来获取Semaphore里面的10份资源(n > 10),n个线程中只有10个线程能获取到,其他线程都会阻塞。直到有线程释放了资源,其他线程才能获取到。

当初始的资源个数为1的时候,Semaphore退化为排他锁。正因为如此,Semaphone的实现原理和锁十分类似,是基于AQS,有公平和非公平之分。Semaphore相关类的继承体系如下图所示:

public void acquire() throws InterruptedException {
  sync.acquireSharedInterruptibly(1);
}

public void release() {
  sync.releaseShared(1);
}

上述代码我们可以看到,是调用的sync的方法。这个sync是什么呢?

由上图我们可以看到,Sync其实就是AbstractQueuedSynchronizer的实现类。而且有两个子类,实现了公平唤醒和非公平唤醒。由构造方法可以看出,默认是使用非公平唤醒的方式。

AbstractQueuedSynchronizer就是我们常说的AQS,这个类很重要,因为大部分的同步工具都是基于该类实现的并发同步。

我们看一下sync的acquireSharedInterruptibly方法的实现:

/**
  * Acquires in shared mode, aborting if interrupted.  Implemented
  * by first checking interrupt status, then invoking at least once
  * {@link #tryAcquireShared}, returning on success.  Otherwise the
  * thread is queued, possibly repeatedly blocking and unblocking,
  * invoking {@link #tryAcquireShared} until success or the thread
  * is interrupted.
  * @param arg the acquire argument.
  * This value is conveyed to {@link #tryAcquireShared} but is
  * otherwise uninterpreted and can represent anything
  * you like.
  * @throws InterruptedException if the current thread is interrupted
  */
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

上述代码的逻辑,就是先检测一下线程是否接收过中断信号,如果接收过,直接中断线程。没有的话调用tryAcquireShared方法,该方法是一个抽象方法,我们看一下非公平唤醒(公平唤醒的逻辑大同小异,又想去的可以看一下)的实现:

protected int tryAcquireShared(int acquires) {
    return nonfairTryAcquireShared(acquires);
}

final int nonfairTryAcquireShared(int acquires) {
    for (;;) {
        int available = getState();
        int remaining = available - acquires;
        if (remaining < 0 ||
            compareAndSetState(available, remaining))
            return remaining;
    }
}

代码逻辑很简单,就是查看剩余的许可还有多少,如果大于0,直接获取许可(也就是CAS remaining成功);如果小于0或者CAS失败,返回一个负数,返回负数之后会进入AQS的doAcquireSharedInterruptibly方法,我们看一下该方法:

/**
  * Acquires in shared interruptible mode.
  * @param arg the acquire argument
  */
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

由于进入该方法的都是未获取许可但等待获取许可的线程,所以调用addWaiter将当前线程加入到阻塞队列(双向队列)中,之后进入死循环阻塞。

循环的逻辑很简单,就是查看一下自己是不是head节点的后续节点,不是的话检查是否需要阻塞,需要的话阻塞该线程;是的话调用tryAcquireShared方法继续尝试获取许可,未获取成功的话检查是否需要阻塞,需要的话阻塞该线程,获取成功的话后移head节点并进行一系列检查后唤醒该阻塞线程(也就是setHeadAndPropagate方法的逻辑)。

了解了获取许可,那释放许可就很简单了。 我们看一下sync的releaseShared方法的实现:

/**
  * Releases in shared mode.  Implemented by unblocking one or more
  * threads if {@link #tryReleaseShared} returns true.
  *
  * @param arg the release argument.  This value is conveyed to
  *        {@link #tryReleaseShared} but is otherwise uninterpreted
  *        and can represent anything you like.
  * @return the value returned from {@link #tryReleaseShared}
  */
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

逻辑很简单,就是先尝试释放许可,成功之后再执行释放许可的动作。tryReleaseShared在AQS里面也是一个抽象方法,我们看一下实现:

protected final boolean tryReleaseShared(int releases) {
    for (;;) {
        int current = getState();
        int next = current + releases;
        if (next < current) // overflow
            throw new Error("Maximum permit count exceeded");
        if (compareAndSetState(current, next))
            return true;
    }
}

尝试的过程,也就是不断的自旋尝试CAS许可数目,将其数量加上传过来要释放的许可数量。我们再来来看一下doReleaseShared方法的实现:

/**
  * Release action for shared mode -- signals successor and ensures
  * propagation. (Note: For exclusive mode, release just amounts
  * to calling unparkSuccessor of head if it needs signal.)
  */
private void doReleaseShared() {
    /*
     * Ensure that a release propagates, even if there are other
     * in-progress acquires/releases.  This proceeds in the usual
     * way of trying to unparkSuccessor of head if it needs
     * signal. But if it does not, status is set to PROPAGATE to
     * ensure that upon release, propagation continues.
     * Additionally, we must loop in case a new node is added
     * while we are doing this. Also, unlike other uses of
     * unparkSuccessor, we need to know if CAS to reset status
     * fails, if so rechecking.
     */
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

其实就是唤醒阻塞队列的head节点的线程。

CountDownLatch

CountDownLatch使用场景

假设一个主线程要等待4个Worker线程执行完才能退出,可以使用CountDownLatch来实现:

package com.rubin.concurrent.countdownlatch;

import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class MyThread extends Thread {

    private final CountDownLatch latch;
    private final Random random = new Random();

    public MyThread(String name, CountDownLatch latch) {
        super(name);
        this.latch = latch;
    }

    @Override
    public void run() {
        try {
            Thread.sleep(random.nextInt(2000));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(Thread.currentThread().getName() + " - 执行完毕");
        // latch计数减一
        latch.countDown();
    }

}

package com.rubin.concurrent.countdownlatch;

import java.util.concurrent.CountDownLatch;

public class Main {

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(4);
        for (int i = 0; i < 4; i++) {
            new MyThread("线程" + (i + 1), latch).start();
        }
        // main线程等待
        latch.await();
        System.out.println("main线程执行结束");
    }

}

下图为CountDownLatch相关类的继承层次,CountDownLatch原理和Semaphore原理类似,同样是基于AQS,不过没有公平和非公平之分。

await()实现分析

如下所示,await()调用的是AQS的模板方法,这个方法在前面已经介绍过。CountDownLatch.Sync重新实现了tryAccuqireShared方法:

public void await() throws InterruptedException {
    // AQS的模板方法
    sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted())
      throw new InterruptedException();
    // 被CountDownLatch.Sync实现
    if (tryAcquireShared(arg) < 0)
      doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
   return (getState() == 0) ? 1 : -1;
}

从tryAcquireShared(…)方法的实现来看,只要state != 0,调用await()方法的线程便会被放入AQS的阻塞队列,进入阻塞状态。

countDown()实现分析

public void countDown() {
    sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

countDown()调用的AQS的模板方法releaseShared(),里面的tryReleaseShared(…)由CountDownLatch.Sync实现。从上面的代码可以看出,只有state=0,tryReleaseShared(…)才会返回true,然后执行doReleaseShared(…),一次性唤醒队列中所有阻塞的线程。

总结:由于是基于AQS阻塞队列来实现的,所以可以让多个线程都阻塞在state=0条件上,通过countDown()一直减state,减到0后一次性唤醒所有线程。如下图所示,假设初始总数为M,N个线程await(),M个线程countDown(),减到0之后,N个线程被唤醒。

CyclicBarrier

CyclicBarrier使用场景

使用场景:10个工程师一起来公司应聘,招聘方式分为笔试和面试。首先,要等人到齐后,开始笔试;笔试结束之后,再一起参加面试。把10个人看作10个线程,10个线程之间的同步过程如下图所示:

代码实现:

package com.rubin.concurrent.cyclicbarrier;

import java.util.Random;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

public class MyThread extends Thread {

    private final CyclicBarrier barrier;
    private final Random random = new Random();

    public MyThread(String name, CyclicBarrier barrier) {
        super(name);
        this.barrier = barrier;
    }

    @Override
    public void run() {
        try {
            System.out.println(Thread.currentThread().getName() + " - 向公司出发");
            Thread.sleep(random.nextInt(5000));
            System.out.println(Thread.currentThread().getName() + " - 已经到达公司");
            // 等待其他线程该阶段结束
            barrier.await();

            System.out.println(Thread.currentThread().getName() + " - 开始笔试");
            Thread.sleep(random.nextInt(5000));
            System.out.println(Thread.currentThread().getName() + " - 笔试结束");
            // 等待其他线程该阶段结束
            barrier.await();

            System.out.println(Thread.currentThread().getName() + " - 开始面试");
            Thread.sleep(random.nextInt(5000));
            System.out.println(Thread.currentThread().getName() + " - 面试结束");

        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (BrokenBarrierException e) {
            e.printStackTrace();
        }
    }
}
package com.rubin.concurrent.cyclicbarrier;

import java.util.concurrent.CyclicBarrier;

public class Main {

    public static void main(String[] args) {
//        CyclicBarrier barrier = new CyclicBarrier(5);
        CyclicBarrier barrier = new CyclicBarrier(5, new Runnable() {
            @Override
            public void run() {
                System.out.println("该阶段结束");
            }
        });

        for (int i = 0; i < 5; i++) {
            new MyThread("线程-" + (i + 1), barrier).start();
        }
    }

}

CyclicBarrier实现原理

CyclicBarrier基于ReentrantLock+Condition实现。

public class CyclicBarrier {
   private final ReentrantLock lock = new ReentrantLock();
   // 用于线程之间相互唤醒
   private final Condition trip = lock.newCondition();
   // 线程总数
   private final int parties;
   private int count;
   private Generation generation = new Generation();
   // ...
}

下面详细介绍CyclicBarrier的实现原理。先看构造方法:

public CyclicBarrier(int parties, Runnable barrierAction) {
  if (parties <= 0) throw new IllegalArgumentException();
  // 参与方数量
  this.parties = parties;
  this.count = parties;
  // 当所有线程被唤醒时,执行barrierCommand表示的Runnable。
  this.barrierCommand = barrierAction;
}

接下来看一下await()方法的实现过程:

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;

        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

        int index = --count;
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            try {
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

关于上面的方法,有几点说明:

  1. CyclicBarrier是可以被重用的。以上一节的应聘场景为例,来了10个线程,这10个线程互相等待,到齐后一起被唤醒,各自执行接下来的逻辑;然后,这10个线程继续互相等待,到齐后再一起被唤醒。每一轮被称为一个Generation,就是一次同步点
  2. CyclicBarrier会响应中断。10 个线程没有到齐,如果有线程收到了中断信号,所有阻塞的线程也会被唤醒,就是上面的breakBarrier()方法。然后count被重置为初始值(parties),重新开始
  3. 上面的回调方法,barrierAction只会被第10个线程执行1次(在唤醒其他9个线程之前),而不是10个线程每个都执行1次

Exchanger

使用场景

Exchanger用于线程之间交换数据,其使用代码很简单,是一个exchange(…)方法,使用示例如下:

package com.rubin.concurrent.exchanger;

import java.util.Random;
import java.util.concurrent.Exchanger;

public class Main {

    private static final Random random = new Random();

    public static void main(String[] args) {
        // 建一个多线程共用的exchange对象
        // 把exchange对象传给3个线程对象。每个线程在自己的run方法中调用exchange,把自己的数据作为参数
        // 传递进去,返回值是另外一个线程调用exchange传进去的参数
        Exchanger<String> exchanger = new Exchanger<>();

        new Thread("线程1") {
            @Override
            public void run() {
                while (true) {
                    try {
                        // 如果没有其他线程调用exchange,线程阻塞,直到有其他线程调用exchange为止。
                        String otherData = exchanger.exchange("交换数据1");
                        System.out.println(Thread.currentThread().getName() + "得到<==" + otherData);
                        Thread.sleep(random.nextInt(2000));
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }.start();

        new Thread("线程2") {
            @Override
            public void run() {
                while (true) {
                    try {
                        String otherData = exchanger.exchange("交换数据2");
                        System.out.println(Thread.currentThread().getName() + "得到<==" + otherData);
                        Thread.sleep(random.nextInt(2000));
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }.start();

        new Thread("线程3") {
            @Override
            public void run() {
                while (true) {
                    try {
                        String otherData = exchanger.exchange("交换数据3");
                        System.out.println(Thread.currentThread().getName() + "得到<==" + otherData);
                        Thread.sleep(random.nextInt(2000));
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }.start();
    }
}

在上面的例子中,3个线程并发地调用exchange(…),会两两交互数据,如1/2、1/3和2/3。

实现原理

Exchanger的核心机制和Lock一样,也是CAS+park/unpark。

首先,在Exchanger内部,有两个内部类:Participant和Node,代码如下:

public class Exchanger<V> {
  // ...
  // 添加了Contended注解,表示伪共享与缓存行填充
  @jdk.internal.vm.annotation.Contended static final class Node {
    int index;        // Arena index
    int bound;        // Last recorded value of Exchanger.bound
    int collides;      // 本次绑定中,CAS操作失败次数
    int hash;        // 自旋伪随机
    Object item;       // 本线程要交换的数据
    volatile Object match;  // 对方线程交换来的数据
    // 当前线程
    volatile Thread parked; // 当前线程阻塞的时候设置该属性,不阻塞为null。
   }
 
  static final class Participant extends ThreadLocal<Node> {
    public Node initialValue() { return new Node(); }
   }
  // ...
}

每个线程在调用exchange(…)方法交换数据的时候,会先创建一个Node对象。

这个Node对象就是对该线程的包装,里面包含了3个重要字段:第一个是该线程要交互的数据,第二个是对方线程交换来的数据,最后一个是该线程自身。

一个Node只能支持2个线程之间交换数据,要实现多个线程并行地交换数据,需要多个Node,因此在Exchanger里面定义了Node数组:

exchange(V x)实现分析

明白了大致思路,下面来看exchange(V x)方法的详细实现:

public V exchange(V x) throws InterruptedException {
    Object v;
    Object item = (x == null) ? NULL_ITEM : x; // translate null args
    if ((arena != null ||
         (v = slotExchange(item, false, 0L)) == null) &&
        ((Thread.interrupted() || // disambiguates null return
          (v = arenaExchange(item, false, 0L)) == null)))
        throw new InterruptedException();
    return (v == NULL_ITEM) ? null : (V)v;
}

上面方法中,如果arena不是null,表示启用了arena方式交换数据。如果arena不是null,并且线程被中断,则抛异常。

如果arena不是null,并且arenaExchange的返回值为null,则抛异常。对方线程交换来的null值是封装为NULL_ITEM对象的,而不是null。

如果arena是null,slotExchange的返回值是null,并且线程被中断,则抛异常。

如果arena是null,slotExchange的返回值是null,并且areaExchange的返回值是null,则抛异常。

slotExchange的实现:

private final Object slotExchange(Object item, boolean timed, long ns) {
    Node p = participant.get();
    Thread t = Thread.currentThread();
    if (t.isInterrupted()) // preserve interrupt status so caller can recheck
        return null;

    for (Node q;;) {
        if ((q = slot) != null) {
            if (U.compareAndSwapObject(this, SLOT, q, null)) {
                Object v = q.item;
                q.match = item;
                Thread w = q.parked;
                if (w != null)
                    U.unpark(w);
                return v;
            }
            // create arena on contention, but continue until slot null
            if (NCPU > 1 && bound == 0 &&
                U.compareAndSwapInt(this, BOUND, 0, SEQ))
                arena = new Node[(FULL + 2) << ASHIFT];
        }
        else if (arena != null)
            return null; // caller must reroute to arenaExchange
        else {
            p.item = item;
            if (U.compareAndSwapObject(this, SLOT, null, p))
                break;
            p.item = null;
        }
    }

    // await release
    int h = p.hash;
    long end = timed ? System.nanoTime() + ns : 0L;
    int spins = (NCPU > 1) ? SPINS : 1;
    Object v;
    while ((v = p.match) == null) {
        if (spins > 0) {
            h ^= h << 1; h ^= h >>> 3; h ^= h << 10;
            if (h == 0)
                h = SPINS | (int)t.getId();
            else if (h < 0 && (--spins & ((SPINS >>> 1) - 1)) == 0)
                Thread.yield();
        }
        else if (slot != p)
            spins = SPINS;
        else if (!t.isInterrupted() && arena == null &&
                 (!timed || (ns = end - System.nanoTime()) > 0L)) {
            U.putObject(t, BLOCKER, this);
            p.parked = t;
            if (slot == p)
                U.park(false, ns);
            p.parked = null;
            U.putObject(t, BLOCKER, null);
        }
        else if (U.compareAndSwapObject(this, SLOT, p, null)) {
            v = timed && ns <= 0L && !t.isInterrupted() ? TIMED_OUT : null;
            break;
        }
    }
    U.putOrderedObject(p, MATCH, null);
    p.item = null;
    p.hash = h;
    return v;
}

arenaExchange的实现:

private final Object arenaExchange(Object item, boolean timed, long ns) {
    Node[] a = arena;
    Node p = participant.get();
    for (int i = p.index;;) {                      // access slot at i
        int b, m, c; long j;                       // j is raw array offset
        Node q = (Node)U.getObjectVolatile(a, j = (i << ASHIFT) + ABASE);
        if (q != null && U.compareAndSwapObject(a, j, q, null)) {
            Object v = q.item;                     // release
            q.match = item;
            Thread w = q.parked;
            if (w != null)
                U.unpark(w);
            return v;
        }
        else if (i <= (m = (b = bound) & MMASK) && q == null) {
            p.item = item;                         // offer
            if (U.compareAndSwapObject(a, j, null, p)) {
                long end = (timed && m == 0) ? System.nanoTime() + ns : 0L;
                Thread t = Thread.currentThread(); // wait
                for (int h = p.hash, spins = SPINS;;) {
                    Object v = p.match;
                    if (v != null) {
                        U.putOrderedObject(p, MATCH, null);
                        p.item = null;             // clear for next use
                        p.hash = h;
                        return v;
                    }
                    else if (spins > 0) {
                        h ^= h << 1; h ^= h >>> 3; h ^= h << 10; // xorshift
                        if (h == 0)                // initialize hash
                            h = SPINS | (int)t.getId();
                        else if (h < 0 &&          // approx 50% true
                                 (--spins & ((SPINS >>> 1) - 1)) == 0)
                            Thread.yield();        // two yields per wait
                    }
                    else if (U.getObjectVolatile(a, j) != p)
                        spins = SPINS;       // releaser hasn't set match yet
                    else if (!t.isInterrupted() && m == 0 &&
                             (!timed ||
                              (ns = end - System.nanoTime()) > 0L)) {
                        U.putObject(t, BLOCKER, this); // emulate LockSupport
                        p.parked = t;              // minimize window
                        if (U.getObjectVolatile(a, j) == p)
                            U.park(false, ns);
                        p.parked = null;
                        U.putObject(t, BLOCKER, null);
                    }
                    else if (U.getObjectVolatile(a, j) == p &&
                             U.compareAndSwapObject(a, j, p, null)) {
                        if (m != 0)                // try to shrink
                            U.compareAndSwapInt(this, BOUND, b, b + SEQ - 1);
                        p.item = null;
                        p.hash = h;
                        i = p.index >>>= 1;        // descend
                        if (Thread.interrupted())
                            return null;
                        if (timed && m == 0 && ns <= 0L)
                            return TIMED_OUT;
                        break;                     // expired; restart
                    }
                }
            }
            else
                p.item = null;                     // clear offer
        }
        else {
            if (p.bound != b) {                    // stale; reset
                p.bound = b;
                p.collides = 0;
                i = (i != m || m == 0) ? m : m - 1;
            }
            else if ((c = p.collides) < m || m == FULL ||
                     !U.compareAndSwapInt(this, BOUND, b, b + SEQ + 1)) {
                p.collides = c + 1;
                i = (i == 0) ? m : i - 1;          // cyclically traverse
            }
            else
                i = m + 1;                         // grow
            p.index = i;
        }
    }
}

Phaser

用Phaser替代CyclicBarrier和CountDownLatch

从JDK7开始,新增了一个同步工具类Phaser,其功能比CyclicBarrier和CountDownLatch更加强大。

用Phaser替代CountDownLatch

考虑我们使用CountDownLatch的例子,1个主线程要等4个Worker线程完成之后,才能做接下来的事情,也可以用Phaser来实现此功能。在CountDownLatch中,主要是2个方法:await()和countDown()。在Phaser中,与之相对应的方法是awaitAdance(int n)和arrive()。

package com.rubin.concurrent.phaser;

import java.util.Random;
import java.util.concurrent.Phaser;

public class Main {

    public static void main(String[] args) {
        Phaser phaser = new Phaser(5);
        for (int i = 0; i < 5; i++) {
            new Thread("线程-" + (i + 1)) {
                private final Random random = new Random();
                @Override
                public void run() {
                    System.out.println(getName() + " - 开始运行");
                    try {
                        Thread.sleep(random.nextInt(1000));
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    System.out.println(getName() + " - 运行结束");
                    phaser.arrive();
                }
            }.start();
        }
        System.out.println("线程启动完毕");
        System.out.println(phaser.getPhase());
//        phaser.awaitAdvance(phaser.getPhase());
        phaser.awaitAdvance(0);
        System.out.println("线程运行结束");
    }

}

用Phaser替代CyclicBarrier

考虑前面我们使用CyclicBarrier时,10个工程师去公司应聘的例子,也可以用Phaser实现,代码基本类似:

package com.rubin.concurrent.phaser;

import java.util.Random;
import java.util.concurrent.Phaser;

public class MyThread extends Thread {

    private final Phaser phaser;
    private final Random random = new Random();

    public MyThread(String name, Phaser phaser) {
        super(name);
        this.phaser = phaser;
    }

    @Override
    public void run() {
        System.out.println(getName() + " - 开始向公司出发");
        slowly();
        System.out.println(getName() + " - 已经到达公司");
        // 到达同步点,等待其他线程
        phaser.arriveAndAwaitAdvance();

        System.out.println(getName() + " - 开始笔试");
        slowly();
        System.out.println(getName() + " - 笔试结束");
        // 到达同步点,等待其他线程
        phaser.arriveAndAwaitAdvance();

        System.out.println(getName() + " - 开始面试");
        slowly();
        System.out.println(getName() + " - 面试结束");
    }

    private void slowly() {
        try {
            Thread.sleep(random.nextInt(1000));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

}
package com.rubin.concurrent.phaser;

import java.util.concurrent.Phaser;

public class Main1 {

    public static void main(String[] args) {
        Phaser phaser = new Phaser(5);
        for (int i = 0; i < 5; i++) {
            new MyThread("线程-" + (i + 1), phaser).start();
        }
        phaser.awaitAdvance(0);
        System.out.println("运行结束");
    }

}

arriveAndAwaitAdance()就是arrive()与awaitAdvance(int)的组合,表示“我自己已到达这个同步点,同时要等待所有人都到达这个同步点,然后再一起前行”。

Phaser新特性

特性1:动态调整线程个数

CyclicBarrier所要同步的线程个数是在构造方法中指定的,之后不能更改,而Phaser可以在运行期间动态地调整要同步的线程个数。Phaser提供了下面这些方法来增加、减少所要同步的线程个数。

register() // 注册一个
bulkRegister(int parties) // 注册多个
arriveAndDeregister()  // 解除注册

特性2:层次Phaser

多个Phaser可以组成如下图所示的树状结构,可以通过在构造方法中传入父Phaser来实现。

public Phaser(Phaser parent, int parties) {
  // ...
}

先简单看一下Phaser内部关于树状结构的存储,如下所示:

可以发现,在Phaser的内部结构中,每个Phaser记录了自己的父节点,但并没有记录自己的子节点列表。所以,每个 Phaser知道自己的父节点是谁,但父节点并不知道自己有多少个子节点,对父节点的操作,是通过子节点来实现的。

树状的Phaser怎么使用呢?考虑如下代码,会组成下图的树状Phaser。

Phaser root = new Phaser(2);
Phaser c1 = new Phaser(root, 3);
Phaser c2 = new Phaser(root, 2);
Phaser c3 = new Phaser(c1, 0);

本来root有两个参与者,然后为其加入了两个子Phaser(c1,c2),每个子Phaser会算作1个参与者,root的参与者就变成2+2=4个。c1本来有3个参与者,为其加入了一个子Phaser c3,参与者数量变成3+1=4个。c3的参与者初始为0,后续可以通过调用register()方法加入。

对于树状Phaser上的每个节点来说,可以当作一个独立的Phaser来看待,其运作机制和一个单独的Phaser是一样的。

父Phaser并不用感知子Phaser的存在,当子Phaser中注册的参与者数量大于0时,会把自己向父节点注册;当子Phaser中注册的参与者数量等于0时,会自动向父节点解除注册。父Phaser把子Phaser当作一个正常参与的线程就即可。

state变量解析

大致了解了Phaser的用法和新特性之后,下面仔细剖析其实现原理。Phaser没有基于AQS来实现,但具备AQS的核心特性:state变量、CAS操作、阻塞队列。先从state变量说起。

这个64位的state变量被拆成4部分,下图为state变量各部分:

最高位0表示未同步完成,1表示同步完成,初始最高位为0。

Phaser提供了一系列的成员方法来从state中获取上图中的几个数字,如下所示:

下面再看一下state变量在构造方法中是如何被赋值的:

public Phaser(Phaser parent, int parties) {
  if (parties >>> PARTIES_SHIFT != 0)
    // 如果parties数超出了最大个数(2的16次方),抛异常
    throw new IllegalArgumentException("Illegal number of parties");
  // 初始化轮数为0
  int phase = 0;
  this.parent = parent;
  if (parent != null) {
    final Phaser root = parent.root;
    // 父节点的根节点就是自己的根节点
    this.root = root;
    // 父节点的evenQ就是自己的evenQ
    this.evenQ = root.evenQ;
    // 父节点的oddQ就是自己的oddQ
    this.oddQ = root.oddQ;
      // 如果参与者不是0,则向父节点注册自己
    if (parties != 0)
      phase = parent.doRegister(1);
   }
 else {
    // 如果父节点为null,则自己就是root节点
    this.root = this;
    // 创建奇数节点
    this.evenQ = new AtomicReference<QNode>();
    // 创建偶数节点
    this.oddQ = new AtomicReference<QNode>();
   }
  this.state = (parties == 0) ? (long)EMPTY :
   ((long)phase << PHASE_SHIFT) |       // 位或操作,赋值state 最高位
为0,表示同步未完成
   ((long)parties << PARTIES_SHIFT) |
   ((long)parties);
}

当parties=0时,state被赋予一个EMPTY常量,常量为1。

当parties != 0时,把phase值左移32位;把parties左移16位;然后parties也作为最低的16位,3个值做或操作,赋值给state。

阻塞与唤醒(Treiber Stack)

基于上述的state变量,对其执行CAS操作,并进行相应的阻塞与唤醒。主线程会调用awaitAdvance()进行阻塞;子线程调用arrive()会对state进行CAS的累减操作,当未到达的线程数减到0时,唤醒阻塞的主线程。

在这里,阻塞使用的是一个称为Treiber Stack的数据结构,而不是AQS的双向链表。Treiber Stack是一个无锁的栈,它是一个单向链表,出栈、入栈都在链表头部,所以只需要一个head指针,而不需要tail指针,如下的实现:

为了减少并发冲突,这里定义了2个链表,也就是2个Treiber Stack。当phase为奇数轮的时候,阻塞线程放在oddQ里面;当phase为偶数轮的时候,阻塞线程放在evenQ里面。代码如下所示:

arrive()方法分析

下面看arrive()方法是如何对state变量进行操作,又是如何唤醒线程的。

public int arrive() {
    return doArrive(ONE_ARRIVAL);
}

private int doArrive(int adjust) {
    final Phaser root = this.root;
    for (;;) {
        long s = (root == this) ? state : reconcileState();
        int phase = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        int counts = (int)s;
        int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
        if (unarrived <= 0)
            throw new IllegalStateException(badArrive(s));
        if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s-=adjust)) {
            if (unarrived == 1) {
                long n = s & PARTIES_MASK;  // base of next state
                int nextUnarrived = (int)n >>> PARTIES_SHIFT;
                if (root == this) {
                    if (onAdvance(phase, nextUnarrived))
                        n |= TERMINATION_BIT;
                    else if (nextUnarrived == 0)
                        n |= EMPTY;
                    else
                        n |= nextUnarrived;
                    int nextPhase = (phase + 1) & MAX_PHASE;
                    n |= (long)nextPhase << PHASE_SHIFT;
                    UNSAFE.compareAndSwapLong(this, stateOffset, s, n);
                    releaseWaiters(phase);
                }
                else if (nextUnarrived == 0) { // propagate deregistration
                    phase = parent.doArrive(ONE_DEREGISTER);
                    UNSAFE.compareAndSwapLong(this, stateOffset,
                                              s, s | EMPTY);
                }
                else
                    phase = parent.doArrive(ONE_ARRIVAL);
            }
            return phase;
        }
    }
}

关于上面的方法,有以下几点说明:

  • 定义了2个常量如下:

当deregister=false时,只最低的16位需要减 1,adj=ONE_ARRIVAL;当deregister=true时,低32位中的2个16位都需要减1,adj=ONE_ARRIVAL|ONE_PARTY。

  • 把未到达线程数减1。减了之后,如果还未到0,什么都不做,直接返回。如果到0,会做2件事情:第1,重置state,把state的未到达线程个数重置到总的注册的线程数中,同时phase加1;第2,唤醒队列中的线程

下面看一下唤醒方法:

private void releaseWaiters(int phase) {
    QNode q;   // first element of queue
    Thread t;  // its thread
    AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
    while ((q = head.get()) != null &&
           q.phase != (int)(root.state >>> PHASE_SHIFT)) {
        if (head.compareAndSet(q, q.next) &&
            (t = q.thread) != null) {
            q.thread = null;
            LockSupport.unpark(t);
        }
    }
}

遍历整个栈,只要栈当中节点的phase不等于当前Phaser的phase,说明该节点不是当前轮的,而是前一轮的,应该被释放并唤醒。

awaitAdvance()方法分析

查看internalAwaitAdvance方法如下:

/**
  * Possibly blocks and waits for phase to advance unless aborted.
  * Call only on root phaser.
  *
  * @param phase current phase
  * @param node if non-null, the wait node to track interrupt and timeout;
  * if null, denotes noninterruptible wait
  * @return current phase
  */
private int internalAwaitAdvance(int phase, QNode node) {
    // assert root == this;
    releaseWaiters(phase-1);          // ensure old queue clean
    boolean queued = false;           // true when node is enqueued
    int lastUnarrived = 0;            // to increase spins upon change
    int spins = SPINS_PER_ARRIVAL;
    long s;
    int p;
    while ((p = (int)((s = state) >>> PHASE_SHIFT)) == phase) {
        if (node == null) {           // spinning in noninterruptible mode
            int unarrived = (int)s & UNARRIVED_MASK;
            if (unarrived != lastUnarrived &&
                (lastUnarrived = unarrived) < NCPU)
                spins += SPINS_PER_ARRIVAL;
            boolean interrupted = Thread.interrupted();
            if (interrupted || --spins < 0) { // need node to record intr
                node = new QNode(this, phase, false, false, 0L);
                node.wasInterrupted = interrupted;
            }
        }
        else if (node.isReleasable()) // done or aborted
            break;
        else if (!queued) {           // push onto queue
            AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
            QNode q = node.next = head.get();
            if ((q == null || q.phase == phase) &&
                (int)(state >>> PHASE_SHIFT) == phase) // avoid stale enq
                queued = head.compareAndSet(q, node);
        }
        else {
            try {
                ForkJoinPool.managedBlock(node);
            } catch (InterruptedException ie) {
                node.wasInterrupted = true;
            }
        }
    }

    if (node != null) {
        if (node.thread != null)
            node.thread = null;       // avoid need for unpark()
        if (node.wasInterrupted && !node.interruptible)
            Thread.currentThread().interrupt();
        if (p == phase && (p = (int)(state >>> PHASE_SHIFT)) == phase)
            return abortWait(phase); // possibly clean up on abort
    }
    releaseWaiters(phase);
    return p;
}

上面的while循环中有4个分支:

  • 初始的时候,node==null,进入第1个分支进行自旋,自旋次数满足之后,会新建一个QNode节点
  • 之后执行第3、第4个分支,分别把该节点入栈并阻塞

这里调用了ForkJoinPool.managedBlock(ManagedBlocker blocker)方法,目的是把node对应的线程阻塞。ManagerdBlocker是ForkJoinPool里面的一个接口,定义如下:

public static interface ManagedBlocker {
  boolean block() throws InterruptedException;
  boolean isReleasable();
}

QNode实现了该接口,实现原理还是park(),如下所示。之所以没有直接使用park()/unpark()来实现阻塞、唤醒,而是封装了ManagedBlocker这一层,主要是出于使用上的方便考虑。一方面是park()可能被中断唤醒,另一方面是带超时时间的park(),把这二者都封装在一起。

/**
  * Wait nodes for Treiber stack representing wait queue
  */
static final class QNode implements ForkJoinPool.ManagedBlocker {
    final Phaser phaser;
    final int phase;
    final boolean interruptible;
    final boolean timed;
    boolean wasInterrupted;
    long nanos;
    final long deadline;
    volatile Thread thread; // nulled to cancel wait
    QNode next;

    QNode(Phaser phaser, int phase, boolean interruptible,
          boolean timed, long nanos) {
        this.phaser = phaser;
        this.phase = phase;
        this.interruptible = interruptible;
        this.nanos = nanos;
        this.timed = timed;
        this.deadline = timed ? System.nanoTime() + nanos : 0L;
        thread = Thread.currentThread();
    }

    public boolean isReleasable() {
        if (thread == null)
            return true;
        if (phaser.getPhase() != phase) {
            thread = null;
            return true;
        }
        if (Thread.interrupted())
            wasInterrupted = true;
        if (wasInterrupted && interruptible) {
            thread = null;
            return true;
        }
        if (timed) {
            if (nanos > 0L) {
                nanos = deadline - System.nanoTime();
            }
            if (nanos <= 0L) {
                thread = null;
                return true;
            }
        }
        return false;
    }

    public boolean block() {
        if (isReleasable())
            return true;
        else if (!timed)
            LockSupport.park(this);
        else if (nanos > 0L)
            LockSupport.parkNanos(this, nanos);
        return isReleasable();
    }
}

理解了arrive()和awaitAdvance(),arriveAndAwaitAdvance()就是二者的一个组合版本。

以上就是本文的全部内容。欢迎小伙伴们积极留言交流~~~

本作品采用 知识共享署名 4.0 国际许可协议 进行许可
标签: 并发编程
最后更新:2022年 6月 9日

RubinChu

一个快乐的小逗比~~~

打赏 点赞
< 上一篇
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复
文章目录
  • Semaphore
  • CountDownLatch
    • CountDownLatch使用场景
    • await()实现分析
    • countDown()实现分析
  • CyclicBarrier
    • CyclicBarrier使用场景
    • CyclicBarrier实现原理
  • Exchanger
    • 使用场景
    • 实现原理
      • exchange(V x)实现分析
  • Phaser
    • 用Phaser替代CyclicBarrier和CountDownLatch
      • 用Phaser替代CountDownLatch
      • 用Phaser替代CyclicBarrier
    • Phaser新特性
      • 特性1:动态调整线程个数
      • 特性2:层次Phaser
    • state变量解析
    • 阻塞与唤醒(Treiber Stack)
    • arrive()方法分析
    • awaitAdvance()方法分析
最新 热点 随机
最新 热点 随机
问题记录之Chrome设置屏蔽Https禁止调用Http行为 问题记录之Mac设置软链接 问题记录之JDK8连接MySQL数据库失败 面试系列之自我介绍 面试总结 算法思维
MySQL之Sharding-JDBC编排治理剖析 MongoDB之数据备份与恢复 Tomcat之性能优化 Docker之安装 Netty源码环境搭建 Kafka高级特性之生产者

COPYRIGHT © 2021 rubinchu.com. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang

京ICP备19039146号-1