musli_macros/internals/
build.rs

1use std::rc::Rc;
2
3use proc_macro2::Span;
4use syn::punctuated::Punctuated;
5use syn::Token;
6
7use crate::de::{build_call, build_reference};
8use crate::expander::{
9    self, Data, EnumData, Expander, FieldData, NameMethod, StructData, StructKind, UnsizedMethod,
10    VariantData,
11};
12
13use super::attr::{EnumTagging, FieldEncoding, ModeKind, Packing};
14use super::name::NameAll;
15use super::tokens::Tokens;
16use super::ATTR;
17use super::{Ctxt, Expansion, Mode, Only, Result};
18
19pub(crate) struct Build<'a> {
20    pub(crate) input: &'a syn::DeriveInput,
21    pub(crate) cx: &'a Ctxt,
22    pub(crate) tokens: &'a Tokens,
23    pub(crate) bounds: &'a [(Span, syn::WherePredicate)],
24    pub(crate) decode_bounds: &'a [(Span, syn::WherePredicate)],
25    pub(crate) expansion: Expansion<'a>,
26    pub(crate) data: BuildData<'a>,
27    pub(crate) decode_t_decode: syn::Path,
28    pub(crate) encode_t_encode: syn::Path,
29    pub(crate) enum_tagging_span: Option<Span>,
30}
31
32impl Build<'_> {
33    /// Emit diagnostics for when we try to implement `Decode` for an enum which
34    /// is marked as `#[musli(transparent)]`.
35    pub(crate) fn encode_transparent_enum_diagnostics(&self, span: Span) {
36        self.cx.error_span(
37            span,
38            format_args!("#[{ATTR}(transparent)] cannot be used to encode enums",),
39        );
40    }
41
42    /// Emit diagnostics indicating that we tried to implement decode for a
43    /// packed enum.
44    pub(crate) fn decode_packed_enum_diagnostics(&self, span: Span) {
45        self.cx.error_span(
46            span,
47            format_args!("#[{ATTR}(packed)] cannot be used to decode enums"),
48        );
49    }
50
51    /// Emit diagnostics indicating that we tried to use a `#[musli(default)]`
52    /// annotation on a packed container.
53    pub(crate) fn packed_default_diagnostics(&self, span: Span) {
54        self.cx.error_span(
55            span,
56            format_args!("#[{ATTR}(default)] fields cannot be used in an packed container",),
57        );
58    }
59
60    /// Validate encode attributes.
61    pub(crate) fn validate_encode(&self) -> Result<()> {
62        self.validate()
63    }
64
65    /// Validate set of legal attributes.
66    pub(crate) fn validate_decode(&self) -> Result<()> {
67        self.validate()
68    }
69
70    fn validate(&self) -> Result<()> {
71        match &self.data {
72            BuildData::Struct(..) => {
73                if let Some(span) = self.enum_tagging_span {
74                    self.cx.error_span(
75                        span,
76                        format_args!(
77                            "#[{ATTR}(tag)] and #[{ATTR}(content)] are only supported on enums"
78                        ),
79                    );
80
81                    return Err(());
82                }
83            }
84            BuildData::Enum(..) => (),
85        }
86
87        Ok(())
88    }
89}
90
91/// Build model for enums and structs.
92pub(crate) enum BuildData<'a> {
93    Struct(Body<'a>),
94    Enum(Enum<'a>),
95}
96
97pub(crate) struct Body<'a> {
98    pub(crate) span: Span,
99    pub(crate) name: &'a syn::LitStr,
100    pub(crate) unskipped_fields: Vec<Rc<Field<'a>>>,
101    pub(crate) all_fields: Vec<Rc<Field<'a>>>,
102    pub(crate) name_type: syn::Type,
103    pub(crate) name_method: NameMethod,
104    pub(crate) name_format_with: Option<&'a (Span, syn::Path)>,
105    pub(crate) packing: Packing,
106    pub(crate) kind: StructKind,
107    pub(crate) path: syn::Path,
108}
109
110impl Body<'_> {
111    pub(crate) fn validate(&self, cx: &Ctxt) {
112        if self.packing == Packing::Transparent && !matches!(&self.unskipped_fields[..], [_]) {
113            cx.transparent_diagnostics(self.span, &self.unskipped_fields);
114        }
115    }
116
117    pub(crate) fn name_format(&self, value: &syn::Expr) -> syn::Expr {
118        match self.name_format_with {
119            Some((_, path)) => build_call(path, [build_reference(value.clone())]),
120            None => build_reference(value.clone()),
121        }
122    }
123
124    pub(crate) fn name_local_type(&self) -> syn::Type {
125        match self.name_method {
126            NameMethod::Unsized(..) => syn::Type::Reference(syn::TypeReference {
127                and_token: <Token![&]>::default(),
128                lifetime: None,
129                mutability: None,
130                elem: Box::new(self.name_type.clone()),
131            }),
132            NameMethod::Value => self.name_type.clone(),
133        }
134    }
135}
136
137pub(crate) struct Enum<'a> {
138    pub(crate) span: Span,
139    pub(crate) name: &'a syn::LitStr,
140    pub(crate) enum_tagging: EnumTagging<'a>,
141    pub(crate) enum_packing: Packing,
142    pub(crate) variants: Vec<Variant<'a>>,
143    pub(crate) fallback: Option<&'a syn::Ident>,
144    pub(crate) name_type: syn::Type,
145    pub(crate) name_method: NameMethod,
146    pub(crate) name_format_with: Option<&'a (Span, syn::Path)>,
147    pub(crate) packing_span: Option<&'a (Span, Packing)>,
148}
149
150impl Enum<'_> {
151    pub(crate) fn name_format(
152        &self,
153        static_var: &syn::Ident,
154        value: &syn::Expr,
155    ) -> (Option<syn::ItemStatic>, syn::Expr) {
156        match self.name_format_with {
157            Some((_, path)) => (None, syn::parse_quote!(&#path(#static_var, #value))),
158            None => {
159                let static_type = self.static_type();
160
161                (
162                    Some(syn::parse_quote!(static #static_var: #static_type = #value;)),
163                    syn::parse_quote!(&#static_var),
164                )
165            }
166        }
167    }
168
169    pub(crate) fn static_type(&self) -> syn::Type {
170        match self.name_method {
171            NameMethod::Unsized(..) => syn::Type::Reference(syn::TypeReference {
172                and_token: <Token![&]>::default(),
173                lifetime: None,
174                mutability: None,
175                elem: Box::new(self.name_type.clone()),
176            }),
177            NameMethod::Value => self.name_type.clone(),
178        }
179    }
180}
181
182pub(crate) struct Variant<'a> {
183    pub(crate) span: Span,
184    pub(crate) index: usize,
185    pub(crate) name: syn::Expr,
186    pub(crate) pattern: Option<&'a syn::Pat>,
187    pub(crate) st: Body<'a>,
188    pub(crate) patterns: Punctuated<syn::FieldPat, Token![,]>,
189}
190
191pub(crate) struct Field<'a> {
192    pub(crate) span: Span,
193    pub(crate) index: usize,
194    pub(crate) encode_path: (Span, syn::Path),
195    pub(crate) decode_path: (Span, syn::Path),
196    pub(crate) name: syn::Expr,
197    pub(crate) pattern: Option<&'a syn::Pat>,
198    /// Skip field entirely and always initialize with the specified expresion,
199    /// or default value through `default_attr`.
200    pub(crate) skip: Option<Span>,
201    pub(crate) skip_encoding_if: Option<&'a (Span, syn::Path)>,
202    /// Fill with default value, if missing.
203    pub(crate) default_attr: Option<(Span, Option<&'a syn::Path>)>,
204    pub(crate) self_access: syn::Expr,
205    pub(crate) member: syn::Member,
206    pub(crate) packing: Packing,
207    pub(crate) var: syn::Ident,
208    pub(crate) ty: &'a syn::Type,
209}
210
211/// Setup a build.
212///
213/// Handles mode decoding, and construction of parameters which might give rise to errors.
214pub(crate) fn setup<'a>(
215    e: &'a Expander,
216    expansion: Expansion<'a>,
217    only: Only,
218) -> Result<Build<'a>> {
219    let mode = expansion.as_mode(&e.tokens, only);
220
221    let data = match &e.data {
222        Data::Struct(data) => BuildData::Struct(setup_struct(e, mode, data)),
223        Data::Enum(data) => BuildData::Enum(setup_enum(e, mode, data)),
224        Data::Union => {
225            e.cx.error_span(e.input.ident.span(), "musli: not supported for unions");
226            return Err(());
227        }
228    };
229
230    if e.cx.has_errors() {
231        return Err(());
232    }
233
234    Ok(Build {
235        input: e.input,
236        cx: &e.cx,
237        tokens: &e.tokens,
238        bounds: e.type_attr.bounds(mode),
239        decode_bounds: e.type_attr.decode_bounds(mode),
240        expansion,
241        data,
242        decode_t_decode: mode.decode_t_decode(FieldEncoding::Default),
243        encode_t_encode: mode.encode_t_encode(FieldEncoding::Default),
244        enum_tagging_span: e.type_attr.enum_tagging_span(mode),
245    })
246}
247
248fn setup_struct<'a>(e: &'a Expander, mode: Mode<'_>, data: &'a StructData<'a>) -> Body<'a> {
249    let mut unskipped_fields = Vec::with_capacity(data.fields.len());
250    let mut all_fields = Vec::with_capacity(data.fields.len());
251
252    let packing = e
253        .type_attr
254        .packing(mode)
255        .map(|&(_, p)| p)
256        .unwrap_or_default();
257
258    let (name_all, name_type, name_method) = match data.kind {
259        StructKind::Indexed(..) if e.type_attr.is_name_type_ambiguous(mode) => {
260            let name_all = NameAll::Index;
261            (name_all, name_all.ty(), NameMethod::Value)
262        }
263        _ => split_name(
264            mode.kind,
265            e.type_attr.name_type(mode),
266            e.type_attr.name_all(mode),
267            e.type_attr.name_method(mode),
268        ),
269    };
270
271    let path = syn::Path::from(syn::Ident::new("Self", e.input.ident.span()));
272
273    for f in &data.fields {
274        let field = Rc::new(setup_field(e, mode, f, name_all, packing, None));
275
276        if field.skip.is_none() {
277            unskipped_fields.push(field.clone());
278        }
279
280        all_fields.push(field);
281    }
282
283    let body = Body {
284        span: data.span,
285        name: &data.name,
286        unskipped_fields,
287        all_fields,
288        name_type,
289        name_method,
290        name_format_with: e.type_attr.name_format_with(mode),
291        packing,
292        kind: data.kind,
293        path,
294    };
295
296    body.validate(&e.cx);
297    body
298}
299
300fn setup_enum<'a>(e: &'a Expander, mode: Mode<'_>, data: &'a EnumData<'a>) -> Enum<'a> {
301    let mut variants = Vec::with_capacity(data.variants.len());
302    let mut fallback = None;
303
304    let packing_span = e.type_attr.packing(mode);
305
306    let enum_tagging = match e.type_attr.enum_tagging(mode) {
307        Some(enum_tagging) => enum_tagging,
308        None => {
309            if data
310                .variants
311                .iter()
312                .all(|v| matches!(v.kind, StructKind::Indexed(0) | StructKind::Empty))
313            {
314                EnumTagging::Empty
315            } else {
316                EnumTagging::Default
317            }
318        }
319    };
320
321    if !matches!(enum_tagging, EnumTagging::Default | EnumTagging::Empty) {
322        match packing_span {
323            Some((_, Packing::Tagged)) => (),
324            Some(&(span, Packing::Packed)) => {
325                e.cx.error_span(span, format_args!("#[{ATTR}(packed)] cannot be combined with #[{ATTR}(tag)] or #[{ATTR}(content)]"));
326            }
327            Some(&(span, Packing::Transparent)) => {
328                e.cx.error_span(span, format_args!("#[{ATTR}(transparent)] cannot be combined with #[{ATTR}(tag)] or #[{ATTR}(content)]"));
329            }
330            _ => (),
331        }
332    }
333
334    let enum_packing = e
335        .type_attr
336        .packing(mode)
337        .map(|&(_, p)| p)
338        .unwrap_or_default();
339
340    let (_, name_type, name_method) = split_name(
341        mode.kind,
342        e.type_attr.name_type(mode),
343        e.type_attr.name_all(mode),
344        e.type_attr.name_method(mode),
345    );
346
347    for v in &data.variants {
348        variants.push(setup_variant(e, mode, v, &mut fallback));
349    }
350
351    Enum {
352        span: data.span,
353        name: &data.name,
354        enum_tagging,
355        enum_packing,
356        variants,
357        fallback,
358        name_type,
359        name_method,
360        name_format_with: e.type_attr.name_format_with(mode),
361        packing_span,
362    }
363}
364
365fn setup_variant<'a>(
366    e: &'a Expander<'_>,
367    mode: Mode<'_>,
368    data: &'a VariantData<'a>,
369    fallback: &mut Option<&'a syn::Ident>,
370) -> Variant<'a> {
371    let mut unskipped_fields = Vec::with_capacity(data.fields.len());
372    let mut all_fields = Vec::with_capacity(data.fields.len());
373
374    let variant_packing = data
375        .attr
376        .packing(mode)
377        .or_else(|| e.type_attr.packing(mode))
378        .map(|&(_, v)| v)
379        .unwrap_or_default();
380
381    let (name_all, name_type, name_method) = match data.kind {
382        StructKind::Indexed(..) if data.attr.is_name_type_ambiguous(mode) => {
383            let name_all = NameAll::Index;
384            (name_all, name_all.ty(), NameMethod::Value)
385        }
386        _ => split_name(
387            mode.kind,
388            data.attr.name_type(mode),
389            data.attr.name_all(mode),
390            data.attr.name_method(mode),
391        ),
392    };
393
394    let (type_name_all, _, _) = split_name(
395        mode.kind,
396        e.type_attr.name_type(mode),
397        e.type_attr.name_all(mode),
398        e.type_attr.name_method(mode),
399    );
400
401    let name = expander::expand_name(data, mode, type_name_all, Some(data.ident));
402
403    let pattern = data.attr.pattern(mode).map(|(_, p)| p);
404
405    let mut path = syn::Path::from(syn::Ident::new("Self", data.span));
406    path.segments.push(data.ident.clone().into());
407
408    if let Some((span, _)) = data.attr.default_variant(mode) {
409        if !data.fields.is_empty() {
410            e.cx.error_span(
411                *span,
412                format_args!("#[{ATTR}(default)] variant must be empty"),
413            );
414        } else if fallback.is_some() {
415            e.cx.error_span(
416                *span,
417                format_args!("#[{ATTR}(default)] only one fallback variant is supported",),
418            );
419        } else {
420            *fallback = Some(data.ident);
421        }
422    }
423
424    let mut patterns = Punctuated::default();
425
426    for f in &data.fields {
427        let field = Rc::new(setup_field(
428            e,
429            mode,
430            f,
431            name_all,
432            variant_packing,
433            Some(&mut patterns),
434        ));
435
436        if field.skip.is_none() {
437            unskipped_fields.push(field.clone());
438        }
439
440        all_fields.push(field);
441    }
442
443    let st = Body {
444        span: data.span,
445        name: &data.name,
446        unskipped_fields,
447        all_fields,
448        packing: variant_packing,
449        kind: data.kind,
450        name_type,
451        name_method,
452        name_format_with: data.attr.name_format_with(mode),
453        path,
454    };
455
456    st.validate(&e.cx);
457
458    Variant {
459        span: data.span,
460        index: data.index,
461        name,
462        pattern,
463        patterns,
464        st,
465    }
466}
467
468fn setup_field<'a>(
469    e: &'a Expander,
470    mode: Mode<'_>,
471    data: &'a FieldData<'a>,
472    name_all: NameAll,
473    packing: Packing,
474    patterns: Option<&mut Punctuated<syn::FieldPat, Token![,]>>,
475) -> Field<'a> {
476    let encode_path = data.attr.encode_path_expanded(mode, data.span);
477    let decode_path = data.attr.decode_path_expanded(mode, data.span);
478
479    let name = expander::expand_name(data, mode, name_all, data.ident);
480    let pattern = data.attr.pattern(mode).map(|(_, p)| p);
481
482    let skip = data.attr.skip(mode).map(|&(s, ())| s);
483    let skip_encoding_if = data.attr.skip_encoding_if(mode);
484    let default_attr = data
485        .attr
486        .is_default(mode)
487        .map(|(s, path)| (*s, path.as_ref()));
488
489    let member = match data.ident {
490        Some(ident) => syn::Member::Named(ident.clone()),
491        None => syn::Member::Unnamed(syn::Index {
492            index: data.index as u32,
493            span: data.span,
494        }),
495    };
496
497    let self_access = if let Some(patterns) = patterns {
498        match data.ident {
499            Some(ident) => {
500                patterns.push(syn::FieldPat {
501                    attrs: Vec::new(),
502                    member: syn::Member::Named(ident.clone()),
503                    colon_token: None,
504                    pat: Box::new(syn::Pat::Path(syn::PatPath {
505                        attrs: Vec::new(),
506                        qself: None,
507                        path: syn::Path::from(ident.clone()),
508                    })),
509                });
510
511                syn::Expr::Path(syn::ExprPath {
512                    attrs: Vec::new(),
513                    qself: None,
514                    path: ident.clone().into(),
515                })
516            }
517            None => {
518                let var = quote::format_ident!("v{}", data.index);
519
520                patterns.push(syn::FieldPat {
521                    attrs: Vec::new(),
522                    member: syn::Member::Unnamed(syn::Index::from(data.index)),
523                    colon_token: Some(<Token![:]>::default()),
524                    pat: Box::new(syn::Pat::Path(syn::PatPath {
525                        attrs: Vec::new(),
526                        qself: None,
527                        path: syn::Path::from(var.clone()),
528                    })),
529                });
530
531                syn::Expr::Path(syn::ExprPath {
532                    attrs: Vec::new(),
533                    qself: None,
534                    path: var.into(),
535                })
536            }
537        }
538    } else {
539        let expr = syn::Expr::Field(syn::ExprField {
540            attrs: Vec::new(),
541            base: Box::new(syn::Expr::Path(syn::ExprPath {
542                attrs: Vec::new(),
543                qself: None,
544                path: <Token![self]>::default().into(),
545            })),
546            dot_token: <Token![.]>::default(),
547            member: member.clone(),
548        });
549
550        syn::Expr::Reference(syn::ExprReference {
551            attrs: Vec::new(),
552            and_token: <Token![&]>::default(),
553            mutability: None,
554            expr: Box::new(expr),
555        })
556    };
557
558    let var = match &member {
559        syn::Member::Named(ident) => e.cx.ident_with_span(&ident.to_string(), ident.span(), "_f"),
560        syn::Member::Unnamed(index) => {
561            e.cx.ident_with_span(&index.index.to_string(), index.span, "_f")
562        }
563    };
564
565    Field {
566        span: data.span,
567        index: data.index,
568        encode_path,
569        decode_path,
570        name,
571        pattern,
572        skip,
573        skip_encoding_if,
574        default_attr,
575        self_access,
576        member,
577        packing,
578        var,
579        ty: data.ty,
580    }
581}
582
583fn split_name(
584    kind: Option<&ModeKind>,
585    name_type: Option<&(Span, syn::Type)>,
586    name_all: Option<&(Span, NameAll)>,
587    name_method: Option<&(Span, NameMethod)>,
588) -> (NameAll, syn::Type, NameMethod) {
589    let kind_name_all = kind.and_then(ModeKind::default_name_all);
590
591    let name_all = name_all.map(|&(_, v)| v);
592    let name_method = name_method.map(|&(_, v)| v);
593
594    let Some((_, name_type)) = name_type else {
595        let name_all = name_all.or(kind_name_all).unwrap_or_default();
596        let name_method = name_method.unwrap_or_else(|| name_all.name_method());
597        return (name_all, name_all.ty(), name_method);
598    };
599
600    let (name_method, default_name_all) = match name_method {
601        Some(name_method) => (name_method, name_method.name_all()),
602        None => determine_name_method(name_type),
603    };
604
605    let name_all = name_all.or(default_name_all).unwrap_or_default();
606    (name_all, name_type.clone(), name_method)
607}
608
609fn determine_name_method(ty: &syn::Type) -> (NameMethod, Option<NameAll>) {
610    match ty {
611        syn::Type::Path(syn::TypePath { qself: None, path }) if path.is_ident("str") => {
612            return (
613                NameMethod::Unsized(UnsizedMethod::Default),
614                Some(NameAll::Name),
615            );
616        }
617        syn::Type::Slice(syn::TypeSlice { elem, .. }) => match &**elem {
618            syn::Type::Path(syn::TypePath { qself: None, path }) if path.is_ident("u8") => {
619                return (
620                    NameMethod::Unsized(UnsizedMethod::Bytes),
621                    Some(NameAll::Name),
622                );
623            }
624            _ => {}
625        },
626        _ => {}
627    }
628
629    (NameMethod::Value, None)
630}