tokio/runtime/task/
core.rs

1//! Core task module.
2//!
3//! # Safety
4//!
5//! The functions in this module are private to the `task` module. All of them
6//! should be considered `unsafe` to use, but are not marked as such since it
7//! would be too noisy.
8//!
9//! Make sure to consult the relevant safety section of each function before
10//! use.
11
12use crate::future::Future;
13use crate::loom::cell::UnsafeCell;
14use crate::runtime::context;
15use crate::runtime::task::raw::{self, Vtable};
16use crate::runtime::task::state::State;
17use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks};
18use crate::util::linked_list;
19
20use std::num::NonZeroU64;
21use std::pin::Pin;
22use std::ptr::NonNull;
23use std::task::{Context, Poll, Waker};
24
25/// The task cell. Contains the components of the task.
26///
27/// It is critical for `Header` to be the first field as the task structure will
28/// be referenced by both *mut Cell and *mut Header.
29///
30/// Any changes to the layout of this struct _must_ also be reflected in the
31/// `const` fns in raw.rs.
32///
33// # This struct should be cache padded to avoid false sharing. The cache padding rules are copied
34// from crossbeam-utils/src/cache_padded.rs
35//
36// Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache
37// lines at a time, so we have to align to 128 bytes rather than 64.
38//
39// Sources:
40// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf
41// - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107
42//
43// ARM's big.LITTLE architecture has asymmetric cores and "big" cores have 128-byte cache line size.
44//
45// Sources:
46// - https://www.mono-project.com/news/2016/09/12/arm64-icache/
47//
48// powerpc64 has 128-byte cache line size.
49//
50// Sources:
51// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_ppc64x.go#L9
52#[cfg_attr(
53    any(
54        target_arch = "x86_64",
55        target_arch = "aarch64",
56        target_arch = "powerpc64",
57    ),
58    repr(align(128))
59)]
60// arm, mips, mips64, sparc, and hexagon have 32-byte cache line size.
61//
62// Sources:
63// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_arm.go#L7
64// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips.go#L7
65// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mipsle.go#L7
66// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips64x.go#L9
67// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L17
68// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/hexagon/include/asm/cache.h#L12
69#[cfg_attr(
70    any(
71        target_arch = "arm",
72        target_arch = "mips",
73        target_arch = "mips64",
74        target_arch = "sparc",
75        target_arch = "hexagon",
76    ),
77    repr(align(32))
78)]
79// m68k has 16-byte cache line size.
80//
81// Sources:
82// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/m68k/include/asm/cache.h#L9
83#[cfg_attr(target_arch = "m68k", repr(align(16)))]
84// s390x has 256-byte cache line size.
85//
86// Sources:
87// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_s390x.go#L7
88// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/s390/include/asm/cache.h#L13
89#[cfg_attr(target_arch = "s390x", repr(align(256)))]
90// x86, riscv, wasm, and sparc64 have 64-byte cache line size.
91//
92// Sources:
93// - https://github.com/golang/go/blob/dda2991c2ea0c5914714469c4defc2562a907230/src/internal/cpu/cpu_x86.go#L9
94// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_wasm.go#L7
95// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L19
96// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/riscv/include/asm/cache.h#L10
97//
98// All others are assumed to have 64-byte cache line size.
99#[cfg_attr(
100    not(any(
101        target_arch = "x86_64",
102        target_arch = "aarch64",
103        target_arch = "powerpc64",
104        target_arch = "arm",
105        target_arch = "mips",
106        target_arch = "mips64",
107        target_arch = "sparc",
108        target_arch = "hexagon",
109        target_arch = "m68k",
110        target_arch = "s390x",
111    )),
112    repr(align(64))
113)]
114#[repr(C)]
115pub(super) struct Cell<T: Future, S> {
116    /// Hot task state data
117    pub(super) header: Header,
118
119    /// Either the future or output, depending on the execution stage.
120    pub(super) core: Core<T, S>,
121
122    /// Cold data
123    pub(super) trailer: Trailer,
124}
125
126pub(super) struct CoreStage<T: Future> {
127    stage: UnsafeCell<Stage<T>>,
128}
129
130/// The core of the task.
131///
132/// Holds the future or output, depending on the stage of execution.
133///
134/// Any changes to the layout of this struct _must_ also be reflected in the
135/// `const` fns in raw.rs.
136#[repr(C)]
137pub(super) struct Core<T: Future, S> {
138    /// Scheduler used to drive this future.
139    pub(super) scheduler: S,
140
141    /// The task's ID, used for populating `JoinError`s.
142    pub(super) task_id: Id,
143
144    /// Either the future or the output.
145    pub(super) stage: CoreStage<T>,
146}
147
148/// Crate public as this is also needed by the pool.
149#[repr(C)]
150pub(crate) struct Header {
151    /// Task state.
152    pub(super) state: State,
153
154    /// Pointer to next task, used with the injection queue.
155    pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
156
157    /// Table of function pointers for executing actions on the task.
158    pub(super) vtable: &'static Vtable,
159
160    /// This integer contains the id of the `OwnedTasks` or `LocalOwnedTasks`
161    /// that this task is stored in. If the task is not in any list, should be
162    /// the id of the list that it was previously in, or `None` if it has never
163    /// been in any list.
164    ///
165    /// Once a task has been bound to a list, it can never be bound to another
166    /// list, even if removed from the first list.
167    ///
168    /// The id is not unset when removed from a list because we want to be able
169    /// to read the id without synchronization, even if it is concurrently being
170    /// removed from the list.
171    pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,
172
173    /// The tracing ID for this instrumented task.
174    #[cfg(all(tokio_unstable, feature = "tracing"))]
175    pub(super) tracing_id: Option<tracing::Id>,
176}
177
178unsafe impl Send for Header {}
179unsafe impl Sync for Header {}
180
181/// Cold data is stored after the future. Data is considered cold if it is only
182/// used during creation or shutdown of the task.
183pub(super) struct Trailer {
184    /// Pointers for the linked list in the `OwnedTasks` that owns this task.
185    pub(super) owned: linked_list::Pointers<Header>,
186    /// Consumer task waiting on completion of this task.
187    pub(super) waker: UnsafeCell<Option<Waker>>,
188    /// Optional hooks needed in the harness.
189    pub(super) hooks: TaskHarnessScheduleHooks,
190}
191
192generate_addr_of_methods! {
193    impl<> Trailer {
194        pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
195            &self.owned
196        }
197    }
198}
199
200/// Either the future or the output.
201#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
202pub(super) enum Stage<T: Future> {
203    Running(T),
204    Finished(super::Result<T::Output>),
205    Consumed,
206}
207
208impl<T: Future, S: Schedule> Cell<T, S> {
209    /// Allocates a new task cell, containing the header, trailer, and core
210    /// structures.
211    pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
212        // Separated into a non-generic function to reduce LLVM codegen
213        fn new_header(
214            state: State,
215            vtable: &'static Vtable,
216            #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option<tracing::Id>,
217        ) -> Header {
218            Header {
219                state,
220                queue_next: UnsafeCell::new(None),
221                vtable,
222                owner_id: UnsafeCell::new(None),
223                #[cfg(all(tokio_unstable, feature = "tracing"))]
224                tracing_id,
225            }
226        }
227
228        #[cfg(all(tokio_unstable, feature = "tracing"))]
229        let tracing_id = future.id();
230        let vtable = raw::vtable::<T, S>();
231        let result = Box::new(Cell {
232            trailer: Trailer::new(scheduler.hooks()),
233            header: new_header(
234                state,
235                vtable,
236                #[cfg(all(tokio_unstable, feature = "tracing"))]
237                tracing_id,
238            ),
239            core: Core {
240                scheduler,
241                stage: CoreStage {
242                    stage: UnsafeCell::new(Stage::Running(future)),
243                },
244                task_id,
245            },
246        });
247
248        #[cfg(debug_assertions)]
249        {
250            // Using a separate function for this code avoids instantiating it separately for every `T`.
251            unsafe fn check<S>(header: &Header, trailer: &Trailer, scheduler: &S, task_id: &Id) {
252                let trailer_addr = trailer as *const Trailer as usize;
253                let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) };
254                assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
255
256                let scheduler_addr = scheduler as *const S as usize;
257                let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) };
258                assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
259
260                let id_addr = task_id as *const Id as usize;
261                let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) };
262                assert_eq!(id_addr, id_ptr.as_ptr() as usize);
263            }
264            unsafe {
265                check(
266                    &result.header,
267                    &result.trailer,
268                    &result.core.scheduler,
269                    &result.core.task_id,
270                );
271            }
272        }
273
274        result
275    }
276}
277
278impl<T: Future> CoreStage<T> {
279    pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
280        self.stage.with_mut(f)
281    }
282}
283
284/// Set and clear the task id in the context when the future is executed or
285/// dropped, or when the output produced by the future is dropped.
286pub(crate) struct TaskIdGuard {
287    parent_task_id: Option<Id>,
288}
289
290impl TaskIdGuard {
291    fn enter(id: Id) -> Self {
292        TaskIdGuard {
293            parent_task_id: context::set_current_task_id(Some(id)),
294        }
295    }
296}
297
298impl Drop for TaskIdGuard {
299    fn drop(&mut self) {
300        context::set_current_task_id(self.parent_task_id);
301    }
302}
303
304impl<T: Future, S: Schedule> Core<T, S> {
305    /// Polls the future.
306    ///
307    /// # Safety
308    ///
309    /// The caller must ensure it is safe to mutate the `state` field. This
310    /// requires ensuring mutual exclusion between any concurrent thread that
311    /// might modify the future or output field.
312    ///
313    /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
314    /// component of the task state.
315    ///
316    /// `self` must also be pinned. This is handled by storing the task on the
317    /// heap.
318    pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
319        let res = {
320            self.stage.stage.with_mut(|ptr| {
321                // Safety: The caller ensures mutual exclusion to the field.
322                let future = match unsafe { &mut *ptr } {
323                    Stage::Running(future) => future,
324                    _ => unreachable!("unexpected stage"),
325                };
326
327                // Safety: The caller ensures the future is pinned.
328                let future = unsafe { Pin::new_unchecked(future) };
329
330                let _guard = TaskIdGuard::enter(self.task_id);
331                future.poll(&mut cx)
332            })
333        };
334
335        if res.is_ready() {
336            self.drop_future_or_output();
337        }
338
339        res
340    }
341
342    /// Drops the future.
343    ///
344    /// # Safety
345    ///
346    /// The caller must ensure it is safe to mutate the `stage` field.
347    pub(super) fn drop_future_or_output(&self) {
348        // Safety: the caller ensures mutual exclusion to the field.
349        unsafe {
350            self.set_stage(Stage::Consumed);
351        }
352    }
353
354    /// Stores the task output.
355    ///
356    /// # Safety
357    ///
358    /// The caller must ensure it is safe to mutate the `stage` field.
359    pub(super) fn store_output(&self, output: super::Result<T::Output>) {
360        // Safety: the caller ensures mutual exclusion to the field.
361        unsafe {
362            self.set_stage(Stage::Finished(output));
363        }
364    }
365
366    /// Takes the task output.
367    ///
368    /// # Safety
369    ///
370    /// The caller must ensure it is safe to mutate the `stage` field.
371    pub(super) fn take_output(&self) -> super::Result<T::Output> {
372        use std::mem;
373
374        self.stage.stage.with_mut(|ptr| {
375            // Safety:: the caller ensures mutual exclusion to the field.
376            match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
377                Stage::Finished(output) => output,
378                _ => panic!("JoinHandle polled after completion"),
379            }
380        })
381    }
382
383    unsafe fn set_stage(&self, stage: Stage<T>) {
384        let _guard = TaskIdGuard::enter(self.task_id);
385        self.stage.stage.with_mut(|ptr| *ptr = stage);
386    }
387}
388
389impl Header {
390    pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
391        self.queue_next.with_mut(|ptr| *ptr = next);
392    }
393
394    // safety: The caller must guarantee exclusive access to this field, and
395    // must ensure that the id is either `None` or the id of the OwnedTasks
396    // containing this task.
397    pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
398        self.owner_id.with_mut(|ptr| *ptr = Some(owner));
399    }
400
401    pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
402        // safety: If there are concurrent writes, then that write has violated
403        // the safety requirements on `set_owner_id`.
404        unsafe { self.owner_id.with(|ptr| *ptr) }
405    }
406
407    /// Gets a pointer to the `Trailer` of the task containing this `Header`.
408    ///
409    /// # Safety
410    ///
411    /// The provided raw pointer must point at the header of a task.
412    pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
413        let offset = me.as_ref().vtable.trailer_offset;
414        let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
415        NonNull::new_unchecked(trailer)
416    }
417
418    /// Gets a pointer to the scheduler of the task containing this `Header`.
419    ///
420    /// # Safety
421    ///
422    /// The provided raw pointer must point at the header of a task.
423    ///
424    /// The generic type S must be set to the correct scheduler type for this
425    /// task.
426    pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
427        let offset = me.as_ref().vtable.scheduler_offset;
428        let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
429        NonNull::new_unchecked(scheduler)
430    }
431
432    /// Gets a pointer to the id of the task containing this `Header`.
433    ///
434    /// # Safety
435    ///
436    /// The provided raw pointer must point at the header of a task.
437    pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
438        let offset = me.as_ref().vtable.id_offset;
439        let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
440        NonNull::new_unchecked(id)
441    }
442
443    /// Gets the id of the task containing this `Header`.
444    ///
445    /// # Safety
446    ///
447    /// The provided raw pointer must point at the header of a task.
448    pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
449        let ptr = Header::get_id_ptr(me).as_ptr();
450        *ptr
451    }
452
453    /// Gets the tracing id of the task containing this `Header`.
454    ///
455    /// # Safety
456    ///
457    /// The provided raw pointer must point at the header of a task.
458    #[cfg(all(tokio_unstable, feature = "tracing"))]
459    pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
460        me.as_ref().tracing_id.as_ref()
461    }
462}
463
464impl Trailer {
465    fn new(hooks: TaskHarnessScheduleHooks) -> Self {
466        Trailer {
467            waker: UnsafeCell::new(None),
468            owned: linked_list::Pointers::new(),
469            hooks,
470        }
471    }
472
473    pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
474        self.waker.with_mut(|ptr| {
475            *ptr = waker;
476        });
477    }
478
479    pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
480        self.waker
481            .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
482    }
483
484    pub(super) fn wake_join(&self) {
485        self.waker.with(|ptr| match unsafe { &*ptr } {
486            Some(waker) => waker.wake_by_ref(),
487            None => panic!("waker missing"),
488        });
489    }
490}
491
492#[test]
493#[cfg(not(loom))]
494fn header_lte_cache_line() {
495    assert!(std::mem::size_of::<Header>() <= 8 * std::mem::size_of::<*const ()>());
496}