tokio/runtime/task/
raw.rs

1use crate::future::Future;
2use crate::runtime::task::core::{Core, Trailer};
3use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State};
4
5use std::ptr::NonNull;
6use std::task::{Poll, Waker};
7
8/// Raw task handle
9#[derive(Clone)]
10pub(crate) struct RawTask {
11    ptr: NonNull<Header>,
12}
13
14pub(super) struct Vtable {
15    /// Polls the future.
16    pub(super) poll: unsafe fn(NonNull<Header>),
17
18    /// Schedules the task for execution on the runtime.
19    pub(super) schedule: unsafe fn(NonNull<Header>),
20
21    /// Deallocates the memory.
22    pub(super) dealloc: unsafe fn(NonNull<Header>),
23
24    /// Reads the task output, if complete.
25    pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),
26
27    /// The join handle has been dropped.
28    pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
29
30    /// An abort handle has been dropped.
31    pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),
32
33    /// Scheduler is being shutdown.
34    pub(super) shutdown: unsafe fn(NonNull<Header>),
35
36    /// The number of bytes that the `trailer` field is offset from the header.
37    pub(super) trailer_offset: usize,
38
39    /// The number of bytes that the `scheduler` field is offset from the header.
40    pub(super) scheduler_offset: usize,
41
42    /// The number of bytes that the `id` field is offset from the header.
43    pub(super) id_offset: usize,
44}
45
46/// Get the vtable for the requested `T` and `S` generics.
47pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
48    &Vtable {
49        poll: poll::<T, S>,
50        schedule: schedule::<S>,
51        dealloc: dealloc::<T, S>,
52        try_read_output: try_read_output::<T, S>,
53        drop_join_handle_slow: drop_join_handle_slow::<T, S>,
54        drop_abort_handle: drop_abort_handle::<T, S>,
55        shutdown: shutdown::<T, S>,
56        trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET,
57        scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET,
58        id_offset: OffsetHelper::<T, S>::ID_OFFSET,
59    }
60}
61
62/// Calling `get_trailer_offset` directly in vtable doesn't work because it
63/// prevents the vtable from being promoted to a static reference.
64///
65/// See this thread for more info:
66/// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508>
67struct OffsetHelper<T, S>(T, S);
68impl<T: Future, S: Schedule> OffsetHelper<T, S> {
69    // Pass `size_of`/`align_of` as arguments rather than calling them directly
70    // inside `get_trailer_offset` because trait bounds on generic parameters
71    // of const fn are unstable on our MSRV.
72    const TRAILER_OFFSET: usize = get_trailer_offset(
73        std::mem::size_of::<Header>(),
74        std::mem::size_of::<Core<T, S>>(),
75        std::mem::align_of::<Core<T, S>>(),
76        std::mem::align_of::<Trailer>(),
77    );
78
79    // The `scheduler` is the first field of `Core`, so it has the same
80    // offset as `Core`.
81    const SCHEDULER_OFFSET: usize = get_core_offset(
82        std::mem::size_of::<Header>(),
83        std::mem::align_of::<Core<T, S>>(),
84    );
85
86    const ID_OFFSET: usize = get_id_offset(
87        std::mem::size_of::<Header>(),
88        std::mem::align_of::<Core<T, S>>(),
89        std::mem::size_of::<S>(),
90        std::mem::align_of::<Id>(),
91    );
92}
93
94/// Compute the offset of the `Trailer` field in `Cell<T, S>` using the
95/// `#[repr(C)]` algorithm.
96///
97/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
98/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
99const fn get_trailer_offset(
100    header_size: usize,
101    core_size: usize,
102    core_align: usize,
103    trailer_align: usize,
104) -> usize {
105    let mut offset = header_size;
106
107    let core_misalign = offset % core_align;
108    if core_misalign > 0 {
109        offset += core_align - core_misalign;
110    }
111    offset += core_size;
112
113    let trailer_misalign = offset % trailer_align;
114    if trailer_misalign > 0 {
115        offset += trailer_align - trailer_misalign;
116    }
117
118    offset
119}
120
121/// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the
122/// `#[repr(C)]` algorithm.
123///
124/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
125/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
126const fn get_core_offset(header_size: usize, core_align: usize) -> usize {
127    let mut offset = header_size;
128
129    let core_misalign = offset % core_align;
130    if core_misalign > 0 {
131        offset += core_align - core_misalign;
132    }
133
134    offset
135}
136
137/// Compute the offset of the `Id` field in `Cell<T, S>` using the
138/// `#[repr(C)]` algorithm.
139///
140/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
141/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
142const fn get_id_offset(
143    header_size: usize,
144    core_align: usize,
145    scheduler_size: usize,
146    id_align: usize,
147) -> usize {
148    let mut offset = get_core_offset(header_size, core_align);
149    offset += scheduler_size;
150
151    let id_misalign = offset % id_align;
152    if id_misalign > 0 {
153        offset += id_align - id_misalign;
154    }
155
156    offset
157}
158
159impl RawTask {
160    pub(super) fn new<T, S>(task: T, scheduler: S, id: Id) -> RawTask
161    where
162        T: Future,
163        S: Schedule,
164    {
165        let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id));
166        let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) };
167
168        RawTask { ptr }
169    }
170
171    pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask {
172        RawTask { ptr }
173    }
174
175    pub(super) fn header_ptr(&self) -> NonNull<Header> {
176        self.ptr
177    }
178
179    pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> {
180        unsafe { Header::get_trailer(self.ptr) }
181    }
182
183    /// Returns a reference to the task's header.
184    pub(super) fn header(&self) -> &Header {
185        unsafe { self.ptr.as_ref() }
186    }
187
188    /// Returns a reference to the task's trailer.
189    pub(super) fn trailer(&self) -> &Trailer {
190        unsafe { &*self.trailer_ptr().as_ptr() }
191    }
192
193    /// Returns a reference to the task's state.
194    pub(super) fn state(&self) -> &State {
195        &self.header().state
196    }
197
198    /// Safety: mutual exclusion is required to call this function.
199    pub(crate) fn poll(self) {
200        let vtable = self.header().vtable;
201        unsafe { (vtable.poll)(self.ptr) }
202    }
203
204    pub(super) fn schedule(self) {
205        let vtable = self.header().vtable;
206        unsafe { (vtable.schedule)(self.ptr) }
207    }
208
209    pub(super) fn dealloc(self) {
210        let vtable = self.header().vtable;
211        unsafe {
212            (vtable.dealloc)(self.ptr);
213        }
214    }
215
216    /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T`
217    /// is the future stored by the task.
218    pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) {
219        let vtable = self.header().vtable;
220        (vtable.try_read_output)(self.ptr, dst, waker);
221    }
222
223    pub(super) fn drop_join_handle_slow(self) {
224        let vtable = self.header().vtable;
225        unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
226    }
227
228    pub(super) fn drop_abort_handle(self) {
229        let vtable = self.header().vtable;
230        unsafe { (vtable.drop_abort_handle)(self.ptr) }
231    }
232
233    pub(super) fn shutdown(self) {
234        let vtable = self.header().vtable;
235        unsafe { (vtable.shutdown)(self.ptr) }
236    }
237
238    /// Increment the task's reference count.
239    ///
240    /// Currently, this is used only when creating an `AbortHandle`.
241    pub(super) fn ref_inc(self) {
242        self.header().state.ref_inc();
243    }
244
245    /// Get the queue-next pointer
246    ///
247    /// This is for usage by the injection queue
248    ///
249    /// Safety: make sure only one queue uses this and access is synchronized.
250    pub(crate) unsafe fn get_queue_next(self) -> Option<RawTask> {
251        self.header()
252            .queue_next
253            .with(|ptr| *ptr)
254            .map(|p| RawTask::from_raw(p))
255    }
256
257    /// Sets the queue-next pointer
258    ///
259    /// This is for usage by the injection queue
260    ///
261    /// Safety: make sure only one queue uses this and access is synchronized.
262    pub(crate) unsafe fn set_queue_next(self, val: Option<RawTask>) {
263        self.header().set_next(val.map(|task| task.ptr));
264    }
265}
266
267impl Copy for RawTask {}
268
269unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) {
270    let harness = Harness::<T, S>::from_raw(ptr);
271    harness.poll();
272}
273
274unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) {
275    use crate::runtime::task::{Notified, Task};
276
277    let scheduler = Header::get_scheduler::<S>(ptr);
278    scheduler
279        .as_ref()
280        .schedule(Notified(Task::from_raw(ptr.cast())));
281}
282
283unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) {
284    let harness = Harness::<T, S>::from_raw(ptr);
285    harness.dealloc();
286}
287
288unsafe fn try_read_output<T: Future, S: Schedule>(
289    ptr: NonNull<Header>,
290    dst: *mut (),
291    waker: &Waker,
292) {
293    let out = &mut *(dst as *mut Poll<super::Result<T::Output>>);
294
295    let harness = Harness::<T, S>::from_raw(ptr);
296    harness.try_read_output(out, waker);
297}
298
299unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
300    let harness = Harness::<T, S>::from_raw(ptr);
301    harness.drop_join_handle_slow();
302}
303
304unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
305    let harness = Harness::<T, S>::from_raw(ptr);
306    harness.drop_reference();
307}
308
309unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) {
310    let harness = Harness::<T, S>::from_raw(ptr);
311    harness.shutdown();
312}