musli_macros/
types.rs

1use std::collections::BTreeMap;
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::Parse;
6use syn::punctuated::Punctuated;
7use syn::Token;
8
9const U_PARAM: &str = "__U";
10
11#[derive(Debug, Clone, Copy)]
12pub(super) enum Ty {
13    /// `str`.
14    Str,
15    /// `[u8]`.
16    Bytes,
17}
18
19#[derive(Debug, Clone, Copy)]
20pub(super) enum Extra {
21    /// `type Type = Never;`
22    None,
23    /// `type Error = <Self::Cx as Context>::Error;`
24    Error,
25    /// `type Mode = <Self::Cx as Context>::Mode;`
26    Mode,
27    Context,
28    Visitor(Ty),
29}
30
31pub(crate) enum Fn {
32    Decode,
33    DecodeUnsized,
34    DecodeUnsizedBytes,
35}
36
37pub(super) const ENCODER_TYPES: &[(&str, Extra)] = &[
38    ("Error", Extra::Error),
39    ("Mode", Extra::Mode),
40    ("WithContext", Extra::Context),
41    ("EncodeSome", Extra::None),
42    ("EncodePack", Extra::None),
43    ("EncodeSequence", Extra::None),
44    ("EncodeMap", Extra::None),
45    ("EncodeMapEntries", Extra::None),
46    ("EncodeVariant", Extra::None),
47    ("EncodeSequenceVariant", Extra::None),
48    ("EncodeMapVariant", Extra::None),
49];
50
51pub(super) const DECODER_TYPES: &[(&str, Extra)] = &[
52    ("Error", Extra::Error),
53    ("Mode", Extra::Mode),
54    ("WithContext", Extra::Context),
55    ("DecodeBuffer", Extra::None),
56    ("DecodeSome", Extra::None),
57    ("DecodePack", Extra::None),
58    ("DecodeSequence", Extra::None),
59    ("DecodeMap", Extra::None),
60    ("DecodeMapEntries", Extra::None),
61    ("DecodeVariant", Extra::None),
62];
63
64pub(super) const DECODER_FNS: &[(&str, Fn)] = &[
65    ("decode", Fn::Decode),
66    ("decode_unsized", Fn::DecodeUnsized),
67    ("decode_unsized_bytes", Fn::DecodeUnsizedBytes),
68];
69
70pub(super) const VISITOR_TYPES: &[(&str, Extra)] = &[
71    ("String", Extra::Visitor(Ty::Str)),
72    ("Bytes", Extra::Visitor(Ty::Bytes)),
73];
74
75#[derive(Clone, Copy)]
76pub(super) enum Kind {
77    SelfCx,
78    GenericCx,
79}
80
81pub(super) struct Attr {
82    crate_path: Option<syn::Path>,
83}
84
85impl Parse for Attr {
86    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
87        let mut crate_path = None;
88
89        while !input.is_empty() {
90            let path = input.parse::<syn::Path>()?;
91
92            if path.is_ident("crate") {
93                if input.parse::<Option<Token![=]>>()?.is_some() {
94                    crate_path = Some(input.parse()?);
95                } else {
96                    crate_path = Some(path);
97                }
98            } else {
99                return Err(syn::Error::new_spanned(
100                    path,
101                    format_args!("Unexpected attribute"),
102                ));
103            }
104
105            if !input.is_empty() {
106                input.parse::<Token![,]>()?;
107            }
108        }
109
110        Ok(Self { crate_path })
111    }
112}
113
114pub(super) struct Types {
115    item_impl: syn::ItemImpl,
116}
117
118impl Parse for Types {
119    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
120        Ok(Self {
121            item_impl: input.parse()?,
122        })
123    }
124}
125
126impl Types {
127    /// Expand encoder types.
128    pub(crate) fn expand(
129        mut self,
130        default_crate: &str,
131        attr: &Attr,
132        what: &str,
133        types: &[(&str, Extra)],
134        fns: &[(&str, Fn)],
135        argument: Option<&str>,
136        hint: &str,
137        kind: Kind,
138    ) -> syn::Result<TokenStream> {
139        let default_crate_path;
140
141        let crate_path = match &attr.crate_path {
142            Some(path) => path,
143            None => {
144                default_crate_path = ident_path(syn::Ident::new(default_crate, Span::call_site()));
145                &default_crate_path
146            }
147        };
148
149        let mut missing = types
150            .iter()
151            .map(|(ident, extra)| (syn::Ident::new(ident, Span::call_site()), *extra))
152            .collect::<BTreeMap<_, _>>();
153
154        let mut missing_fns = fns
155            .iter()
156            .map(|(name, f)| (syn::Ident::new(name, Span::call_site()), f))
157            .collect::<BTreeMap<_, _>>();
158
159        // List of associated types which are specified, but under a `cfg`
160        // attribute so its conditions need to be inverted.
161        let mut not_attribute_ty = Vec::new();
162
163        for item in &self.item_impl.items {
164            match item {
165                syn::ImplItem::Type(impl_type) => {
166                    let Some(extra) = missing.remove(&impl_type.ident) else {
167                        continue;
168                    };
169
170                    let mut has_cfg = false;
171
172                    for attr in &impl_type.attrs {
173                        if !attr.path().is_ident("cfg") {
174                            continue;
175                        }
176
177                        if has_cfg {
178                            return Err(syn::Error::new_spanned(
179                                attr,
180                                format_args!(
181                                    "#[rune::{what}]: only one cfg attribute is supported"
182                                ),
183                            ));
184                        }
185
186                        not_attribute_ty.push((impl_type.clone(), extra));
187                        has_cfg = true;
188                    }
189                }
190                syn::ImplItem::Fn(f) => {
191                    missing_fns.remove(&f.sig.ident);
192                }
193                _ => continue,
194            }
195        }
196
197        for (mut impl_type, extra) in not_attribute_ty {
198            for attr in &mut impl_type.attrs {
199                if !attr.path().is_ident("cfg") {
200                    continue;
201                }
202
203                if let syn::Meta::List(m) = &mut attr.meta {
204                    let tokens = syn::Meta::List(syn::MetaList {
205                        path: ident_path(syn::Ident::new("not", Span::call_site())),
206                        delimiter: syn::MacroDelimiter::Paren(syn::token::Paren::default()),
207                        tokens: m.tokens.clone(),
208                    })
209                    .into_token_stream();
210
211                    m.tokens = tokens;
212                }
213            }
214
215            impl_type.ty = syn::Type::Path(syn::TypePath {
216                qself: None,
217                path: self.never_type(crate_path, argument, extra, kind)?,
218            });
219
220            self.item_impl.items.push(syn::ImplItem::Type(impl_type));
221        }
222
223        for (_, f) in missing_fns {
224            match f {
225                Fn::Decode => {
226                    self.item_impl
227                        .items
228                        .push(syn::ImplItem::Verbatim(quote::quote! {
229                            #[inline(always)]
230                            fn decode<T>(self) -> Result<T, Self::Error>
231                            where
232                                T: #crate_path::de::Decode<'de, Self::Mode>
233                            {
234                                self.cx.decode(self)
235                            }
236                        }));
237                }
238                Fn::DecodeUnsized => {
239                    self.item_impl.items.push(syn::ImplItem::Verbatim(quote::quote! {
240                        #[inline(always)]
241                        fn decode_unsized<T, F, O>(self, f: F) -> Result<O, Self::Error>
242                        where
243                            T: ?Sized + #crate_path::de::DecodeUnsized<'de, Self::Mode>,
244                            F: FnOnce(&T) -> Result<O, <Self::Cx as #crate_path::Context>::Error>
245                        {
246                            self.cx.decode_unsized(self, f)
247                        }
248                    }));
249                }
250                Fn::DecodeUnsizedBytes => {
251                    self.item_impl.items.push(syn::ImplItem::Verbatim(quote::quote! {
252                        #[inline(always)]
253                        fn decode_unsized_bytes<T, F, O>(self, f: F) -> Result<O, Self::Error>
254                        where
255                            T: ?Sized + #crate_path::de::DecodeUnsizedBytes<'de, Self::Mode>,
256                            F: FnOnce(&T) -> Result<O, <Self::Cx as #crate_path::Context>::Error>
257                        {
258                            self.cx.decode_unsized_bytes(self, f)
259                        }
260                    }));
261                }
262            }
263        }
264
265        for (ident, extra) in missing {
266            let ty;
267            let generics;
268
269            match extra {
270                Extra::Mode => {
271                    ty = syn::parse_quote!(<Self::Cx as #crate_path::Context>::Mode);
272                    generics = syn::Generics::default();
273                }
274                Extra::Error => {
275                    ty = syn::parse_quote!(<Self::Cx as #crate_path::Context>::Error);
276                    generics = syn::Generics::default();
277                }
278                Extra::Context => {
279                    let u_param = syn::Ident::new(U_PARAM, Span::call_site());
280
281                    let mut params = Punctuated::default();
282
283                    let this_lifetime = syn::Lifetime::new("'this", Span::call_site());
284
285                    params.push(syn::GenericParam::Lifetime(syn::LifetimeParam {
286                        attrs: Vec::new(),
287                        lifetime: this_lifetime.clone(),
288                        colon_token: None,
289                        bounds: Punctuated::default(),
290                    }));
291
292                    params.push(syn::GenericParam::Type(syn::TypeParam {
293                        attrs: Vec::new(),
294                        ident: u_param.clone(),
295                        colon_token: None,
296                        bounds: Punctuated::default(),
297                        eq_token: None,
298                        default: None,
299                    }));
300
301                    ty = syn::Type::Path(syn::TypePath {
302                        qself: None,
303                        path: self.never_type(crate_path, argument, extra, kind)?,
304                    });
305
306                    let mut where_clause = syn::WhereClause {
307                        where_token: <Token![where]>::default(),
308                        predicates: Punctuated::default(),
309                    };
310
311                    where_clause
312                        .predicates
313                        .push(syn::parse_quote!(#u_param: #this_lifetime + #crate_path::Context));
314
315                    generics = syn::Generics {
316                        lt_token: Some(<Token![<]>::default()),
317                        params,
318                        gt_token: Some(<Token![>]>::default()),
319                        where_clause: Some(where_clause),
320                    };
321                }
322                _ => {
323                    ty = syn::Type::Path(syn::TypePath {
324                        qself: None,
325                        path: self.never_type(crate_path, argument, extra, kind)?,
326                    });
327
328                    generics = syn::Generics::default();
329                }
330            };
331
332            let ty = syn::ImplItemType {
333                attrs: Vec::new(),
334                vis: syn::Visibility::Inherited,
335                defaultness: None,
336                type_token: <Token![type]>::default(),
337                ident,
338                generics,
339                eq_token: <Token![=]>::default(),
340                ty,
341                semi_token: <Token![;]>::default(),
342            };
343
344            self.item_impl.items.push(syn::ImplItem::Type(ty));
345        }
346
347        self.item_impl
348            .items
349            .push(syn::ImplItem::Type(syn::ImplItemType {
350                attrs: Vec::new(),
351                vis: syn::Visibility::Inherited,
352                defaultness: None,
353                type_token: <Token![type]>::default(),
354                ident: syn::Ident::new(hint, Span::call_site()),
355                generics: syn::Generics::default(),
356                eq_token: <Token![=]>::default(),
357                ty: syn::Type::Tuple(syn::TypeTuple {
358                    paren_token: <syn::token::Paren>::default(),
359                    elems: Punctuated::default(),
360                }),
361                semi_token: <Token![;]>::default(),
362            }));
363
364        Ok(self.item_impl.into_token_stream())
365    }
366
367    fn never_type(
368        &self,
369        crate_path: &syn::Path,
370        argument: Option<&str>,
371        extra: Extra,
372        kind: Kind,
373    ) -> syn::Result<syn::Path> {
374        let mut never = crate_path.clone();
375
376        never.segments.push(syn::PathSegment::from(syn::Ident::new(
377            "__priv",
378            Span::call_site(),
379        )));
380
381        never.segments.push({
382            let mut s = syn::PathSegment::from(syn::Ident::new("Never", Span::call_site()));
383
384            let mut args = Punctuated::default();
385
386            if let Some(arg) = argument {
387                args.push(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
388                    qself: None,
389                    path: self_type(arg),
390                })));
391            } else {
392                args.push(syn::parse_quote!(()));
393            }
394
395            match extra {
396                Extra::Visitor(ty) => match ty {
397                    Ty::Str => {
398                        args.push(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
399                            qself: None,
400                            path: ident_path(syn::Ident::new("str", Span::call_site())),
401                        })));
402                    }
403                    Ty::Bytes => {
404                        let mut path = syn::Path {
405                            leading_colon: None,
406                            segments: Punctuated::default(),
407                        };
408
409                        path.segments.push(syn::PathSegment::from(syn::Ident::new(
410                            "u8",
411                            Span::call_site(),
412                        )));
413
414                        args.push(syn::GenericArgument::Type(syn::Type::Slice(
415                            syn::TypeSlice {
416                                bracket_token: syn::token::Bracket::default(),
417                                elem: Box::new(syn::Type::Path(syn::TypePath {
418                                    qself: None,
419                                    path,
420                                })),
421                            },
422                        )));
423                    }
424                },
425                Extra::Context => {
426                    let u_param = syn::Ident::new(U_PARAM, Span::call_site());
427                    args.push(syn::parse_quote!(#u_param));
428                }
429                Extra::None => match kind {
430                    Kind::SelfCx => {
431                        args.push(syn::parse_quote!(Self::Cx));
432                    }
433                    Kind::GenericCx => {}
434                },
435                _ => {}
436            }
437
438            if !args.is_empty() {
439                s.arguments =
440                    syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
441                        colon2_token: None,
442                        lt_token: <Token![<]>::default(),
443                        args,
444                        gt_token: <Token![>]>::default(),
445                    });
446            }
447
448            s
449        });
450
451        Ok(never)
452    }
453}
454
455fn ident_path(ident: syn::Ident) -> syn::Path {
456    let mut not_path = syn::Path {
457        leading_colon: None,
458        segments: Punctuated::default(),
459    };
460
461    not_path.segments.push(syn::PathSegment::from(ident));
462
463    not_path
464}
465
466fn self_type(what: &str) -> syn::Path {
467    let mut self_error = syn::Path {
468        leading_colon: None,
469        segments: Punctuated::default(),
470    };
471
472    self_error
473        .segments
474        .push(syn::PathSegment::from(syn::Ident::new(
475            "Self",
476            Span::call_site(),
477        )));
478
479    self_error
480        .segments
481        .push(syn::PathSegment::from(syn::Ident::new(
482            what,
483            Span::call_site(),
484        )));
485
486    self_error
487}