rune_macros/
function.rs

1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned, ToTokens};
3use syn::parse::ParseStream;
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::Token;
7
8#[derive(Default)]
9enum Path {
10    #[default]
11    None,
12    Rename(syn::PathSegment),
13    Protocol(syn::Path),
14}
15
16#[derive(Default)]
17pub(crate) struct FunctionAttrs {
18    instance: bool,
19    /// A free function.
20    free: bool,
21    /// Keep the existing function in place, and generate a separate hidden meta function.
22    keep: bool,
23    /// Path to register in.
24    path: Path,
25    /// Looks like an associated type.
26    self_type: Option<syn::PathSegment>,
27    /// Defines a fallible function which can make use of the `?` operator.
28    vm_result: bool,
29    /// The function is deprecated.
30    deprecated: Option<syn::LitStr>,
31}
32
33impl FunctionAttrs {
34    /// Parse the given parse stream.
35    pub(crate) fn parse(input: ParseStream) -> syn::Result<Self> {
36        let mut out = Self::default();
37
38        while !input.is_empty() {
39            let ident = input.parse::<syn::Ident>()?;
40
41            if ident == "instance" {
42                out.instance = true;
43            } else if ident == "free" {
44                out.free = true;
45            } else if ident == "keep" {
46                out.keep = true;
47            } else if ident == "vm_result" {
48                out.vm_result = true;
49            } else if ident == "protocol" {
50                input.parse::<Token![=]>()?;
51                let protocol: syn::Path = input.parse()?;
52                out.path = Path::Protocol(if let Some(protocol) = protocol.get_ident() {
53                    syn::Path {
54                        leading_colon: None,
55                        segments: ["rune", "runtime", "Protocol"]
56                            .into_iter()
57                            .map(|i| syn::Ident::new(i, protocol.span()))
58                            .chain(Some(protocol.clone()))
59                            .map(syn::PathSegment::from)
60                            .collect(),
61                    }
62                } else {
63                    protocol
64                })
65            } else if ident == "path" {
66                input.parse::<Token![=]>()?;
67
68                let path = input.parse::<syn::Path>()?;
69
70                if path.segments.len() > 2 {
71                    return Err(syn::Error::new_spanned(
72                        path,
73                        "Expected at most two path segments",
74                    ));
75                }
76
77                let mut it = path.segments.into_iter();
78
79                let Some(first) = it.next() else {
80                    return Err(syn::Error::new(
81                        input.span(),
82                        "Expected at least one path segment",
83                    ));
84                };
85
86                if let Some(second) = it.next() {
87                    let syn::PathArguments::None = &first.arguments else {
88                        return Err(syn::Error::new_spanned(
89                            first.arguments,
90                            "Unsupported arguments",
91                        ));
92                    };
93
94                    out.self_type = Some(first);
95                    out.path = Path::Rename(second);
96                } else if first.ident == "Self" {
97                    out.self_type = Some(first);
98                } else {
99                    out.path = Path::Rename(first);
100                }
101            } else if ident == "deprecated" {
102                input.parse::<Token![=]>()?;
103                out.deprecated = Some(input.parse()?);
104            } else {
105                return Err(syn::Error::new_spanned(ident, "Unsupported option"));
106            }
107
108            if input.parse::<Option<Token![,]>>()?.is_none() {
109                break;
110            }
111        }
112
113        let stream = input.parse::<TokenStream>()?;
114
115        if !stream.is_empty() {
116            return Err(syn::Error::new_spanned(stream, "Unexpected input"));
117        }
118
119        Ok(out)
120    }
121}
122
123pub(crate) struct Function {
124    attributes: Vec<syn::Attribute>,
125    vis: syn::Visibility,
126    sig: syn::Signature,
127    remainder: TokenStream,
128    docs: syn::ExprArray,
129    arguments: syn::ExprArray,
130    takes_self: bool,
131}
132
133impl Function {
134    /// Parse the given parse stream.
135    pub(crate) fn parse(input: ParseStream) -> syn::Result<Self> {
136        let parsed_attributes = input.call(syn::Attribute::parse_outer)?;
137        let vis = input.parse::<syn::Visibility>()?;
138        let sig = input.parse::<syn::Signature>()?;
139
140        let mut attributes = Vec::new();
141
142        let mut docs = syn::ExprArray {
143            attrs: Vec::new(),
144            bracket_token: syn::token::Bracket::default(),
145            elems: Punctuated::default(),
146        };
147
148        for attr in parsed_attributes {
149            if attr.path().is_ident("doc") {
150                if let syn::Meta::NameValue(name_value) = &attr.meta {
151                    docs.elems.push(name_value.value.clone());
152                }
153            }
154
155            attributes.push(attr);
156        }
157
158        let mut arguments = syn::ExprArray {
159            attrs: Vec::new(),
160            bracket_token: syn::token::Bracket::default(),
161            elems: Punctuated::default(),
162        };
163
164        let mut takes_self = false;
165
166        for arg in &sig.inputs {
167            let argument_name = match arg {
168                syn::FnArg::Typed(ty) => argument_ident(&ty.pat),
169                syn::FnArg::Receiver(..) => {
170                    takes_self = true;
171                    syn::LitStr::new("self", arg.span())
172                }
173            };
174
175            arguments.elems.push(syn::Expr::Lit(syn::ExprLit {
176                attrs: Vec::new(),
177                lit: syn::Lit::Str(argument_name),
178            }));
179        }
180
181        let remainder = input.parse::<TokenStream>()?;
182
183        Ok(Self {
184            attributes,
185            vis,
186            sig,
187            remainder,
188            docs,
189            arguments,
190            takes_self,
191        })
192    }
193
194    /// Expand the function declaration.
195    pub(crate) fn expand(mut self, attrs: FunctionAttrs) -> syn::Result<TokenStream> {
196        let instance = attrs.instance || self.takes_self;
197
198        let (meta_fn, real_fn, mut sig, real_fn_mangled) = if attrs.keep {
199            let meta_fn =
200                syn::Ident::new(&format!("{}__meta", self.sig.ident), self.sig.ident.span());
201            let real_fn = self.sig.ident.clone();
202            (meta_fn, real_fn, self.sig.clone(), false)
203        } else {
204            let meta_fn = self.sig.ident.clone();
205            let real_fn = syn::Ident::new(
206                &format!("__rune_fn__{}", self.sig.ident),
207                self.sig.ident.span(),
208            );
209            let mut sig = self.sig.clone();
210            sig.ident = real_fn.clone();
211            (meta_fn, real_fn, sig, true)
212        };
213
214        let mut path = syn::Path {
215            leading_colon: None,
216            segments: Punctuated::default(),
217        };
218
219        match (self.takes_self, attrs.free, &attrs.self_type) {
220            (true, _, _) => {
221                path.segments
222                    .push(syn::PathSegment::from(<Token![Self]>::default()));
223                path.segments.push(syn::PathSegment::from(real_fn));
224            }
225            (_, false, Some(self_type)) => {
226                path.segments.push(self_type.clone());
227                path.segments.push(syn::PathSegment::from(real_fn));
228            }
229            _ => {
230                path.segments.push(syn::PathSegment::from(real_fn));
231            }
232        }
233
234        let real_fn_path = path;
235
236        let name_string = syn::LitStr::new(&self.sig.ident.to_string(), self.sig.ident.span());
237
238        let name = if instance {
239            'out: {
240                syn::Expr::Lit(syn::ExprLit {
241                    attrs: Vec::new(),
242                    lit: syn::Lit::Str(match &attrs.path {
243                        Path::Protocol(protocol) => {
244                            break 'out syn::parse_quote!(&#protocol);
245                        }
246                        Path::None => name_string.clone(),
247                        Path::Rename(last) => {
248                            syn::LitStr::new(&last.ident.to_string(), last.ident.span())
249                        }
250                    }),
251                })
252            }
253        } else {
254            match &attrs.path {
255                Path::None => expr_lit(&self.sig.ident),
256                Path::Rename(last) => expr_lit(&last.ident),
257                Path::Protocol(protocol) => syn::parse_quote!(&#protocol),
258            }
259        };
260
261        let arguments = match &attrs.path {
262            Path::None | Path::Protocol(_) => Punctuated::default(),
263            Path::Rename(last) => match &last.arguments {
264                syn::PathArguments::AngleBracketed(arguments) => arguments.args.clone(),
265                syn::PathArguments::None => Punctuated::default(),
266                arguments => {
267                    return Err(syn::Error::new_spanned(
268                        arguments,
269                        "Unsupported path segments",
270                    ));
271                }
272            },
273        };
274
275        let name = if !arguments.is_empty() {
276            let mut array = syn::ExprArray {
277                attrs: Vec::new(),
278                bracket_token: <syn::token::Bracket>::default(),
279                elems: Punctuated::default(),
280            };
281
282            for argument in arguments {
283                array.elems.push(syn::Expr::Verbatim(quote! {
284                    <#argument as rune::__private::TypeHash>::HASH
285                }));
286            }
287
288            quote!(rune::__private::Params::new(#name, #array))
289        } else {
290            quote!(#name)
291        };
292
293        if instance {
294            // Ensure that the first argument is called `self`.
295            if let Some(argument) = self.arguments.elems.first_mut() {
296                let span = argument.span();
297
298                *argument = syn::Expr::Lit(syn::ExprLit {
299                    attrs: Vec::new(),
300                    lit: syn::Lit::Str(syn::LitStr::new("self", span)),
301                });
302            }
303        }
304
305        let meta_kind = syn::Ident::new(
306            if instance { "instance" } else { "function" },
307            self.sig.span(),
308        );
309
310        let mut stream = TokenStream::new();
311
312        for attr in self.attributes {
313            stream.extend(attr.into_token_stream());
314        }
315
316        if real_fn_mangled {
317            stream.extend(quote!(#[allow(non_snake_case)]));
318            stream.extend(quote!(#[doc(hidden)]));
319        }
320
321        stream.extend(self.vis.to_token_stream());
322
323        let vm_result = VmResult::new();
324
325        if attrs.vm_result {
326            let vm_result = &vm_result.vm_result;
327
328            sig.output = match sig.output {
329                syn::ReturnType::Default => syn::ReturnType::Type(
330                    <Token![->]>::default(),
331                    Box::new(syn::Type::Verbatim(quote!(#vm_result<()>))),
332                ),
333                syn::ReturnType::Type(arrow, ty) => syn::ReturnType::Type(
334                    arrow,
335                    Box::new(syn::Type::Verbatim(quote!(#vm_result<#ty>))),
336                ),
337            };
338        }
339
340        let generics = sig.generics.clone();
341        stream.extend(sig.into_token_stream());
342
343        if attrs.vm_result {
344            let mut block: syn::Block = syn::parse2(self.remainder)?;
345            vm_result.block(&mut block, true)?;
346            block.to_tokens(&mut stream);
347        } else {
348            stream.extend(self.remainder);
349        }
350
351        let arguments = &self.arguments;
352        let docs = &self.docs;
353
354        let build_with = if instance {
355            None
356        } else if let Some(self_type) = &attrs.self_type {
357            Some(quote!(.build_associated::<#self_type>()?))
358        } else {
359            Some(quote!(.build()?))
360        };
361
362        let attributes = (!real_fn_mangled).then(|| quote!(#[allow(non_snake_case)]));
363
364        let deprecated = match &attrs.deprecated {
365            Some(message) => quote!(Some(#message)),
366            None => quote!(None),
367        };
368
369        let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
370        let type_generics = type_generics.as_turbofish();
371
372        stream.extend(quote! {
373            /// Get function metadata.
374            #[automatically_derived]
375            #attributes
376            #[doc(hidden)]
377            pub(crate) fn #meta_fn #impl_generics() -> rune::alloc::Result<rune::__private::FunctionMetaData>
378            #where_clause
379            {
380                Ok(rune::__private::FunctionMetaData {
381                    kind: rune::__private::FunctionMetaKind::#meta_kind(#name, #real_fn_path #type_generics)?#build_with,
382                    statics: rune::__private::FunctionMetaStatics {
383                        name: #name_string,
384                        deprecated: #deprecated,
385                        docs: &#docs[..],
386                        arguments: &#arguments[..],
387                    },
388                })
389            }
390        });
391
392        Ok(stream)
393    }
394}
395
396/// The identifier of an argument.
397fn argument_ident(pat: &syn::Pat) -> syn::LitStr {
398    match pat {
399        syn::Pat::Type(pat) => argument_ident(&pat.pat),
400        syn::Pat::Path(pat) => argument_path_ident(&pat.path),
401        syn::Pat::Ident(pat) => syn::LitStr::new(&pat.ident.to_string(), pat.span()),
402        _ => syn::LitStr::new(&pat.to_token_stream().to_string(), pat.span()),
403    }
404}
405
406/// Argument path identifier.
407fn argument_path_ident(path: &syn::Path) -> syn::LitStr {
408    match path.get_ident() {
409        Some(ident) => syn::LitStr::new(&ident.to_string(), path.span()),
410        None => syn::LitStr::new(&path.to_token_stream().to_string(), path.span()),
411    }
412}
413
414fn expr_lit(ident: &syn::Ident) -> syn::Expr {
415    syn::Expr::Lit(syn::ExprLit {
416        attrs: Vec::new(),
417        lit: syn::Lit::Str(syn::LitStr::new(&ident.to_string(), ident.span())),
418    })
419}
420
421struct VmResult {
422    vm_result: syn::Path,
423    from: syn::Path,
424    result: syn::Path,
425}
426
427impl VmResult {
428    fn new() -> Self {
429        Self {
430            vm_result: syn::parse_quote!(rune::runtime::VmResult),
431            from: syn::parse_quote!(core::convert::From),
432            result: syn::parse_quote!(core::result::Result),
433        }
434    }
435
436    /// Modify the block so that it is fallible.
437    fn block(&self, ast: &mut syn::Block, top_level: bool) -> syn::Result<()> {
438        let vm_result = &self.vm_result;
439
440        for stmt in &mut ast.stmts {
441            match stmt {
442                syn::Stmt::Expr(expr, _) => {
443                    self.expr(expr)?;
444                }
445                syn::Stmt::Local(local) => {
446                    let Some(init) = &mut local.init else {
447                        continue;
448                    };
449
450                    self.expr(&mut init.expr)?;
451
452                    let Some((_, expr)) = &mut init.diverge else {
453                        continue;
454                    };
455
456                    self.expr(expr)?;
457                }
458                _ => {}
459            };
460        }
461
462        if top_level {
463            let mut found = false;
464
465            for stmt in ast.stmts.iter_mut().rev() {
466                if let syn::Stmt::Expr(expr, semi) = stmt {
467                    if semi.is_none() {
468                        found = true;
469
470                        *expr = syn::Expr::Verbatim(quote_spanned! {
471                            expr.span() => #vm_result::Ok(#expr)
472                        });
473                    }
474
475                    break;
476                }
477            }
478
479            if !found {
480                ast.stmts.push(syn::Stmt::Expr(
481                    syn::Expr::Verbatim(quote!(#vm_result::Ok(()))),
482                    None,
483                ));
484            }
485        }
486
487        Ok(())
488    }
489
490    fn expr(&self, ast: &mut syn::Expr) -> syn::Result<()> {
491        let Self {
492            vm_result,
493            from,
494            result,
495        } = self;
496
497        let outcome = 'outcome: {
498            match ast {
499                syn::Expr::Array(expr) => {
500                    for expr in &mut expr.elems {
501                        self.expr(expr)?;
502                    }
503                }
504                syn::Expr::Assign(expt) => {
505                    self.expr(&mut expt.right)?;
506                }
507                syn::Expr::Async(..) => {}
508                syn::Expr::Await(expr) => {
509                    self.expr(&mut expr.base)?;
510                }
511                syn::Expr::Binary(expr) => {
512                    self.expr(&mut expr.left)?;
513                    self.expr(&mut expr.right)?;
514                }
515                syn::Expr::Block(block) => {
516                    self.block(&mut block.block, false)?;
517                }
518                syn::Expr::Break(expr) => {
519                    if let Some(expr) = &mut expr.expr {
520                        self.expr(expr)?;
521                    }
522                }
523                syn::Expr::Call(expr) => {
524                    self.expr(&mut expr.func)?;
525
526                    for expr in &mut expr.args {
527                        self.expr(expr)?;
528                    }
529                }
530                syn::Expr::Field(expr) => {
531                    self.expr(&mut expr.base)?;
532                }
533                syn::Expr::ForLoop(expr) => {
534                    self.expr(&mut expr.expr)?;
535                    self.block(&mut expr.body, false)?;
536                }
537                syn::Expr::Group(expr) => {
538                    self.expr(&mut expr.expr)?;
539                }
540                syn::Expr::If(expr) => {
541                    self.expr(&mut expr.cond)?;
542                    self.block(&mut expr.then_branch, false)?;
543
544                    if let Some((_, expr)) = &mut expr.else_branch {
545                        self.expr(expr)?;
546                    }
547                }
548                syn::Expr::Index(expr) => {
549                    self.expr(&mut expr.expr)?;
550                    self.expr(&mut expr.index)?;
551                }
552                syn::Expr::Let(expr) => {
553                    self.expr(&mut expr.expr)?;
554                }
555                syn::Expr::Loop(expr) => {
556                    self.block(&mut expr.body, false)?;
557                }
558                syn::Expr::Match(expr) => {
559                    self.expr(&mut expr.expr)?;
560
561                    for arm in &mut expr.arms {
562                        if let Some((_, expr)) = &mut arm.guard {
563                            self.expr(expr)?;
564                        }
565
566                        self.expr(&mut arm.body)?;
567                    }
568                }
569                syn::Expr::MethodCall(expr) => {
570                    self.expr(&mut expr.receiver)?;
571
572                    for expr in &mut expr.args {
573                        self.expr(expr)?;
574                    }
575                }
576                syn::Expr::Paren(expr) => {
577                    self.expr(&mut expr.expr)?;
578                }
579                syn::Expr::Range(expr) => {
580                    if let Some(expr) = &mut expr.start {
581                        self.expr(expr)?;
582                    }
583
584                    if let Some(expr) = &mut expr.end {
585                        self.expr(expr)?;
586                    }
587                }
588                syn::Expr::Reference(expr) => {
589                    self.expr(&mut expr.expr)?;
590                }
591                syn::Expr::Repeat(expr) => {
592                    self.expr(&mut expr.expr)?;
593                    self.expr(&mut expr.len)?;
594                }
595                syn::Expr::Return(expr) => {
596                    if let Some(expr) = &mut expr.expr {
597                        self.expr(expr)?;
598                    }
599
600                    expr.expr = Some(Box::new(match expr.expr.take() {
601                        Some(expr) => syn::Expr::Verbatim(quote_spanned! {
602                            expr.span() =>
603                            #vm_result::Ok(#expr)
604                        }),
605                        None => syn::Expr::Verbatim(quote!(#vm_result::Ok(()))),
606                    }));
607                }
608                syn::Expr::Struct(expr) => {
609                    for field in &mut expr.fields {
610                        self.expr(&mut field.expr)?;
611                    }
612                }
613                syn::Expr::Try(expr) => {
614                    let span = expr.span();
615
616                    self.expr(&mut expr.expr)?;
617
618                    break 'outcome if let Some((expr, ident)) = as_vm_expr(&mut expr.expr) {
619                        let vm_try = syn::Ident::new("vm_try", ident.span());
620                        quote_spanned!(span => rune::#vm_try!(#expr))
621                    } else {
622                        let value = &mut expr.expr;
623                        let from = quote_spanned!(expr.question_token.span() => #from::from);
624
625                        quote_spanned! {
626                            span =>
627                            match #value {
628                                #result::Ok(value) => value,
629                                #result::Err(error) => return #vm_result::Ok(#result::Err(#[allow(clippy::useless_conversion)] #from(error))),
630                            }
631                        }
632                    };
633                }
634                syn::Expr::Tuple(expr) => {
635                    for expr in &mut expr.elems {
636                        self.expr(expr)?;
637                    }
638                }
639                syn::Expr::Unary(expr) => {
640                    self.expr(&mut expr.expr)?;
641                }
642                syn::Expr::Unsafe(expr) => {
643                    self.block(&mut expr.block, false)?;
644                }
645                syn::Expr::While(expr) => {
646                    self.expr(&mut expr.cond)?;
647                    self.block(&mut expr.body, false)?;
648                }
649                syn::Expr::Yield(expr) => {
650                    if let Some(expr) = &mut expr.expr {
651                        self.expr(expr)?;
652                    }
653                }
654                _ => {}
655            }
656
657            return Ok(());
658        };
659
660        *ast = syn::Expr::Verbatim(outcome);
661        Ok(())
662    }
663}
664
665/// If this is a field expression like `<expr>.vm`.
666fn as_vm_expr(expr: &mut syn::Expr) -> Option<(&mut syn::Expr, &syn::Ident)> {
667    let syn::Expr::Field(expr) = expr else {
668        return None;
669    };
670
671    let syn::Member::Named(ident) = &expr.member else {
672        return None;
673    };
674
675    (ident == "vm").then_some((&mut expr.base, ident))
676}