use crate::job::{JobFifo, JobRef, StackJob};
use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LockLatch, SpinLatch};
use crate::log::Event::*;
use crate::log::Logger;
use crate::sleep::Sleep;
use crate::unwind;
use crate::{
ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
};
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use std::any::Any;
use std::cell::Cell;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::Hasher;
use std::io;
use std::mem;
use std::ptr;
#[allow(deprecated)]
use std::sync::atomic::ATOMIC_USIZE_INIT;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Once};
use std::thread;
use std::usize;
pub struct ThreadBuilder {
name: Option<String>,
stack_size: Option<usize>,
worker: Worker<JobRef>,
registry: Arc<Registry>,
index: usize,
}
impl ThreadBuilder {
pub fn index(&self) -> usize {
self.index
}
pub fn name(&self) -> Option<&str> {
self.name.as_ref().map(String::as_str)
}
pub fn stack_size(&self) -> Option<usize> {
self.stack_size
}
pub fn run(self) {
unsafe { main_loop(self.worker, self.registry, self.index) }
}
}
impl fmt::Debug for ThreadBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ThreadBuilder")
.field("pool", &self.registry.id())
.field("index", &self.index)
.field("name", &self.name)
.field("stack_size", &self.stack_size)
.finish()
}
}
pub trait ThreadSpawn {
private_decl! {}
fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
}
#[derive(Debug, Default)]
pub struct DefaultSpawn;
impl ThreadSpawn for DefaultSpawn {
private_impl! {}
fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
let mut b = thread::Builder::new();
if let Some(name) = thread.name() {
b = b.name(name.to_owned());
}
if let Some(stack_size) = thread.stack_size() {
b = b.stack_size(stack_size);
}
b.spawn(|| thread.run())?;
Ok(())
}
}
#[derive(Debug)]
pub struct CustomSpawn<F>(F);
impl<F> CustomSpawn<F>
where
F: FnMut(ThreadBuilder) -> io::Result<()>,
{
pub(super) fn new(spawn: F) -> Self {
CustomSpawn(spawn)
}
}
impl<F> ThreadSpawn for CustomSpawn<F>
where
F: FnMut(ThreadBuilder) -> io::Result<()>,
{
private_impl! {}
#[inline]
fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
(self.0)(thread)
}
}
pub(super) struct Registry {
logger: Logger,
thread_infos: Vec<ThreadInfo>,
sleep: Sleep,
injected_jobs: Injector<JobRef>,
panic_handler: Option<Box<PanicHandler>>,
start_handler: Option<Box<StartHandler>>,
exit_handler: Option<Box<ExitHandler>>,
terminate_count: AtomicUsize,
}
static mut THE_REGISTRY: Option<Arc<Registry>> = None;
static THE_REGISTRY_SET: Once = Once::new();
pub(super) fn global_registry() -> &'static Arc<Registry> {
set_global_registry(|| Registry::new(ThreadPoolBuilder::new()))
.or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
.expect("The global thread pool has not been initialized.")
}
pub(super) fn init_global_registry<S>(
builder: ThreadPoolBuilder<S>,
) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
where
S: ThreadSpawn,
{
set_global_registry(|| Registry::new(builder))
}
fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
where
F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
{
let mut result = Err(ThreadPoolBuildError::new(
ErrorKind::GlobalPoolAlreadyInitialized,
));
THE_REGISTRY_SET.call_once(|| {
result = registry()
.map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) })
});
result
}
struct Terminator<'a>(&'a Arc<Registry>);
impl<'a> Drop for Terminator<'a> {
fn drop(&mut self) {
self.0.terminate()
}
}
impl Registry {
pub(super) fn new<S>(
mut builder: ThreadPoolBuilder<S>,
) -> Result<Arc<Self>, ThreadPoolBuildError>
where
S: ThreadSpawn,
{
let n_threads = builder.get_num_threads();
let breadth_first = builder.get_breadth_first();
let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
.map(|_| {
let worker = if breadth_first {
Worker::new_fifo()
} else {
Worker::new_lifo()
};
let stealer = worker.stealer();
(worker, stealer)
})
.unzip();
let logger = Logger::new(n_threads);
let registry = Arc::new(Registry {
logger: logger.clone(),
thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
sleep: Sleep::new(logger, n_threads),
injected_jobs: Injector::new(),
terminate_count: AtomicUsize::new(1),
panic_handler: builder.take_panic_handler(),
start_handler: builder.take_start_handler(),
exit_handler: builder.take_exit_handler(),
});
let t1000 = Terminator(®istry);
for (index, worker) in workers.into_iter().enumerate() {
let thread = ThreadBuilder {
name: builder.get_thread_name(index),
stack_size: builder.get_stack_size(),
registry: registry.clone(),
worker,
index,
};
if let Err(e) = builder.get_spawn_handler().spawn(thread) {
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
}
}
mem::forget(t1000);
Ok(registry.clone())
}
pub(super) fn current() -> Arc<Registry> {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().clone()
} else {
(*worker_thread).registry.clone()
}
}
}
pub(super) fn current_num_threads() -> usize {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().num_threads()
} else {
(*worker_thread).registry.num_threads()
}
}
}
pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
unsafe {
let worker = WorkerThread::current().as_ref()?;
if worker.registry().id() == self.id() {
Some(worker)
} else {
None
}
}
}
pub(super) fn id(&self) -> RegistryId {
RegistryId {
addr: self as *const Self as usize,
}
}
#[inline]
pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
self.logger.log(event)
}
pub(super) fn num_threads(&self) -> usize {
self.thread_infos.len()
}
pub(super) fn handle_panic(&self, err: Box<dyn Any + Send>) {
match self.panic_handler {
Some(ref handler) => {
let abort_guard = unwind::AbortIfPanic;
handler(err);
mem::forget(abort_guard);
}
None => {
let _ = unwind::AbortIfPanic;
}
}
}
pub(super) fn wait_until_primed(&self) {
for info in &self.thread_infos {
info.primed.wait();
}
}
#[cfg(test)]
pub(super) fn wait_until_stopped(&self) {
for info in &self.thread_infos {
info.stopped.wait();
}
}
pub(super) fn inject_or_push(&self, job_ref: JobRef) {
let worker_thread = WorkerThread::current();
unsafe {
if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
(*worker_thread).push(job_ref);
} else {
self.inject(&[job_ref]);
}
}
}
pub(super) fn inject(&self, injected_jobs: &[JobRef]) {
self.log(|| JobsInjected {
count: injected_jobs.len(),
});
debug_assert_ne!(
self.terminate_count.load(Ordering::Acquire),
0,
"inject() sees state.terminate as true"
);
let queue_was_empty = self.injected_jobs.is_empty();
for &job_ref in injected_jobs {
self.injected_jobs.push(job_ref);
}
self.sleep
.new_injected_jobs(usize::MAX, injected_jobs.len() as u32, queue_was_empty);
}
fn has_injected_job(&self) -> bool {
!self.injected_jobs.is_empty()
}
fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> {
loop {
match self.injected_jobs.steal() {
Steal::Success(job) => {
self.log(|| JobUninjected {
worker: worker_index,
});
return Some(job);
}
Steal::Empty => return None,
Steal::Retry => {}
}
}
}
pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
self.in_worker_cold(op)
} else if (*worker_thread).registry().id() != self.id() {
self.in_worker_cross(&*worker_thread, op)
} else {
op(&*worker_thread, false)
}
}
}
#[cold]
unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
LOCK_LATCH.with(|l| {
debug_assert!(WorkerThread::current().is_null());
let job = StackJob::new(
|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
},
l,
);
self.inject(&[job.as_job_ref()]);
job.latch.wait_and_reset();
self.logger.log(|| Flush);
job.into_result()
})
}
#[cold]
unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
debug_assert!(current_thread.registry().id() != self.id());
let latch = SpinLatch::cross(current_thread);
let job = StackJob::new(
|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
},
latch,
);
self.inject(&[job.as_job_ref()]);
current_thread.wait_until(&job.latch);
job.into_result()
}
pub(super) fn increment_terminate_count(&self) {
let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
debug_assert!(previous != 0, "registry ref count incremented from zero");
assert!(
previous != std::usize::MAX,
"overflow in registry ref count"
);
}
pub(super) fn terminate(&self) {
if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
for (i, thread_info) in self.thread_infos.iter().enumerate() {
thread_info.terminate.set_and_tickle_one(self, i);
}
}
}
pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
self.sleep.notify_worker_latch_is_set(target_worker_index);
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct RegistryId {
addr: usize,
}
struct ThreadInfo {
primed: LockLatch,
stopped: LockLatch,
terminate: CountLatch,
stealer: Stealer<JobRef>,
}
impl ThreadInfo {
fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
ThreadInfo {
primed: LockLatch::new(),
stopped: LockLatch::new(),
terminate: CountLatch::new(),
stealer,
}
}
}
pub(super) struct WorkerThread {
worker: Worker<JobRef>,
fifo: JobFifo,
index: usize,
rng: XorShift64Star,
registry: Arc<Registry>,
}
thread_local! {
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = Cell::new(ptr::null());
}
impl Drop for WorkerThread {
fn drop(&mut self) {
WORKER_THREAD_STATE.with(|t| {
assert!(t.get().eq(&(self as *const _)));
t.set(ptr::null());
});
}
}
impl WorkerThread {
#[inline]
pub(super) fn current() -> *const WorkerThread {
WORKER_THREAD_STATE.with(Cell::get)
}
unsafe fn set_current(thread: *const WorkerThread) {
WORKER_THREAD_STATE.with(|t| {
assert!(t.get().is_null());
t.set(thread);
});
}
#[inline]
pub(super) fn registry(&self) -> &Arc<Registry> {
&self.registry
}
#[inline]
pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
self.registry.logger.log(event)
}
#[inline]
pub(super) fn index(&self) -> usize {
self.index
}
#[inline]
pub(super) unsafe fn push(&self, job: JobRef) {
self.log(|| JobPushed { worker: self.index });
let queue_was_empty = self.worker.is_empty();
self.worker.push(job);
self.registry
.sleep
.new_internal_jobs(self.index, 1, queue_was_empty);
}
#[inline]
pub(super) unsafe fn push_fifo(&self, job: JobRef) {
self.push(self.fifo.push(job));
}
#[inline]
pub(super) fn local_deque_is_empty(&self) -> bool {
self.worker.is_empty()
}
#[inline]
pub(super) unsafe fn take_local_job(&self) -> Option<JobRef> {
let popped_job = self.worker.pop();
if popped_job.is_some() {
self.log(|| JobPopped { worker: self.index });
}
popped_job
}
#[inline]
pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
let latch = latch.as_core_latch();
if !latch.probe() {
self.wait_until_cold(latch);
}
}
#[cold]
unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
let abort_guard = unwind::AbortIfPanic;
let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
while !latch.probe() {
if let Some(job) = self
.take_local_job()
.or_else(|| self.steal())
.or_else(|| self.registry.pop_injected_job(self.index))
{
self.registry.sleep.work_found(idle_state);
self.execute(job);
idle_state = self.registry.sleep.start_looking(self.index, latch);
} else {
self.registry
.sleep
.no_work_found(&mut idle_state, latch, || self.registry.has_injected_job())
}
}
self.registry.sleep.work_found(idle_state);
self.log(|| ThreadSawLatchSet {
worker: self.index,
latch_addr: latch.addr(),
});
mem::forget(abort_guard);
}
#[inline]
pub(super) unsafe fn execute(&self, job: JobRef) {
job.execute();
}
unsafe fn steal(&self) -> Option<JobRef> {
debug_assert!(self.local_deque_is_empty());
let thread_infos = &self.registry.thread_infos.as_slice();
let num_threads = thread_infos.len();
if num_threads <= 1 {
return None;
}
loop {
let mut retry = false;
let start = self.rng.next_usize(num_threads);
let job = (start..num_threads)
.chain(0..start)
.filter(move |&i| i != self.index)
.find_map(|victim_index| {
let victim = &thread_infos[victim_index];
match victim.stealer.steal() {
Steal::Success(job) => {
self.log(|| JobStolen {
worker: self.index,
victim: victim_index,
});
Some(job)
}
Steal::Empty => None,
Steal::Retry => {
retry = true;
None
}
}
});
if job.is_some() || !retry {
return job;
}
}
}
}
unsafe fn main_loop(worker: Worker<JobRef>, registry: Arc<Registry>, index: usize) {
let worker_thread = &WorkerThread {
worker,
fifo: JobFifo::new(),
index,
rng: XorShift64Star::new(),
registry: registry.clone(),
};
WorkerThread::set_current(worker_thread);
registry.thread_infos[index].primed.set();
let abort_guard = unwind::AbortIfPanic;
if let Some(ref handler) = registry.start_handler {
let registry = registry.clone();
match unwind::halt_unwinding(|| handler(index)) {
Ok(()) => {}
Err(err) => {
registry.handle_panic(err);
}
}
}
let my_terminate_latch = ®istry.thread_infos[index].terminate;
worker_thread.log(|| ThreadStart {
worker: index,
terminate_addr: my_terminate_latch.as_core_latch().addr(),
});
worker_thread.wait_until(my_terminate_latch);
debug_assert!(worker_thread.take_local_job().is_none());
registry.thread_infos[index].stopped.set();
mem::forget(abort_guard);
worker_thread.log(|| ThreadTerminate { worker: index });
if let Some(ref handler) = registry.exit_handler {
let registry = registry.clone();
match unwind::halt_unwinding(|| handler(index)) {
Ok(()) => {}
Err(err) => {
registry.handle_panic(err);
}
}
}
}
pub(super) fn in_worker<OP, R>(op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
unsafe {
let owner_thread = WorkerThread::current();
if !owner_thread.is_null() {
op(&*owner_thread, false)
} else {
global_registry().in_worker_cold(op)
}
}
}
struct XorShift64Star {
state: Cell<u64>,
}
impl XorShift64Star {
fn new() -> Self {
let mut seed = 0;
while seed == 0 {
let mut hasher = DefaultHasher::new();
#[allow(deprecated)]
static COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
seed = hasher.finish();
}
XorShift64Star {
state: Cell::new(seed),
}
}
fn next(&self) -> u64 {
let mut x = self.state.get();
debug_assert_ne!(x, 0);
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.state.set(x);
x.wrapping_mul(0x2545_f491_4f6c_dd1d)
}
fn next_usize(&self, n: usize) -> usize {
(self.next() % n as u64) as usize
}
}