tokio/runtime/scheduler/multi_thread/
queue.rs

1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16    type UnsignedShort = u32;
17    type UnsignedLong = u64;
18    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22    type UnsignedShort = u16;
23    type UnsignedLong = u32;
24    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30    inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36pub(crate) struct Inner<T: 'static> {
37    /// Concurrently updated by many threads.
38    ///
39    /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40    /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41    /// of stealing values. It represents the first value being stolen in the
42    /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43    /// required for buffer indexing in order to provide ABA mitigation and make
44    /// it possible to distinguish between full and empty buffers.
45    ///
46    /// When both `UnsignedShort` values are the same, there is no active
47    /// stealer.
48    ///
49    /// Tracking an in-progress stealer prevents a wrapping scenario.
50    head: AtomicUnsignedLong,
51
52    /// Only updated by producer thread but read by many threads.
53    tail: AtomicUnsignedShort,
54
55    /// Elements
56    buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57}
58
59unsafe impl<T> Send for Inner<T> {}
60unsafe impl<T> Sync for Inner<T> {}
61
62#[cfg(not(loom))]
63const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65// Shrink the size of the local queue when using loom. This shouldn't impact
66// logic, but allows loom to test more edge cases in a reasonable a mount of
67// time.
68#[cfg(loom)]
69const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73// Constructing the fixed size array directly is very awkward. The only way to
74// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75// the contents are not Copy. The trick with defining a const doesn't work for
76// generic types.
77fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78    assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80    // safety: We check that the length is correct.
81    unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82}
83
84/// Create a new local run-queue
85pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86    let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88    for _ in 0..LOCAL_QUEUE_CAPACITY {
89        buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90    }
91
92    let inner = Arc::new(Inner {
93        head: AtomicUnsignedLong::new(0),
94        tail: AtomicUnsignedShort::new(0),
95        buffer: make_fixed_size(buffer.into_boxed_slice()),
96    });
97
98    let local = Local {
99        inner: inner.clone(),
100    };
101
102    let remote = Steal(inner);
103
104    (remote, local)
105}
106
107impl<T> Local<T> {
108    /// Returns the number of entries in the queue
109    pub(crate) fn len(&self) -> usize {
110        self.inner.len() as usize
111    }
112
113    /// How many tasks can be pushed into the queue
114    pub(crate) fn remaining_slots(&self) -> usize {
115        self.inner.remaining_slots()
116    }
117
118    pub(crate) fn max_capacity(&self) -> usize {
119        LOCAL_QUEUE_CAPACITY
120    }
121
122    /// Returns false if there are any entries in the queue
123    ///
124    /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
125    /// some tasks from stealing won't affect this
126    pub(crate) fn has_tasks(&self) -> bool {
127        !self.inner.is_empty()
128    }
129
130    /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
131    /// the local queue.
132    ///
133    /// # Panics
134    ///
135    /// The method panics if there is not enough capacity to fit in the queue.
136    pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
137        let len = tasks.len();
138        assert!(len <= LOCAL_QUEUE_CAPACITY);
139
140        if len == 0 {
141            // Nothing to do
142            return;
143        }
144
145        let head = self.inner.head.load(Acquire);
146        let (steal, _) = unpack(head);
147
148        // safety: this is the **only** thread that updates this cell.
149        let mut tail = unsafe { self.inner.tail.unsync_load() };
150
151        if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
152            // Yes, this if condition is structured a bit weird (first block
153            // does nothing, second returns an error). It is this way to match
154            // `push_back_or_overflow`.
155        } else {
156            panic!()
157        }
158
159        for task in tasks {
160            let idx = tail as usize & MASK;
161
162            self.inner.buffer[idx].with_mut(|ptr| {
163                // Write the task to the slot
164                //
165                // Safety: There is only one producer and the above `if`
166                // condition ensures we don't touch a cell if there is a
167                // value, thus no consumer.
168                unsafe {
169                    ptr::write((*ptr).as_mut_ptr(), task);
170                }
171            });
172
173            tail = tail.wrapping_add(1);
174        }
175
176        self.inner.tail.store(tail, Release);
177    }
178
179    /// Pushes a task to the back of the local queue, if there is not enough
180    /// capacity in the queue, this triggers the overflow operation.
181    ///
182    /// When the queue overflows, half of the current contents of the queue is
183    /// moved to the given Injection queue. This frees up capacity for more
184    /// tasks to be pushed into the local queue.
185    pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
186        &mut self,
187        mut task: task::Notified<T>,
188        overflow: &O,
189        stats: &mut Stats,
190    ) {
191        let tail = loop {
192            let head = self.inner.head.load(Acquire);
193            let (steal, real) = unpack(head);
194
195            // safety: this is the **only** thread that updates this cell.
196            let tail = unsafe { self.inner.tail.unsync_load() };
197
198            if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
199                // There is capacity for the task
200                break tail;
201            } else if steal != real {
202                // Concurrently stealing, this will free up capacity, so only
203                // push the task onto the inject queue
204                overflow.push(task);
205                return;
206            } else {
207                // Push the current task and half of the queue into the
208                // inject queue.
209                match self.push_overflow(task, real, tail, overflow, stats) {
210                    Ok(_) => return,
211                    // Lost the race, try again
212                    Err(v) => {
213                        task = v;
214                    }
215                }
216            }
217        };
218
219        self.push_back_finish(task, tail);
220    }
221
222    // Second half of `push_back`
223    fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
224        // Map the position to a slot index.
225        let idx = tail as usize & MASK;
226
227        self.inner.buffer[idx].with_mut(|ptr| {
228            // Write the task to the slot
229            //
230            // Safety: There is only one producer and the above `if`
231            // condition ensures we don't touch a cell if there is a
232            // value, thus no consumer.
233            unsafe {
234                ptr::write((*ptr).as_mut_ptr(), task);
235            }
236        });
237
238        // Make the task available. Synchronizes with a load in
239        // `steal_into2`.
240        self.inner.tail.store(tail.wrapping_add(1), Release);
241    }
242
243    /// Moves a batch of tasks into the inject queue.
244    ///
245    /// This will temporarily make some of the tasks unavailable to stealers.
246    /// Once `push_overflow` is done, a notification is sent out, so if other
247    /// workers "missed" some of the tasks during a steal, they will get
248    /// another opportunity.
249    #[inline(never)]
250    fn push_overflow<O: Overflow<T>>(
251        &mut self,
252        task: task::Notified<T>,
253        head: UnsignedShort,
254        tail: UnsignedShort,
255        overflow: &O,
256        stats: &mut Stats,
257    ) -> Result<(), task::Notified<T>> {
258        /// How many elements are we taking from the local queue.
259        ///
260        /// This is one less than the number of tasks pushed to the inject
261        /// queue as we are also inserting the `task` argument.
262        const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
263
264        assert_eq!(
265            tail.wrapping_sub(head) as usize,
266            LOCAL_QUEUE_CAPACITY,
267            "queue is not full; tail = {tail}; head = {head}"
268        );
269
270        let prev = pack(head, head);
271
272        // Claim a bunch of tasks
273        //
274        // We are claiming the tasks **before** reading them out of the buffer.
275        // This is safe because only the **current** thread is able to push new
276        // tasks.
277        //
278        // There isn't really any need for memory ordering... Relaxed would
279        // work. This is because all tasks are pushed into the queue from the
280        // current thread (or memory has been acquired if the local queue handle
281        // moved).
282        if self
283            .inner
284            .head
285            .compare_exchange(
286                prev,
287                pack(
288                    head.wrapping_add(NUM_TASKS_TAKEN),
289                    head.wrapping_add(NUM_TASKS_TAKEN),
290                ),
291                Release,
292                Relaxed,
293            )
294            .is_err()
295        {
296            // We failed to claim the tasks, losing the race. Return out of
297            // this function and try the full `push` routine again. The queue
298            // may not be full anymore.
299            return Err(task);
300        }
301
302        /// An iterator that takes elements out of the run queue.
303        struct BatchTaskIter<'a, T: 'static> {
304            buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
305            head: UnsignedLong,
306            i: UnsignedLong,
307        }
308        impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
309            type Item = task::Notified<T>;
310
311            #[inline]
312            fn next(&mut self) -> Option<task::Notified<T>> {
313                if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
314                    None
315                } else {
316                    let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
317                    let slot = &self.buffer[i_idx];
318
319                    // safety: Our CAS from before has assumed exclusive ownership
320                    // of the task pointers in this range.
321                    let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
322
323                    self.i += 1;
324                    Some(task)
325                }
326            }
327        }
328
329        // safety: The CAS above ensures that no consumer will look at these
330        // values again, and we are the only producer.
331        let batch_iter = BatchTaskIter {
332            buffer: &self.inner.buffer,
333            head: head as UnsignedLong,
334            i: 0,
335        };
336        overflow.push_batch(batch_iter.chain(std::iter::once(task)));
337
338        // Add 1 to factor in the task currently being scheduled.
339        stats.incr_overflow_count();
340
341        Ok(())
342    }
343
344    /// Pops a task from the local queue.
345    pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
346        let mut head = self.inner.head.load(Acquire);
347
348        let idx = loop {
349            let (steal, real) = unpack(head);
350
351            // safety: this is the **only** thread that updates this cell.
352            let tail = unsafe { self.inner.tail.unsync_load() };
353
354            if real == tail {
355                // queue is empty
356                return None;
357            }
358
359            let next_real = real.wrapping_add(1);
360
361            // If `steal == real` there are no concurrent stealers. Both `steal`
362            // and `real` are updated.
363            let next = if steal == real {
364                pack(next_real, next_real)
365            } else {
366                assert_ne!(steal, next_real);
367                pack(steal, next_real)
368            };
369
370            // Attempt to claim a task.
371            let res = self
372                .inner
373                .head
374                .compare_exchange(head, next, AcqRel, Acquire);
375
376            match res {
377                Ok(_) => break real as usize & MASK,
378                Err(actual) => head = actual,
379            }
380        };
381
382        Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
383    }
384}
385
386impl<T> Steal<T> {
387    pub(crate) fn is_empty(&self) -> bool {
388        self.0.is_empty()
389    }
390
391    /// Steals half the tasks from self and place them into `dst`.
392    pub(crate) fn steal_into(
393        &self,
394        dst: &mut Local<T>,
395        dst_stats: &mut Stats,
396    ) -> Option<task::Notified<T>> {
397        // Safety: the caller is the only thread that mutates `dst.tail` and
398        // holds a mutable reference.
399        let dst_tail = unsafe { dst.inner.tail.unsync_load() };
400
401        // To the caller, `dst` may **look** empty but still have values
402        // contained in the buffer. If another thread is concurrently stealing
403        // from `dst` there may not be enough capacity to steal.
404        let (steal, _) = unpack(dst.inner.head.load(Acquire));
405
406        if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
407            // we *could* try to steal less here, but for simplicity, we're just
408            // going to abort.
409            return None;
410        }
411
412        // Steal the tasks into `dst`'s buffer. This does not yet expose the
413        // tasks in `dst`.
414        let mut n = self.steal_into2(dst, dst_tail);
415
416        if n == 0 {
417            // No tasks were stolen
418            return None;
419        }
420
421        dst_stats.incr_steal_count(n as u16);
422        dst_stats.incr_steal_operations();
423
424        // We are returning a task here
425        n -= 1;
426
427        let ret_pos = dst_tail.wrapping_add(n);
428        let ret_idx = ret_pos as usize & MASK;
429
430        // safety: the value was written as part of `steal_into2` and not
431        // exposed to stealers, so no other thread can access it.
432        let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
433
434        if n == 0 {
435            // The `dst` queue is empty, but a single task was stolen
436            return Some(ret);
437        }
438
439        // Make the stolen items available to consumers
440        dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
441
442        Some(ret)
443    }
444
445    // Steal tasks from `self`, placing them into `dst`. Returns the number of
446    // tasks that were stolen.
447    fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
448        let mut prev_packed = self.0.head.load(Acquire);
449        let mut next_packed;
450
451        let n = loop {
452            let (src_head_steal, src_head_real) = unpack(prev_packed);
453            let src_tail = self.0.tail.load(Acquire);
454
455            // If these two do not match, another thread is concurrently
456            // stealing from the queue.
457            if src_head_steal != src_head_real {
458                return 0;
459            }
460
461            // Number of available tasks to steal
462            let n = src_tail.wrapping_sub(src_head_real);
463            let n = n - n / 2;
464
465            if n == 0 {
466                // No tasks available to steal
467                return 0;
468            }
469
470            // Update the real head index to acquire the tasks.
471            let steal_to = src_head_real.wrapping_add(n);
472            assert_ne!(src_head_steal, steal_to);
473            next_packed = pack(src_head_steal, steal_to);
474
475            // Claim all those tasks. This is done by incrementing the "real"
476            // head but not the steal. By doing this, no other thread is able to
477            // steal from this queue until the current thread completes.
478            let res = self
479                .0
480                .head
481                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
482
483            match res {
484                Ok(_) => break n,
485                Err(actual) => prev_packed = actual,
486            }
487        };
488
489        assert!(
490            n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
491            "actual = {n}"
492        );
493
494        let (first, _) = unpack(next_packed);
495
496        // Take all the tasks
497        for i in 0..n {
498            // Compute the positions
499            let src_pos = first.wrapping_add(i);
500            let dst_pos = dst_tail.wrapping_add(i);
501
502            // Map to slots
503            let src_idx = src_pos as usize & MASK;
504            let dst_idx = dst_pos as usize & MASK;
505
506            // Read the task
507            //
508            // safety: We acquired the task with the atomic exchange above.
509            let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
510
511            // Write the task to the new slot
512            //
513            // safety: `dst` queue is empty and we are the only producer to
514            // this queue.
515            dst.inner.buffer[dst_idx]
516                .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
517        }
518
519        let mut prev_packed = next_packed;
520
521        // Update `src_head_steal` to match `src_head_real` signalling that the
522        // stealing routine is complete.
523        loop {
524            let head = unpack(prev_packed).1;
525            next_packed = pack(head, head);
526
527            let res = self
528                .0
529                .head
530                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
531
532            match res {
533                Ok(_) => return n,
534                Err(actual) => {
535                    let (actual_steal, actual_real) = unpack(actual);
536
537                    assert_ne!(actual_steal, actual_real);
538
539                    prev_packed = actual;
540                }
541            }
542        }
543    }
544}
545
546cfg_unstable_metrics! {
547    impl<T> Steal<T> {
548        pub(crate) fn len(&self) -> usize {
549            self.0.len() as _
550        }
551    }
552}
553
554impl<T> Clone for Steal<T> {
555    fn clone(&self) -> Steal<T> {
556        Steal(self.0.clone())
557    }
558}
559
560impl<T> Drop for Local<T> {
561    fn drop(&mut self) {
562        if !std::thread::panicking() {
563            assert!(self.pop().is_none(), "queue not empty");
564        }
565    }
566}
567
568impl<T> Inner<T> {
569    fn remaining_slots(&self) -> usize {
570        let (steal, _) = unpack(self.head.load(Acquire));
571        let tail = self.tail.load(Acquire);
572
573        LOCAL_QUEUE_CAPACITY - (tail.wrapping_sub(steal) as usize)
574    }
575
576    fn len(&self) -> UnsignedShort {
577        let (_, head) = unpack(self.head.load(Acquire));
578        let tail = self.tail.load(Acquire);
579
580        tail.wrapping_sub(head)
581    }
582
583    fn is_empty(&self) -> bool {
584        self.len() == 0
585    }
586}
587
588/// Split the head value into the real head and the index a stealer is working
589/// on.
590fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
591    let real = n & UnsignedShort::MAX as UnsignedLong;
592    let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
593
594    (steal as UnsignedShort, real as UnsignedShort)
595}
596
597/// Join the two head values
598fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
599    (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
600}
601
602#[test]
603fn test_local_queue_capacity() {
604    assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
605}