rune_macros/
inst_display.rs

1use core::mem::take;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::punctuated::Punctuated;
6use syn::Token;
7
8/// The `InstDisplay` derive.
9pub struct Derive {
10    input: syn::DeriveInput,
11}
12
13impl syn::parse::Parse for Derive {
14    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
15        Ok(Self {
16            input: input.parse()?,
17        })
18    }
19}
20
21impl Derive {
22    pub(super) fn expand(self) -> Result<TokenStream, Vec<syn::Error>> {
23        let mut errors = Vec::new();
24
25        let syn::Data::Enum(en) = &self.input.data else {
26            errors.push(syn::Error::new_spanned(
27                &self.input.ident,
28                "InstDisplay is only supported for enums",
29            ));
30            return Err(errors);
31        };
32
33        let fmt = syn::Ident::new("fmt", Span::call_site());
34        let ident = self.input.ident;
35
36        let mut variants = Vec::new();
37
38        for variant in &en.variants {
39            let variant_ident = &variant.ident;
40            let mut patterns = Vec::new();
41            let mut fmt_call = Vec::new();
42
43            for (index, f) in variant.fields.iter().enumerate() {
44                let mut display_with = None::<syn::Path>;
45
46                for a in &f.attrs {
47                    if a.path().is_ident("inst_display") {
48                        let result = a.parse_nested_meta(|meta| {
49                            if meta.path.is_ident("display_with") {
50                                meta.input.parse::<Token![=]>()?;
51                                display_with = Some(meta.input.parse()?);
52                            } else {
53                                return Err(syn::Error::new(
54                                    meta.input.span(),
55                                    "Unsupported attribute",
56                                ));
57                            }
58
59                            Ok(())
60                        });
61
62                        if let Err(error) = result {
63                            errors.push(error);
64                            continue;
65                        }
66                    }
67                }
68
69                let member = match &f.ident {
70                    Some(ident) => syn::Member::Named(ident.clone()),
71                    None => syn::Member::Unnamed(syn::Index::from(index)),
72                };
73
74                let (assign, var) = match &f.ident {
75                    Some(ident) => (false, ident.clone()),
76                    None => (true, quote::format_ident!("_{index}")),
77                };
78
79                let mut path = syn::Path {
80                    leading_colon: None,
81                    segments: Punctuated::default(),
82                };
83
84                path.segments.push(syn::PathSegment::from(var.clone()));
85
86                patterns.push(syn::FieldValue {
87                    attrs: Vec::new(),
88                    member,
89                    colon_token: assign.then(<Token![:]>::default),
90                    expr: syn::Expr::Path(syn::ExprPath {
91                        attrs: Vec::new(),
92                        qself: None,
93                        path,
94                    }),
95                });
96
97                let var_name = syn::LitStr::new(&var.to_string(), var.span());
98
99                let var = syn::Expr::Path(syn::ExprPath {
100                    attrs: Vec::new(),
101                    qself: None,
102                    path: syn::Path::from(var),
103                });
104
105                let arg = if let Some(display_with) = display_with {
106                    let mut call = syn::ExprCall {
107                        attrs: Vec::new(),
108                        func: Box::new(syn::Expr::Path(syn::ExprPath {
109                            attrs: Vec::new(),
110                            qself: None,
111                            path: display_with.clone(),
112                        })),
113                        paren_token: syn::token::Paren::default(),
114                        args: Punctuated::new(),
115                    };
116
117                    call.args.push(var);
118                    let call = syn::Expr::Call(call);
119
120                    syn::Expr::Reference(syn::ExprReference {
121                        attrs: Vec::new(),
122                        and_token: <Token![&]>::default(),
123                        mutability: None,
124                        expr: Box::new(call),
125                    })
126                } else {
127                    var
128                };
129
130                if fmt_call.is_empty() {
131                    fmt_call.push(quote! {
132                        #fmt::Formatter::write_str(f, " ")?;
133                    });
134                } else {
135                    fmt_call.push(quote! {
136                        #fmt::Formatter::write_str(f, ", ")?;
137                    });
138                }
139
140                fmt_call.push(quote! {
141                    #fmt::Formatter::write_str(f, #var_name)?;
142                    #fmt::Formatter::write_str(f, "=")?;
143                    #fmt::Display::fmt(#arg, f)?
144                });
145            }
146
147            let variant_name = variant_name(&variant.ident.to_string());
148
149            variants.push(quote! {
150                #ident::#variant_ident { #(#patterns,)* } => {
151                    #fmt::Formatter::write_str(f, #variant_name)?;
152                    #(#fmt_call;)*
153                    Ok(())
154                }
155            });
156        }
157
158        if !errors.is_empty() {
159            return Err(errors);
160        }
161
162        let (impl_g, ty_g, where_g) = self.input.generics.split_for_impl();
163
164        Ok(quote! {
165            impl #impl_g #fmt::Display for #ident #ty_g #where_g {
166                fn fmt(&self, f: &mut #fmt::Formatter<'_>) -> #fmt::Result {
167                    match self {
168                        #(#variants,)*
169                    }
170                }
171            }
172        })
173    }
174}
175
176fn variant_name(name: &str) -> String {
177    let mut out = String::new();
178    let mut first = true;
179
180    for c in name.chars() {
181        if take(&mut first) {
182            out.extend(c.to_lowercase());
183            continue;
184        }
185
186        if c.is_uppercase() {
187            out.push('-');
188            out.extend(c.to_lowercase());
189            continue;
190        }
191
192        out.push(c);
193    }
194
195    out
196}