musli_macros/
de.rs

1use proc_macro2::{Ident, Literal, Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::punctuated::Punctuated;
4use syn::spanned::Spanned;
5use syn::Token;
6
7use crate::expander::{NameMethod, StructKind};
8use crate::internals::apply;
9use crate::internals::attr::{EnumTagging, Packing};
10use crate::internals::build::{Body, Build, BuildData, Enum, Field, Variant};
11use crate::internals::tokens::Tokens;
12use crate::internals::Result;
13
14struct Ctxt<'a> {
15    ctx_var: &'a Ident,
16    decoder_var: &'a Ident,
17    name_var: &'a Ident,
18    trace: bool,
19    trace_body: bool,
20}
21
22pub(crate) fn expand_decode_entry(e: Build<'_>) -> Result<TokenStream> {
23    e.validate_decode()?;
24    e.cx.reset();
25
26    let ctx_var = e.cx.ident("ctx");
27    let root_decoder_var = e.cx.ident("decoder");
28    let tag_var = e.cx.ident("tag");
29    let d_param = e.cx.type_with_span("D", Span::call_site());
30
31    let cx = Ctxt {
32        ctx_var: &ctx_var,
33        decoder_var: &root_decoder_var,
34        name_var: &tag_var,
35        trace: true,
36        trace_body: true,
37    };
38
39    let body = match &e.data {
40        BuildData::Struct(st) => decode_struct(&cx, &e, st)?,
41        BuildData::Enum(en) => decode_enum(&cx, &e, en)?,
42    };
43
44    if e.cx.has_errors() {
45        return Err(());
46    }
47
48    // Figure out which lifetime to use for what. We use the first lifetime in
49    // the type (if any is available) as the decoder lifetime. Else we generate
50    // a new anonymous lifetime `'de` to use for the `Decode` impl.
51    let mut generics = e.input.generics.clone();
52    let type_ident = &e.input.ident;
53
54    let (lt, exists) = if let Some(existing) = generics.lifetimes().next() {
55        (existing.clone(), true)
56    } else {
57        let lt = syn::LifetimeParam::new(syn::Lifetime::new("'de", e.input.span()));
58        (lt, false)
59    };
60
61    if !exists {
62        generics.params.push(lt.clone().into());
63    }
64
65    let Tokens {
66        context_t,
67        result,
68        decode_t,
69        decoder_t,
70        ..
71    } = e.tokens;
72
73    if !e.bounds.is_empty() && !e.decode_bounds.is_empty() {
74        generics.make_where_clause().predicates.extend(
75            e.bounds
76                .iter()
77                .chain(e.decode_bounds.iter())
78                .map(|(_, v)| v.clone()),
79        );
80    }
81
82    let (impl_generics, _, where_clause) = generics.split_for_impl();
83    let (_, type_generics, _) = e.input.generics.split_for_impl();
84
85    let mut attributes = Vec::<syn::Attribute>::new();
86
87    if cfg!(not(feature = "verbose")) {
88        attributes.push(syn::parse_quote!(#[allow(clippy::just_underscores_and_digits)]));
89    }
90
91    let mode_ident = e.expansion.mode_path(e.tokens).as_path();
92
93    Ok(quote! {
94        const _: () = {
95            #[automatically_derived]
96            #(#attributes)*
97            impl #impl_generics #decode_t<#lt, #mode_ident> for #type_ident #type_generics #where_clause {
98                #[inline]
99                fn decode<#d_param>(#ctx_var: &#d_param::Cx, #root_decoder_var: #d_param) -> #result<Self, <#d_param::Cx as #context_t>::Error>
100                where
101                    #d_param: #decoder_t<#lt, Mode = #mode_ident>,
102                {
103                    #body
104                }
105            }
106        };
107    })
108}
109
110fn decode_struct(cx: &Ctxt<'_>, b: &Build<'_>, st: &Body<'_>) -> Result<TokenStream> {
111    let Tokens { result_ok, .. } = b.tokens;
112
113    let body = match (st.kind, st.packing) {
114        (_, Packing::Transparent) => decode_transparent(cx, b, st)?,
115        (_, Packing::Packed) => decode_packed(cx, b, st)?,
116        (StructKind::Empty, _) => decode_empty(cx, b, st)?,
117        (_, Packing::Tagged) => decode_tagged(cx, b, st, None)?,
118    };
119
120    Ok(quote!(#result_ok({ #body })))
121}
122
123fn decode_enum(cx: &Ctxt<'_>, b: &Build<'_>, en: &Enum) -> Result<TokenStream> {
124    let Ctxt {
125        ctx_var,
126        name_var,
127        decoder_var,
128        ..
129    } = *cx;
130
131    let Tokens {
132        as_decoder_t,
133        context_t,
134        decoder_t,
135        fmt,
136        option_none,
137        option_some,
138        option,
139        result_err,
140        result_ok,
141        skip_field,
142        skip,
143        map_decoder_t,
144        struct_field_decoder_t,
145        map_hint,
146        variant_decoder_t,
147        ..
148    } = b.tokens;
149
150    if let Some(&(span, Packing::Packed)) = en.packing_span {
151        b.decode_packed_enum_diagnostics(span);
152        return Err(());
153    }
154
155    let type_name = en.name;
156
157    // Trying to decode an uninhabitable type.
158    if en.variants.is_empty() {
159        return Ok(quote!(#result_err(#context_t::uninhabitable(#ctx_var, #type_name))));
160    }
161
162    let binding_var = b.cx.ident("value");
163    let body_decoder_var = b.cx.ident("body_decoder");
164    let buffer_decoder_var = b.cx.ident("buffer_decoder");
165    let buffer_var = b.cx.ident("buffer");
166    let entry_var = b.cx.ident("entry");
167    let field_name = b.cx.ident("field_name");
168    let field_name_var = b.cx.ident("field_name");
169    let field_var = b.cx.ident("field");
170    let outcome_type = b.cx.type_with_span("Outcome", Span::call_site());
171    let buf_type = b.cx.type_with_span("B", Span::call_site());
172    let outcome_var = b.cx.ident("outcome");
173    let output_var = b.cx.ident("output");
174    let struct_decoder_var = b.cx.ident("struct_decoder");
175    let struct_hint_static = b.cx.ident("STRUCT_HINT");
176    let struct_var = b.cx.ident("st");
177    let value_var = b.cx.ident("value");
178    let variant_decoder_var = b.cx.ident("variant_decoder");
179    let variant_tag_var = b.cx.ident("variant_tag");
180    let tag_static = b.cx.ident("TAG");
181    let content_static = b.cx.ident("CONTENT");
182
183    let mut output_arms = Vec::new();
184
185    let mut fallback = match en.fallback {
186        Some(ident) => {
187            quote! {{
188                if #skip(#variant_decoder_t::decode_value(#variant_decoder_var)?)? {
189                    return #result_err(#context_t::invalid_variant_tag(#ctx_var, #type_name, &#variant_tag_var));
190                }
191
192                Self::#ident {}
193            }}
194        }
195        None => quote! {
196            return #result_err(#context_t::invalid_variant_tag(#ctx_var, #type_name, &#variant_tag_var))
197        },
198    };
199
200    let decode_name;
201    let output_enum;
202    let name_type;
203
204    match en.name_method {
205        NameMethod::Value => {
206            for v in &en.variants {
207                let arm = output_arm(v.pattern, &v.name, &binding_var);
208                output_arms.push((v, arm, &v.name));
209            }
210
211            let decode_t_decode = &b.decode_t_decode;
212
213            decode_name = quote!(#decode_t_decode(#ctx_var, #variant_decoder_var));
214            output_enum = None;
215            fallback = quote!(_ => #fallback);
216            name_type = en.name_type.clone();
217        }
218        NameMethod::Unsized(method) => {
219            let mut variants = Vec::new();
220            let output_type = b.cx.type_with_span("VariantTag", en.span);
221
222            for v in &en.variants {
223                let (pat, variant) =
224                    unsized_arm(b, v.span, v.index, &v.name, v.pattern, &output_type);
225
226                output_arms.push((v, OutputArm { pat, cond: None }, &v.name));
227                variants.push(variant);
228            }
229
230            let arms = variants.iter().map(|o| o.as_arm(&binding_var, option_some));
231
232            let visit_type = &en.name_type;
233            let method = method.as_method_name();
234
235            decode_name = quote! {
236                #decoder_t::#method(#variant_decoder_var, |#value_var: &#visit_type| {
237                    #result_ok(match #value_var {
238                        #(#arms,)*
239                        _ => #option_none,
240                    })
241                })
242            };
243
244            let fmt_patterns = variants.iter().map(|o| {
245                let variant = &o.variant;
246                let name = o.name;
247                quote!(#output_type::#variant => #fmt::Debug::fmt(&#name, f))
248            });
249
250            let fmt_patterns2 = variants.iter().map(|o| {
251                let variant = &o.variant;
252                let name = o.name;
253                quote!(#output_type::#variant => #fmt::Display::fmt(&#name, f))
254            });
255
256            let variants = variants.iter().map(|o| &o.variant);
257
258            output_enum = Some(quote! {
259                enum #output_type { #(#variants,)* }
260
261                impl #fmt::Debug for #output_type {
262                    #[inline]
263                    fn fmt(&self, f: &mut #fmt::Formatter<'_>) -> #fmt::Result {
264                        match *self { #(#fmt_patterns,)* }
265                    }
266                }
267
268                impl #fmt::Display for #output_type {
269                    #[inline]
270                    fn fmt(&self, f: &mut #fmt::Formatter<'_>) -> #fmt::Result {
271                        match *self { #(#fmt_patterns2,)* }
272                    }
273                }
274            });
275
276            fallback = quote!(#option_none => { #fallback });
277            name_type = syn::parse_quote!(#option<#output_type>);
278        }
279    }
280
281    match en.enum_tagging {
282        EnumTagging::Empty => {
283            let mut arms = Vec::new();
284
285            for v in &en.variants {
286                let path = &v.st.path;
287                let pat = output_arm(v.pattern, &v.name, &binding_var);
288                arms.push(quote!(#pat => #result_ok(#path {})));
289            }
290
291            match en.fallback {
292                Some(ident) => {
293                    arms.push(quote!(_ => #result_ok(Self::#ident {})));
294                }
295                None => {
296                    arms.push(quote!(#value_var => #result_err(#context_t::invalid_variant_tag(#ctx_var, #type_name, &#value_var))));
297                }
298            }
299
300            match en.name_method {
301                NameMethod::Value => {
302                    let decode_t_decode = &b.decode_t_decode;
303                    let name_type = &en.name_type;
304
305                    Ok(quote! {{
306                        let #value_var: #name_type = #decode_t_decode(#ctx_var, #decoder_var)?;
307
308                        match #value_var { #(#arms,)* }
309                    }})
310                }
311                NameMethod::Unsized(method) => {
312                    let method = method.as_method_name();
313                    let visit_type = &en.name_type;
314
315                    Ok(quote! {
316                        #decoder_t::#method(#decoder_var, |#value_var: &#visit_type| {
317                            match #value_var { #(#arms,)* }
318                        })
319                    })
320                }
321            }
322        }
323        EnumTagging::Default => {
324            let arms = output_arms.iter().flat_map(|(v, pat, tag_value)| {
325                let name = &v.st.name;
326
327                let decode = decode_variant(cx, b, v, &body_decoder_var, &variant_tag_var).ok()?;
328
329                let enter = cx.trace.then(|| {
330                    let (tag_decl, formatted_tag) = en.name_format(&tag_static, tag_value);
331
332                    quote! {
333                        #tag_decl
334                        #context_t::enter_variant(#ctx_var, #name, #formatted_tag);
335                    }
336                });
337
338                let leave = cx.trace.then(|| quote! {
339                    #context_t::leave_variant(#ctx_var);
340                });
341
342                Some(quote! {
343                    #pat => {
344                        #enter
345
346                        let #body_decoder_var = #variant_decoder_t::decode_value(#variant_decoder_var)?;
347                        let #output_var = #decode;
348
349                        #leave
350                        #output_var
351                    }
352                })
353            });
354
355            let enter = cx.trace.then(|| {
356                quote! {
357                    #context_t::enter_enum(#ctx_var, #type_name);
358                }
359            });
360
361            let leave = cx.trace.then(|| {
362                quote! {
363                    #context_t::leave_enum(#ctx_var);
364                }
365            });
366
367            Ok(quote! {{
368                #output_enum
369                #enter
370
371                let #output_var = #decoder_t::decode_variant(#decoder_var, move |#variant_decoder_var| {
372                    let #variant_tag_var: #name_type = {
373                        let mut #variant_decoder_var = #variant_decoder_t::decode_tag(#variant_decoder_var)?;
374                        #decode_name?
375                    };
376
377                    let #output_var = match #variant_tag_var {
378                        #(#arms,)*
379                        #fallback
380                    };
381
382                    #result_ok(#output_var)
383                })?;
384
385                #leave
386                Ok(#output_var)
387            }})
388        }
389        EnumTagging::Internal { tag } => {
390            let arms = output_arms.iter().flat_map(|(v, pat, tag_value)| {
391                let name = &v.st.name;
392
393                let decode =
394                    decode_variant(cx, b, v, &buffer_decoder_var, &variant_tag_var).ok()?;
395
396                let enter = cx.trace.then(|| {
397                    let (tag_decl, formatted_tag) = en.name_format(&tag_static, tag_value);
398
399                    quote! {
400                        #tag_decl
401                        #context_t::enter_variant(#ctx_var, #name, #formatted_tag);
402                    }
403                });
404
405                let leave = cx.trace.then(|| {
406                    quote! {
407                        #context_t::leave_variant(#ctx_var);
408                    }
409                });
410
411                Some(quote! {
412                    #pat => {
413                        #enter
414
415                        let #buffer_decoder_var = #as_decoder_t::as_decoder(&#buffer_var)?;
416                        let #output_var = #decode;
417
418                        #leave
419                        #output_var
420                    }
421                })
422            });
423
424            let outcome_enum;
425            let decode_match;
426
427            match en.name_method {
428                NameMethod::Value => {
429                    let decode_t_decode = &b.decode_t_decode;
430
431                    outcome_enum = None;
432
433                    let name_type = &en.name_type;
434                    let tag_arm = output_arm(None, tag, &binding_var);
435
436                    decode_match = quote! {
437                        let #value_var: #name_type = #decode_t_decode(#ctx_var, #field_name_var)?;
438
439                        match #value_var {
440                            #tag_arm => {
441                                break #struct_field_decoder_t::decode_value(#entry_var)?;
442                            }
443                            #field_var => {
444                                if #skip_field(#entry_var)? {
445                                    return #result_err(#context_t::invalid_field_tag(#ctx_var, #type_name, &#field_var));
446                                }
447                            }
448                        }
449                    };
450                }
451                NameMethod::Unsized(method) => {
452                    outcome_enum = Some(quote! {
453                        enum #outcome_type<#buf_type> { Tag, Skip(#buf_type) }
454                    });
455
456                    let visit_type = &en.name_type;
457                    let method = method.as_method_name();
458
459                    let tag_arm = output_arm(None, tag, &binding_var);
460
461                    let decode_outcome = quote! {
462                        #decoder_t::#method(#field_name_var, |#value_var: &#visit_type| {
463                            #result_ok(match #value_var {
464                                #tag_arm => #outcome_type::Tag,
465                                #value_var => {
466                                    #outcome_type::Skip(#context_t::collect_string(#ctx_var, #value_var)?)
467                                }
468                            })
469                        })?
470                    };
471
472                    decode_match = quote! {{
473                        let #field_name_var = #decode_outcome;
474
475                        match #field_name_var {
476                            #outcome_type::Tag => {
477                                break #struct_field_decoder_t::decode_value(#entry_var)?;
478                            }
479                            #outcome_type::Skip(#field_name) => {
480                                if #skip_field(#entry_var)? {
481                                    return #result_err(#context_t::invalid_field_string_tag(#ctx_var, #type_name, #field_name));
482                                }
483                            }
484                        }
485                    }};
486                }
487            };
488
489            let enter = cx.trace.then(|| {
490                quote! {
491                    #context_t::enter_enum(#ctx_var, #type_name);
492                }
493            });
494
495            let leave = cx.trace.then(|| {
496                quote! {
497                    #context_t::leave_enum(#ctx_var);
498                }
499            });
500
501            let static_type = en.static_type();
502
503            Ok(quote! {{
504                static #tag_static: #static_type = #tag;
505
506                #output_enum
507                #outcome_enum
508
509                #enter
510                let #buffer_var = #decoder_t::decode_buffer(#decoder_var)?;
511                let #struct_var = #as_decoder_t::as_decoder(&#buffer_var)?;
512
513                let #variant_tag_var: #name_type = #decoder_t::decode_map(#struct_var, |#struct_var| {
514                    let #variant_decoder_var = loop {
515                        let #option_some(mut #entry_var) = #map_decoder_t::decode_entry(#struct_var)? else {
516                            return #result_err(#context_t::missing_variant_field(#ctx_var, #type_name, &#tag_static));
517                        };
518
519                        let #field_name_var = #struct_field_decoder_t::decode_key(&mut #entry_var)?;
520
521                        #decode_match
522                    };
523
524                    #decode_name
525                })?;
526
527                let #output_var = match #variant_tag_var {
528                    #(#arms,)*
529                    #fallback
530                };
531
532                #leave
533                #result_ok(#output_var)
534            }})
535        }
536        EnumTagging::Adjacent { tag, content } => {
537            let arms = output_arms.iter().flat_map(|(v, pat, tag_value)| {
538                let name = &v.st.name;
539
540                let decode = decode_variant(cx, b, v, &body_decoder_var, &variant_tag_var).ok()?;
541
542                let enter = cx.trace.then(|| {
543                    let (tag_decl, formatted_tag) = en.name_format(&tag_static, tag_value);
544
545                    quote! {
546                        #tag_decl
547                        #context_t::enter_variant(#ctx_var, #name, #formatted_tag);
548                    }
549                });
550
551                let leave = cx.trace.then(|| {
552                    quote! {
553                        #context_t::leave_variant(#ctx_var);
554                    }
555                });
556
557                Some(quote! {
558                    #pat => {
559                        #enter
560                        let #output_var = #decode;
561                        #leave
562                        #output_var
563                    }
564                })
565            });
566
567            let decode_t_decode = &b.decode_t_decode;
568
569            let outcome_enum;
570            let decode_match;
571
572            match en.name_method {
573                NameMethod::Value => {
574                    outcome_enum = None;
575
576                    let name_type = &en.name_type;
577                    let tag_arm = output_arm(None, tag, &binding_var);
578                    let content_arm = output_arm(None, content, &binding_var);
579
580                    decode_match = quote! {
581                        let #value_var: #name_type = #decode_t_decode(#ctx_var, #field_name_var)?;
582
583                        match #value_var {
584                            #tag_arm => {
585                                let #variant_decoder_var = #struct_field_decoder_t::decode_value(#entry_var)?;
586                                let #variant_tag_var: #name_type = #decode_name?;
587                                #name_var = #option_some(#variant_tag_var);
588                            }
589                            #content_arm => {
590                                let #option_some(#variant_tag_var) = #name_var else {
591                                    return #result_err(#context_t::missing_adjacent_tag(#ctx_var, #type_name, &#content));
592                                };
593
594                                let #body_decoder_var = #struct_field_decoder_t::decode_value(#entry_var)?;
595
596                                break #result_ok(match #variant_tag_var {
597                                    #(#arms,)*
598                                    #fallback
599                                });
600                            }
601                            #field_var => {
602                                if #skip_field(#entry_var)? {
603                                    return #result_err(#context_t::invalid_field_tag(#ctx_var, #type_name, &#field_var));
604                                }
605                            }
606                        }
607                    };
608                }
609                NameMethod::Unsized(method) => {
610                    outcome_enum = Some(quote! {
611                        enum #outcome_type<#buf_type> { Tag, Content, Skip(#buf_type) }
612                    });
613
614                    let visit_type = &en.name_type;
615                    let method = method.as_method_name();
616
617                    let tag_arm = output_arm(None, tag, &binding_var);
618                    let content_arm = output_arm(None, content, &binding_var);
619
620                    decode_match = quote! {
621                        let #outcome_var = #decoder_t::#method(#field_name_var, |#value_var: &#visit_type| {
622                            #result_ok(match #value_var {
623                                #tag_arm => #outcome_type::Tag,
624                                #content_arm => #outcome_type::Content,
625                                #value_var => {
626                                    #outcome_type::Skip(#context_t::collect_string(#ctx_var, #value_var)?)
627                                }
628                            })
629                        })?;
630
631                        match #outcome_var {
632                            #outcome_type::Tag => {
633                                let #variant_decoder_var = #struct_field_decoder_t::decode_value(#entry_var)?;
634                                #name_var = #option_some(#decode_name?);
635                            }
636                            #outcome_type::Content => {
637                                let #option_some(#variant_tag_var) = #name_var else {
638                                    return #result_err(#context_t::invalid_field_tag(#ctx_var, #type_name, &#tag));
639                                };
640
641                                let #body_decoder_var = #struct_field_decoder_t::decode_value(#entry_var)?;
642
643                                break #result_ok(match #variant_tag_var {
644                                    #(#arms,)*
645                                    #fallback
646                                });
647                            }
648                            #outcome_type::Skip(#field_name) => {
649                                if #skip_field(#entry_var)? {
650                                    return #result_err(#context_t::invalid_field_string_tag(#ctx_var, #type_name, #field_name));
651                                }
652                            }
653                        }
654                    };
655                }
656            };
657
658            let enter = cx.trace.then(|| {
659                quote! {
660                    #context_t::enter_enum(#ctx_var, #type_name);
661                }
662            });
663
664            let leave = cx.trace.then(|| {
665                quote! {
666                    #context_t::leave_enum(#ctx_var);
667                }
668            });
669
670            let static_type = en.static_type();
671
672            Ok(quote! {{
673                static #tag_static: #static_type = #tag;
674                static #content_static: #static_type = #content;
675
676                #output_enum
677                #outcome_enum
678
679                static #struct_hint_static: #map_hint = #map_hint::with_size(2);
680
681                #enter
682
683                #decoder_t::decode_map_hint(#decoder_var, &#struct_hint_static, move |#struct_decoder_var| {
684                    let mut #name_var = #option_none;
685
686                    let #output_var = loop {
687                        let #option_some(mut #entry_var) = #map_decoder_t::decode_entry(#struct_decoder_var)? else {
688                            return #result_err(#context_t::expected_field_adjacent(#ctx_var, #type_name, &#tag_static, &#content_static));
689                        };
690
691                        let #field_name_var = #struct_field_decoder_t::decode_key(&mut #entry_var)?;
692
693                        #decode_match
694                    };
695
696                    #leave
697                    #result_ok(#output_var)
698                })?
699            }})
700        }
701    }
702}
703
704fn decode_variant(
705    cx: &Ctxt<'_>,
706    b: &Build,
707    v: &Variant<'_>,
708    decoder_var: &Ident,
709    variant_tag: &Ident,
710) -> Result<TokenStream, ()> {
711    let cx = Ctxt {
712        decoder_var,
713        trace_body: false,
714        ..*cx
715    };
716
717    Ok(match (v.st.kind, v.st.packing) {
718        (_, Packing::Transparent) => decode_transparent(&cx, b, &v.st)?,
719        (_, Packing::Packed) => decode_packed(&cx, b, &v.st)?,
720        (StructKind::Empty, _) => decode_empty(&cx, b, &v.st)?,
721        (_, Packing::Tagged) => decode_tagged(&cx, b, &v.st, Some(variant_tag))?,
722    })
723}
724
725/// Decode something empty.
726fn decode_empty(cx: &Ctxt, b: &Build<'_>, st: &Body<'_>) -> Result<TokenStream> {
727    let Ctxt {
728        ctx_var,
729        decoder_var,
730        ..
731    } = *cx;
732
733    let Tokens {
734        context_t,
735        decoder_t,
736        result_ok,
737        map_hint,
738        ..
739    } = b.tokens;
740
741    let Body { path, name, .. } = st;
742
743    let output_var = b.cx.ident("output");
744    let struct_hint_static = b.cx.ident("STRUCT_HINT");
745
746    let enter = (cx.trace && cx.trace_body).then(|| {
747        quote! {
748            #context_t::enter_struct(#ctx_var, #name);
749        }
750    });
751
752    let leave = (cx.trace && cx.trace_body).then(|| {
753        quote! {
754            #context_t::leave_struct(#ctx_var);
755        }
756    });
757
758    Ok(quote! {{
759        #enter
760        static #struct_hint_static: #map_hint = #map_hint::with_size(0);
761        let #output_var = #decoder_t::decode_map_hint(#decoder_var, &#struct_hint_static, |_| #result_ok(()))?;
762        #leave
763        #path
764    }})
765}
766
767/// Decode something tagged.
768///
769/// If `variant_name` is specified it implies that a tagged enum is being
770/// decoded.
771fn decode_tagged(
772    cx: &Ctxt,
773    b: &Build<'_>,
774    st: &Body<'_>,
775    variant_tag: Option<&Ident>,
776) -> Result<TokenStream> {
777    let Ctxt {
778        ctx_var,
779        decoder_var,
780        name_var,
781        ..
782    } = *cx;
783
784    let Tokens {
785        context_t,
786        decoder_t,
787        default_function,
788        fmt,
789        option_none,
790        option_some,
791        option,
792        result_err,
793        result_ok,
794        skip_field,
795        map_decoder_t,
796        struct_field_decoder_t,
797        map_hint,
798        ..
799    } = b.tokens;
800
801    let struct_decoder_var = b.cx.ident("struct_decoder");
802    let struct_hint_static = b.cx.ident("STRUCT_HINT");
803    let type_decoder_var = b.cx.ident("type_decoder");
804    let value_var = b.cx.ident("value");
805    let binding_var = b.cx.ident("value");
806
807    let type_name = &st.name;
808
809    let mut assigns = Punctuated::<_, Token![,]>::new();
810
811    let mut fields_with = Vec::new();
812
813    for f in &st.all_fields {
814        let tag = &f.name;
815        let var = &f.var;
816        let decode_path = &f.decode_path.1;
817
818        let expr = match &f.skip {
819            Some(span) => {
820                let ty = f.ty;
821
822                match &f.default_attr {
823                    Some((_, Some(path))) => syn::Expr::Verbatim(quote_spanned!(*span => #path())),
824                    _ => syn::Expr::Verbatim(quote_spanned!(*span => #default_function::<#ty>())),
825                }
826            }
827            None => {
828                let formatted_tag = match &st.name_format_with {
829                    Some((_, path)) => quote!(&#path(&#tag)),
830                    None => quote!(&#tag),
831                };
832
833                let enter = cx.trace.then(|| {
834                    let (name, enter) = match &f.member {
835                        syn::Member::Named(name) => (
836                            syn::Lit::Str(syn::LitStr::new(&name.to_string(), name.span())),
837                            Ident::new("enter_named_field", Span::call_site()),
838                        ),
839                        syn::Member::Unnamed(index) => (
840                            syn::Lit::Int(syn::LitInt::from(Literal::u32_suffixed(index.index))),
841                            Ident::new("enter_unnamed_field", Span::call_site()),
842                        ),
843                    };
844
845                    quote! {
846                        #context_t::#enter(#ctx_var, #name, #formatted_tag);
847                    }
848                });
849
850                let leave = cx.trace.then(|| {
851                    quote! {
852                        #context_t::leave_field(#ctx_var);
853                    }
854                });
855
856                let decode = quote! {
857                    #var = #option_some(#decode_path(#ctx_var, #struct_decoder_var)?);
858                };
859
860                fields_with.push((f, decode, (enter, leave)));
861
862                let fallback = match f.default_attr {
863                    Some((span, None)) => quote_spanned!(span => #default_function()),
864                    Some((_, Some(path))) => quote!(#path()),
865                    None => quote! {
866                        return #result_err(#context_t::expected_tag(#ctx_var, #type_name, &#tag))
867                    },
868                };
869
870                let var = &f.var;
871
872                syn::Expr::Verbatim(quote! {
873                    match #var {
874                        #option_some(#var) => #var,
875                        #option_none => #fallback,
876                    }
877                })
878            }
879        };
880
881        assigns.push(syn::FieldValue {
882            attrs: Vec::new(),
883            member: f.member.clone(),
884            colon_token: Some(<Token![:]>::default()),
885            expr,
886        });
887    }
888
889    let decode_tag;
890    let mut output_enum = quote!();
891
892    let unsupported = match variant_tag {
893        Some(variant_tag) => quote! {
894            #context_t::invalid_variant_field_tag(#ctx_var, #type_name, &#variant_tag, &#name_var)
895        },
896        None => quote! {
897            #context_t::invalid_field_tag(#ctx_var, #type_name, &#name_var)
898        },
899    };
900
901    let skip_field = quote! {
902        if #skip_field(#struct_decoder_var)? {
903            return #result_err(#unsupported);
904        }
905    };
906
907    let body;
908    let name_type: syn::Type;
909
910    match st.name_method {
911        NameMethod::Value => {
912            let mut arms = Vec::with_capacity(fields_with.len());
913
914            for (f, decode, (enter, leave)) in fields_with {
915                let arm = output_arm(f.pattern, &f.name, &binding_var);
916
917                arms.push(quote! {
918                    #arm => {
919                        #enter
920                        let #struct_decoder_var = #struct_field_decoder_t::decode_value(#struct_decoder_var)?;
921                        #decode
922                        #leave
923                    }
924                });
925            }
926
927            body = quote!(match #name_var { #(#arms,)* _ => { #skip_field } });
928
929            let decode_t_decode = &b.decode_t_decode;
930
931            decode_tag = quote! {
932                #decode_t_decode(#ctx_var, #struct_decoder_var)?
933            };
934
935            name_type = st.name_type.clone();
936        }
937        NameMethod::Unsized(method) => {
938            let output_type =
939                b.cx.type_with_span("TagVisitorOutput", b.input.ident.span());
940
941            let mut outputs = Vec::with_capacity(fields_with.len());
942            let mut name_arms = Vec::with_capacity(fields_with.len());
943
944            for (f, decode, trace) in fields_with {
945                let (name_pat, name_variant) =
946                    unsized_arm(b, f.span, f.index, &f.name, f.pattern, &output_type);
947
948                outputs.push(name_variant);
949                name_arms.push((name_pat, decode, trace));
950            }
951
952            if !name_arms.is_empty() {
953                let arms = name_arms
954                    .into_iter()
955                    .map(|(name_pat, decode, (enter, leave))| {
956                        quote! {
957                            #name_pat => {
958                                #enter
959                                let #struct_decoder_var = #struct_field_decoder_t::decode_value(#struct_decoder_var)?;
960                                #decode
961                                #leave
962                            }
963                        }
964                    });
965
966                body = quote! {
967                    match #name_var { #(#arms,)* #name_var => { #skip_field } }
968                }
969            } else {
970                body = skip_field;
971            }
972
973            let arms = outputs.iter().map(|o| o.as_arm(&binding_var, option_some));
974
975            let visit_type = &st.name_type;
976            let method = method.as_method_name();
977
978            decode_tag = quote! {
979                #decoder_t::#method(#struct_decoder_var, |#value_var: &#visit_type| {
980                    #result_ok(match #value_var {
981                        #(#arms,)*
982                        #value_var => {
983                            #option_none
984                        }
985                    })
986                })?
987            };
988
989            let variants = outputs.iter().map(|o| &o.variant);
990
991            let fmt_patterns = outputs.iter().map(|o| {
992                let variant = &o.variant;
993                let tag = o.name;
994                quote!(#output_type::#variant => #fmt::Debug::fmt(&#tag, f))
995            });
996
997            output_enum = quote! {
998                enum #output_type {
999                    #(#variants,)*
1000                }
1001
1002                impl #fmt::Debug for #output_type {
1003                    #[inline]
1004                    fn fmt(&self, f: &mut #fmt::Formatter<'_>) -> #fmt::Result {
1005                        match *self { #(#fmt_patterns,)* }
1006                    }
1007                }
1008            };
1009
1010            name_type = syn::parse_quote!(#option<#output_type>);
1011        }
1012    }
1013
1014    let path = &st.path;
1015    let fields_len = st.unskipped_fields.len();
1016
1017    let decls = st
1018        .unskipped_fields
1019        .iter()
1020        .map(|f| &**f)
1021        .map(|Field { var, ty, .. }| quote!(let mut #var: #option<#ty> = #option_none;));
1022
1023    let enter = (cx.trace && cx.trace_body).then(|| {
1024        quote! {
1025            #context_t::enter_struct(#ctx_var, #type_name);
1026        }
1027    });
1028
1029    let leave = (cx.trace && cx.trace_body).then(|| {
1030        quote! {
1031            #context_t::leave_struct(#ctx_var);
1032        }
1033    });
1034
1035    Ok(quote! {{
1036        #output_enum
1037        #(#decls)*
1038
1039        #enter
1040
1041        static #struct_hint_static: #map_hint = #map_hint::with_size(#fields_len);
1042
1043        #decoder_t::decode_map_hint(#decoder_var, &#struct_hint_static, move |#type_decoder_var| {
1044            while let #option_some(mut #struct_decoder_var) = #map_decoder_t::decode_entry(#type_decoder_var)? {
1045                let #name_var: #name_type = {
1046                    let #struct_decoder_var = #struct_field_decoder_t::decode_key(&mut #struct_decoder_var)?;
1047                    #decode_tag
1048                };
1049
1050                #body
1051            }
1052
1053            #leave
1054            #result_ok(#path { #assigns })
1055        })?
1056    }})
1057}
1058
1059/// Decode a transparent value.
1060fn decode_transparent(cx: &Ctxt<'_>, b: &Build<'_>, st: &Body<'_>) -> Result<TokenStream> {
1061    let Ctxt {
1062        decoder_var,
1063        ctx_var,
1064        ..
1065    } = *cx;
1066
1067    let output_var = b.cx.ident("output");
1068
1069    let Tokens { context_t, .. } = b.tokens;
1070
1071    let f = &st.unskipped_fields[0];
1072
1073    let type_name = &st.name;
1074    let path = &st.path;
1075    let decode_path = &f.decode_path.1;
1076    let member = &f.member;
1077
1078    let enter = (cx.trace && cx.trace_body).then(|| {
1079        quote! {
1080            #context_t::enter_struct(#ctx_var, #type_name);
1081        }
1082    });
1083
1084    let leave = (cx.trace && cx.trace_body).then(|| {
1085        quote! {
1086            #context_t::leave_struct(#ctx_var);
1087        }
1088    });
1089
1090    Ok(quote! {{
1091        #enter
1092
1093        let #output_var = #path {
1094            #member: #decode_path(#ctx_var, #decoder_var)?
1095        };
1096
1097        #leave
1098        #output_var
1099    }})
1100}
1101
1102/// Decode something packed.
1103fn decode_packed(cx: &Ctxt<'_>, b: &Build<'_>, st_: &Body<'_>) -> Result<TokenStream> {
1104    let Ctxt {
1105        decoder_var,
1106        ctx_var,
1107        ..
1108    } = *cx;
1109
1110    let Tokens {
1111        context_t,
1112        decoder_t,
1113        pack_decoder_t,
1114        ..
1115    } = b.tokens;
1116
1117    let type_name = &st_.name;
1118    let output_var = b.cx.ident("output");
1119    let field_decoder = b.cx.ident("field_decoder");
1120
1121    let mut assign = Vec::new();
1122
1123    for f in &st_.unskipped_fields {
1124        if let Some((span, _)) = f.default_attr {
1125            b.packed_default_diagnostics(span);
1126        }
1127
1128        let (_, decode_path) = &f.decode_path;
1129        let member = &f.member;
1130        let field_decoder = &field_decoder;
1131
1132        assign.push(move |ident: &syn::Ident, tokens: &mut TokenStream| {
1133            tokens.extend(quote! {
1134                #member: {
1135                    let #field_decoder = #pack_decoder_t::decode_next(#ident)?;
1136                    #decode_path(#ctx_var, #field_decoder)?
1137                }
1138            })
1139        });
1140    }
1141
1142    let enter = (cx.trace && cx.trace_body).then(|| {
1143        quote! {
1144            #context_t::enter_struct(#ctx_var, #type_name);
1145        }
1146    });
1147
1148    let leave = (cx.trace && cx.trace_body).then(|| {
1149        quote! {
1150            #context_t::leave_struct(#ctx_var);
1151        }
1152    });
1153
1154    let pack = b.cx.ident("pack");
1155    let assign = apply::iter(assign, &pack);
1156    let path = &st_.path;
1157
1158    Ok(quote! {{
1159        #enter
1160
1161        let #output_var = #decoder_t::decode_pack(#decoder_var, move |#pack| {
1162            Ok(#path { #(#assign),* })
1163        })?;
1164
1165        #leave
1166        #output_var
1167    }})
1168}
1169
1170/// Output type used when indirectly encoding a variant or field as type which
1171/// might require special handling. Like a string.
1172pub(crate) struct NameVariant<'a> {
1173    /// The path of the variant this output should generate.
1174    path: syn::Path,
1175    /// The identified of the variant this path generates.
1176    variant: Ident,
1177    /// The tag this variant corresponds to.
1178    name: &'a syn::Expr,
1179    /// The pattern being matched.
1180    pattern: Option<&'a syn::Pat>,
1181}
1182
1183impl NameVariant<'_> {
1184    /// Generate the pattern for this output.
1185    pub(crate) fn as_arm(&self, binding_var: &syn::Ident, option_some: &syn::Path) -> syn::Arm {
1186        let body = syn::Expr::Path(syn::ExprPath {
1187            attrs: Vec::new(),
1188            qself: None,
1189            path: self.path.clone(),
1190        });
1191
1192        let arm = output_arm(self.pattern, self.name, binding_var);
1193
1194        syn::Arm {
1195            attrs: Vec::new(),
1196            pat: arm.pat,
1197            guard: arm.cond.map(|_| {
1198                let name = self.name;
1199
1200                (
1201                    <syn::Token![if]>::default(),
1202                    syn::parse_quote!(*#binding_var == #name),
1203                )
1204            }),
1205            fat_arrow_token: <Token![=>]>::default(),
1206            body: Box::new(build_call(option_some, [body])),
1207            comma: None,
1208        }
1209    }
1210}
1211
1212pub(crate) fn build_call<A>(path: &syn::Path, it: A) -> syn::Expr
1213where
1214    A: IntoIterator<Item = syn::Expr>,
1215{
1216    let mut args = Punctuated::default();
1217
1218    for arg in it {
1219        args.push(arg);
1220    }
1221
1222    syn::Expr::Call(syn::ExprCall {
1223        attrs: Vec::new(),
1224        func: Box::new(syn::Expr::Path(syn::ExprPath {
1225            attrs: Vec::new(),
1226            qself: None,
1227            path: path.clone(),
1228        })),
1229        paren_token: syn::token::Paren::default(),
1230        args,
1231    })
1232}
1233
1234pub(crate) fn build_reference(expr: syn::Expr) -> syn::Expr {
1235    syn::Expr::Reference(syn::ExprReference {
1236        attrs: Vec::new(),
1237        and_token: <Token![&]>::default(),
1238        mutability: None,
1239        expr: Box::new(expr),
1240    })
1241}
1242
1243fn unsized_arm<'a>(
1244    b: &Build<'_>,
1245    span: Span,
1246    index: usize,
1247    name: &'a syn::Expr,
1248    pattern: Option<&'a syn::Pat>,
1249    output: &Ident,
1250) -> (syn::Pat, NameVariant<'a>) {
1251    let variant = b.cx.type_with_span(format_args!("Variant{}", index), span);
1252
1253    let mut path = syn::Path::from(output.clone());
1254    path.segments.push(syn::PathSegment::from(variant.clone()));
1255
1256    let output = NameVariant {
1257        path: path.clone(),
1258        variant,
1259        name,
1260        pattern,
1261    };
1262
1263    let option_some = &b.tokens.option_some;
1264    (syn::parse_quote!(#option_some(#path)), output)
1265}
1266
1267struct Condition<'a> {
1268    if_: syn::Token![if],
1269    star: syn::Token![*],
1270    ident: &'a syn::Ident,
1271    equals: syn::Token![==],
1272    expr: &'a syn::Expr,
1273}
1274
1275impl ToTokens for Condition<'_> {
1276    fn to_tokens(&self, tokens: &mut TokenStream) {
1277        self.if_.to_tokens(tokens);
1278        self.star.to_tokens(tokens);
1279        self.ident.to_tokens(tokens);
1280        self.equals.to_tokens(tokens);
1281        self.expr.to_tokens(tokens);
1282    }
1283}
1284
1285fn condition<'a>(ident: &'a syn::Ident, expr: &'a syn::Expr) -> Condition<'a> {
1286    Condition {
1287        if_: <syn::Token![if]>::default(),
1288        star: <syn::Token![*]>::default(),
1289        ident,
1290        equals: <syn::Token![==]>::default(),
1291        expr,
1292    }
1293}
1294
1295fn ref_pattern(ident: &syn::Ident) -> syn::Pat {
1296    syn::Pat::Ident(syn::PatIdent {
1297        attrs: Vec::new(),
1298        by_ref: Some(<syn::Token![ref]>::default()),
1299        mutability: None,
1300        ident: ident.clone(),
1301        subpat: None,
1302    })
1303}
1304
1305fn output_arm<'a>(
1306    pat: Option<&'a syn::Pat>,
1307    name: &'a syn::Expr,
1308    binding: &'a syn::Ident,
1309) -> OutputArm<'a> {
1310    if let Some(pat) = pat {
1311        return OutputArm {
1312            pat: pat.clone(),
1313            cond: None,
1314        };
1315    }
1316
1317    if let Some(pat) = expr_to_pat(name) {
1318        return OutputArm { pat, cond: None };
1319    }
1320
1321    OutputArm {
1322        pat: ref_pattern(binding),
1323        cond: Some(condition(binding, name)),
1324    }
1325}
1326
1327fn expr_to_pat(expr: &syn::Expr) -> Option<syn::Pat> {
1328    match expr {
1329        syn::Expr::Lit(lit) => {
1330            let pat = syn::Pat::Lit(syn::PatLit {
1331                attrs: Vec::new(),
1332                lit: lit.lit.clone(),
1333            });
1334
1335            Some(pat)
1336        }
1337        syn::Expr::Array(expr) => {
1338            let mut elems = Punctuated::new();
1339
1340            for e in &expr.elems {
1341                elems.push(expr_to_pat(e)?);
1342            }
1343
1344            Some(syn::Pat::Slice(syn::PatSlice {
1345                attrs: Vec::new(),
1346                bracket_token: expr.bracket_token,
1347                elems,
1348            }))
1349        }
1350        _ => None,
1351    }
1352}
1353
1354struct OutputArm<'a> {
1355    pat: syn::Pat,
1356    cond: Option<Condition<'a>>,
1357}
1358
1359impl ToTokens for OutputArm<'_> {
1360    fn to_tokens(&self, tokens: &mut TokenStream) {
1361        self.pat.to_tokens(tokens);
1362
1363        if let Some(cond) = &self.cond {
1364            cond.to_tokens(tokens);
1365        }
1366    }
1367}