1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::parse::{Parse, ParseStream, Parser};
4use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
5
6type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
8
9#[derive(Clone, Copy, PartialEq)]
10enum RuntimeFlavor {
11 CurrentThread,
12 Threaded,
13}
14
15impl RuntimeFlavor {
16 fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
17 match s {
18 "current_thread" => Ok(RuntimeFlavor::CurrentThread),
19 "multi_thread" => Ok(RuntimeFlavor::Threaded),
20 "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
21 "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
22 "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
23 _ => Err(format!("No such runtime flavor `{s}`. The runtime flavors are `current_thread` and `multi_thread`.")),
24 }
25 }
26}
27
28#[derive(Clone, Copy, PartialEq)]
29enum UnhandledPanic {
30 Ignore,
31 ShutdownRuntime,
32}
33
34impl UnhandledPanic {
35 fn from_str(s: &str) -> Result<UnhandledPanic, String> {
36 match s {
37 "ignore" => Ok(UnhandledPanic::Ignore),
38 "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime),
39 _ => Err(format!("No such unhandled panic behavior `{s}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`.")),
40 }
41 }
42
43 fn into_tokens(self, crate_path: &TokenStream) -> TokenStream {
44 match self {
45 UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore },
46 UnhandledPanic::ShutdownRuntime => {
47 quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime }
48 }
49 }
50 }
51}
52
53struct FinalConfig {
54 flavor: RuntimeFlavor,
55 worker_threads: Option<usize>,
56 start_paused: Option<bool>,
57 crate_name: Option<Path>,
58 unhandled_panic: Option<UnhandledPanic>,
59}
60
61const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
63 flavor: RuntimeFlavor::CurrentThread,
64 worker_threads: None,
65 start_paused: None,
66 crate_name: None,
67 unhandled_panic: None,
68};
69
70struct Configuration {
71 rt_multi_thread_available: bool,
72 default_flavor: RuntimeFlavor,
73 flavor: Option<RuntimeFlavor>,
74 worker_threads: Option<(usize, Span)>,
75 start_paused: Option<(bool, Span)>,
76 is_test: bool,
77 crate_name: Option<Path>,
78 unhandled_panic: Option<(UnhandledPanic, Span)>,
79}
80
81impl Configuration {
82 fn new(is_test: bool, rt_multi_thread: bool) -> Self {
83 Configuration {
84 rt_multi_thread_available: rt_multi_thread,
85 default_flavor: match is_test {
86 true => RuntimeFlavor::CurrentThread,
87 false => RuntimeFlavor::Threaded,
88 },
89 flavor: None,
90 worker_threads: None,
91 start_paused: None,
92 is_test,
93 crate_name: None,
94 unhandled_panic: None,
95 }
96 }
97
98 fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
99 if self.flavor.is_some() {
100 return Err(syn::Error::new(span, "`flavor` set multiple times."));
101 }
102
103 let runtime_str = parse_string(runtime, span, "flavor")?;
104 let runtime =
105 RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
106 self.flavor = Some(runtime);
107 Ok(())
108 }
109
110 fn set_worker_threads(
111 &mut self,
112 worker_threads: syn::Lit,
113 span: Span,
114 ) -> Result<(), syn::Error> {
115 if self.worker_threads.is_some() {
116 return Err(syn::Error::new(
117 span,
118 "`worker_threads` set multiple times.",
119 ));
120 }
121
122 let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
123 if worker_threads == 0 {
124 return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
125 }
126 self.worker_threads = Some((worker_threads, span));
127 Ok(())
128 }
129
130 fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
131 if self.start_paused.is_some() {
132 return Err(syn::Error::new(span, "`start_paused` set multiple times."));
133 }
134
135 let start_paused = parse_bool(start_paused, span, "start_paused")?;
136 self.start_paused = Some((start_paused, span));
137 Ok(())
138 }
139
140 fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
141 if self.crate_name.is_some() {
142 return Err(syn::Error::new(span, "`crate` set multiple times."));
143 }
144 let name_path = parse_path(name, span, "crate")?;
145 self.crate_name = Some(name_path);
146 Ok(())
147 }
148
149 fn set_unhandled_panic(
150 &mut self,
151 unhandled_panic: syn::Lit,
152 span: Span,
153 ) -> Result<(), syn::Error> {
154 if self.unhandled_panic.is_some() {
155 return Err(syn::Error::new(
156 span,
157 "`unhandled_panic` set multiple times.",
158 ));
159 }
160
161 let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic")?;
162 let unhandled_panic =
163 UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?;
164 self.unhandled_panic = Some((unhandled_panic, span));
165 Ok(())
166 }
167
168 fn macro_name(&self) -> &'static str {
169 if self.is_test {
170 "tokio::test"
171 } else {
172 "tokio::main"
173 }
174 }
175
176 fn build(&self) -> Result<FinalConfig, syn::Error> {
177 use RuntimeFlavor as F;
178
179 let flavor = self.flavor.unwrap_or(self.default_flavor);
180 let worker_threads = match (flavor, self.worker_threads) {
181 (F::CurrentThread, Some((_, worker_threads_span))) => {
182 let msg = format!(
183 "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
184 self.macro_name(),
185 );
186 return Err(syn::Error::new(worker_threads_span, msg));
187 }
188 (F::CurrentThread, None) => None,
189 (F::Threaded, worker_threads) if self.rt_multi_thread_available => {
190 worker_threads.map(|(val, _span)| val)
191 }
192 (F::Threaded, _) => {
193 let msg = if self.flavor.is_none() {
194 "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
195 } else {
196 "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
197 };
198 return Err(syn::Error::new(Span::call_site(), msg));
199 }
200 };
201
202 let start_paused = match (flavor, self.start_paused) {
203 (F::Threaded, Some((_, start_paused_span))) => {
204 let msg = format!(
205 "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
206 self.macro_name(),
207 );
208 return Err(syn::Error::new(start_paused_span, msg));
209 }
210 (F::CurrentThread, Some((start_paused, _))) => Some(start_paused),
211 (_, None) => None,
212 };
213
214 let unhandled_panic = match (flavor, self.unhandled_panic) {
215 (F::Threaded, Some((_, unhandled_panic_span))) => {
216 let msg = format!(
217 "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
218 self.macro_name(),
219 );
220 return Err(syn::Error::new(unhandled_panic_span, msg));
221 }
222 (F::CurrentThread, Some((unhandled_panic, _))) => Some(unhandled_panic),
223 (_, None) => None,
224 };
225
226 Ok(FinalConfig {
227 crate_name: self.crate_name.clone(),
228 flavor,
229 worker_threads,
230 start_paused,
231 unhandled_panic,
232 })
233 }
234}
235
236fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
237 match int {
238 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
239 Ok(value) => Ok(value),
240 Err(e) => Err(syn::Error::new(
241 span,
242 format!("Failed to parse value of `{field}` as integer: {e}"),
243 )),
244 },
245 _ => Err(syn::Error::new(
246 span,
247 format!("Failed to parse value of `{field}` as integer."),
248 )),
249 }
250}
251
252fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
253 match int {
254 syn::Lit::Str(s) => Ok(s.value()),
255 syn::Lit::Verbatim(s) => Ok(s.to_string()),
256 _ => Err(syn::Error::new(
257 span,
258 format!("Failed to parse value of `{field}` as string."),
259 )),
260 }
261}
262
263fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
264 match lit {
265 syn::Lit::Str(s) => {
266 let err = syn::Error::new(
267 span,
268 format!(
269 "Failed to parse value of `{}` as path: \"{}\"",
270 field,
271 s.value()
272 ),
273 );
274 s.parse::<syn::Path>().map_err(|_| err.clone())
275 }
276 _ => Err(syn::Error::new(
277 span,
278 format!("Failed to parse value of `{field}` as path."),
279 )),
280 }
281}
282
283fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
284 match bool {
285 syn::Lit::Bool(b) => Ok(b.value),
286 _ => Err(syn::Error::new(
287 span,
288 format!("Failed to parse value of `{field}` as bool."),
289 )),
290 }
291}
292
293fn build_config(
294 input: &ItemFn,
295 args: AttributeArgs,
296 is_test: bool,
297 rt_multi_thread: bool,
298) -> Result<FinalConfig, syn::Error> {
299 if input.sig.asyncness.is_none() {
300 let msg = "the `async` keyword is missing from the function declaration";
301 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
302 }
303
304 let mut config = Configuration::new(is_test, rt_multi_thread);
305 let macro_name = config.macro_name();
306
307 for arg in args {
308 match arg {
309 syn::Meta::NameValue(namevalue) => {
310 let ident = namevalue
311 .path
312 .get_ident()
313 .ok_or_else(|| {
314 syn::Error::new_spanned(&namevalue, "Must have specified ident")
315 })?
316 .to_string()
317 .to_lowercase();
318 let lit = match &namevalue.value {
319 syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
320 expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
321 };
322 match ident.as_str() {
323 "worker_threads" => {
324 config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?;
325 }
326 "flavor" => {
327 config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?;
328 }
329 "start_paused" => {
330 config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?;
331 }
332 "core_threads" => {
333 let msg = "Attribute `core_threads` is renamed to `worker_threads`";
334 return Err(syn::Error::new_spanned(namevalue, msg));
335 }
336 "crate" => {
337 config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
338 }
339 "unhandled_panic" => {
340 config
341 .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?;
342 }
343 name => {
344 let msg = format!(
345 "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`",
346 );
347 return Err(syn::Error::new_spanned(namevalue, msg));
348 }
349 }
350 }
351 syn::Meta::Path(path) => {
352 let name = path
353 .get_ident()
354 .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
355 .to_string()
356 .to_lowercase();
357 let msg = match name.as_str() {
358 "threaded_scheduler" | "multi_thread" => {
359 format!(
360 "Set the runtime flavor with #[{macro_name}(flavor = \"multi_thread\")]."
361 )
362 }
363 "basic_scheduler" | "current_thread" | "single_threaded" => {
364 format!(
365 "Set the runtime flavor with #[{macro_name}(flavor = \"current_thread\")]."
366 )
367 }
368 "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => {
369 format!("The `{name}` attribute requires an argument.")
370 }
371 name => {
372 format!("Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`.")
373 }
374 };
375 return Err(syn::Error::new_spanned(path, msg));
376 }
377 other => {
378 return Err(syn::Error::new_spanned(
379 other,
380 "Unknown attribute inside the macro",
381 ));
382 }
383 }
384 }
385
386 config.build()
387}
388
389fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
390 input.sig.asyncness = None;
391
392 let (last_stmt_start_span, last_stmt_end_span) = {
394 let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
395
396 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
401 let end = last_stmt.last().map_or(start, |t| t.span());
402 (start, end)
403 };
404
405 let crate_path = config
406 .crate_name
407 .map(ToTokens::into_token_stream)
408 .unwrap_or_else(|| Ident::new("tokio", last_stmt_start_span).into_token_stream());
409
410 let mut rt = match config.flavor {
411 RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
412 #crate_path::runtime::Builder::new_current_thread()
413 },
414 RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
415 #crate_path::runtime::Builder::new_multi_thread()
416 },
417 };
418 if let Some(v) = config.worker_threads {
419 rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) };
420 }
421 if let Some(v) = config.start_paused {
422 rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) };
423 }
424 if let Some(v) = config.unhandled_panic {
425 let unhandled_panic = v.into_tokens(&crate_path);
426 rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) };
427 }
428
429 let generated_attrs = if is_test {
430 quote! {
431 #[::core::prelude::v1::test]
432 }
433 } else {
434 quote! {}
435 };
436
437 let body_ident = quote! { body };
438 let last_block = quote_spanned! {last_stmt_end_span=>
440 #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
441 {
442 return #rt
443 .enable_all()
444 .build()
445 .expect("Failed building the Runtime")
446 .block_on(#body_ident);
447 }
448 };
449
450 let body = input.body();
451
452 let body = if is_test {
462 let output_type = match &input.sig.output {
463 syn::ReturnType::Default => quote! { () },
467 syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
468 };
469 quote! {
470 let body = async #body;
471 #crate_path::pin!(body);
472 let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body;
473 }
474 } else {
475 quote! {
476 let body = async #body;
477 }
478 };
479
480 input.into_tokens(generated_attrs, body, last_block)
481}
482
483fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
484 tokens.extend(error.into_compile_error());
485 tokens
486}
487
488pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
489 let input: ItemFn = match syn::parse2(item.clone()) {
493 Ok(it) => it,
494 Err(e) => return token_stream_with_error(item, e),
495 };
496
497 let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
498 let msg = "the main function cannot accept arguments";
499 Err(syn::Error::new_spanned(&input.sig.ident, msg))
500 } else {
501 AttributeArgs::parse_terminated
502 .parse2(args)
503 .and_then(|args| build_config(&input, args, false, rt_multi_thread))
504 };
505
506 match config {
507 Ok(config) => parse_knobs(input, false, config),
508 Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
509 }
510}
511
512fn is_test_attribute(attr: &Attribute) -> bool {
517 let path = match &attr.meta {
518 syn::Meta::Path(path) => path,
519 _ => return false,
520 };
521 let candidates = [
522 ["core", "prelude", "*", "test"],
523 ["std", "prelude", "*", "test"],
524 ];
525 if path.leading_colon.is_none()
526 && path.segments.len() == 1
527 && path.segments[0].arguments.is_none()
528 && path.segments[0].ident == "test"
529 {
530 return true;
531 } else if path.segments.len() != candidates[0].len() {
532 return false;
533 }
534 candidates.into_iter().any(|segments| {
535 path.segments.iter().zip(segments).all(|(segment, path)| {
536 segment.arguments.is_none() && (path == "*" || segment.ident == path)
537 })
538 })
539}
540
541pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
542 let input: ItemFn = match syn::parse2(item.clone()) {
546 Ok(it) => it,
547 Err(e) => return token_stream_with_error(item, e),
548 };
549 let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) {
550 let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes";
551 Err(syn::Error::new_spanned(attr, msg))
552 } else {
553 AttributeArgs::parse_terminated
554 .parse2(args)
555 .and_then(|args| build_config(&input, args, true, rt_multi_thread))
556 };
557
558 match config {
559 Ok(config) => parse_knobs(input, true, config),
560 Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
561 }
562}
563
564struct ItemFn {
565 outer_attrs: Vec<Attribute>,
566 vis: Visibility,
567 sig: Signature,
568 brace_token: syn::token::Brace,
569 inner_attrs: Vec<Attribute>,
570 stmts: Vec<proc_macro2::TokenStream>,
571}
572
573impl ItemFn {
574 fn attrs(&self) -> impl Iterator<Item = &Attribute> {
576 self.outer_attrs.iter().chain(self.inner_attrs.iter())
577 }
578
579 fn body(&self) -> Body<'_> {
582 Body {
583 brace_token: self.brace_token,
584 stmts: &self.stmts,
585 }
586 }
587
588 fn into_tokens(
590 self,
591 generated_attrs: proc_macro2::TokenStream,
592 body: proc_macro2::TokenStream,
593 last_block: proc_macro2::TokenStream,
594 ) -> TokenStream {
595 let mut tokens = proc_macro2::TokenStream::new();
596 for attr in self.outer_attrs {
598 attr.to_tokens(&mut tokens);
599 }
600
601 for mut attr in self.inner_attrs {
605 attr.style = syn::AttrStyle::Outer;
606 attr.to_tokens(&mut tokens);
607 }
608
609 generated_attrs.to_tokens(&mut tokens);
611
612 self.vis.to_tokens(&mut tokens);
613 self.sig.to_tokens(&mut tokens);
614
615 self.brace_token.surround(&mut tokens, |tokens| {
616 body.to_tokens(tokens);
617 last_block.to_tokens(tokens);
618 });
619
620 tokens
621 }
622}
623
624impl Parse for ItemFn {
625 #[inline]
626 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
627 let outer_attrs = input.call(Attribute::parse_outer)?;
635 let vis: Visibility = input.parse()?;
636 let sig: Signature = input.parse()?;
637
638 let content;
639 let brace_token = braced!(content in input);
640 let inner_attrs = Attribute::parse_inner(&content)?;
641
642 let mut buf = proc_macro2::TokenStream::new();
643 let mut stmts = Vec::new();
644
645 while !content.is_empty() {
646 if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
647 semi.to_tokens(&mut buf);
648 stmts.push(buf);
649 buf = proc_macro2::TokenStream::new();
650 continue;
651 }
652
653 buf.extend([content.parse::<TokenTree>()?]);
656 }
657
658 if !buf.is_empty() {
659 stmts.push(buf);
660 }
661
662 Ok(Self {
663 outer_attrs,
664 vis,
665 sig,
666 brace_token,
667 inner_attrs,
668 stmts,
669 })
670 }
671}
672
673struct Body<'a> {
674 brace_token: syn::token::Brace,
675 stmts: &'a [TokenStream],
677}
678
679impl ToTokens for Body<'_> {
680 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
681 self.brace_token.surround(tokens, |tokens| {
682 for stmt in self.stmts {
683 stmt.to_tokens(tokens);
684 }
685 });
686 }
687}