musli_macros/
en.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::Token;
5
6use crate::internals::attr::{EnumTagging, Packing};
7use crate::internals::build::{Body, Build, BuildData, Enum, Variant};
8use crate::internals::tokens::Tokens;
9use crate::internals::Result;
10
11struct Ctxt<'a> {
12    ctx_var: &'a syn::Ident,
13    encoder_var: &'a syn::Ident,
14    trace: bool,
15}
16
17pub(crate) fn expand_insert_entry(e: Build<'_>) -> Result<TokenStream> {
18    e.validate_encode()?;
19    e.cx.reset();
20
21    let type_ident = &e.input.ident;
22
23    let encoder_var = e.cx.ident("encoder");
24    let ctx_var = e.cx.ident("ctx");
25    let e_param = e.cx.type_with_span("E", Span::call_site());
26
27    let cx = Ctxt {
28        ctx_var: &ctx_var,
29        encoder_var: &encoder_var,
30        trace: true,
31    };
32
33    let Tokens {
34        encode_t,
35        encoder_t,
36        result,
37        ..
38    } = e.tokens;
39
40    let body = match &e.data {
41        BuildData::Struct(st) => encode_map(&cx, &e, st)?,
42        BuildData::Enum(en) => encode_enum(&cx, &e, en)?,
43    };
44
45    if e.cx.has_errors() {
46        return Err(());
47    }
48
49    let mut impl_generics = e.input.generics.clone();
50
51    if !e.bounds.is_empty() {
52        let where_clause = impl_generics.make_where_clause();
53
54        where_clause
55            .predicates
56            .extend(e.bounds.iter().map(|(_, v)| v.clone()));
57    }
58
59    let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
60    let (_, type_generics, _) = e.input.generics.split_for_impl();
61
62    let mut attributes = Vec::<syn::Attribute>::new();
63
64    if cfg!(not(feature = "verbose")) {
65        attributes.push(syn::parse_quote!(#[allow(clippy::just_underscores_and_digits)]));
66    }
67
68    let mode_ident = e.expansion.mode_path(e.tokens).as_path();
69
70    Ok(quote! {
71        const _: () = {
72            #[automatically_derived]
73            #(#attributes)*
74            impl #impl_generics #encode_t<#mode_ident> for #type_ident #type_generics #where_clause {
75                #[inline]
76                fn encode<#e_param>(&self, #ctx_var: &#e_param::Cx, #encoder_var: #e_param) -> #result<<#e_param as #encoder_t>::Ok, <#e_param as #encoder_t>::Error>
77                where
78                    #e_param: #encoder_t<Mode = #mode_ident>,
79                {
80                    #body
81                }
82            }
83        };
84    })
85}
86
87/// Encode a struct.
88fn encode_map(cx: &Ctxt<'_>, b: &Build<'_>, st: &Body<'_>) -> Result<TokenStream> {
89    let Ctxt {
90        ctx_var,
91        encoder_var,
92        ..
93    } = *cx;
94
95    let Tokens {
96        context_t,
97        encoder_t,
98        result_ok,
99        ..
100    } = b.tokens;
101
102    let pack_var = b.cx.ident("pack");
103    let output_var = b.cx.ident("output");
104
105    let (encoders, tests) = insert_fields(cx, b, st, &pack_var)?;
106
107    let type_name = &st.name;
108
109    let enter = cx
110        .trace
111        .then(|| quote!(#context_t::enter_struct(#ctx_var, #type_name);));
112    let leave = cx
113        .trace
114        .then(|| quote!(#context_t::leave_struct(#ctx_var);));
115
116    let encode;
117
118    match st.packing {
119        Packing::Transparent => {
120            let f = &st.unskipped_fields[0];
121
122            let access = &f.self_access;
123            let encode_path = &f.encode_path.1;
124
125            encode = quote! {{
126                #enter
127                let #output_var = #encode_path(#access, #ctx_var, #encoder_var)?;
128                #leave
129                #output_var
130            }};
131        }
132        Packing::Tagged => {
133            let decls = tests.iter().map(|t| &t.decl);
134            let (build_hint, hint) = length_test(st.unskipped_fields.len(), &tests).build(b);
135
136            encode = quote! {{
137                #enter
138                #(#decls)*
139                #build_hint
140
141                let #output_var = #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
142                    #(#encoders)*
143                    #result_ok(())
144                })?;
145                #leave
146                #output_var
147            }};
148        }
149        Packing::Packed => {
150            let decls = tests.iter().map(|t| &t.decl);
151
152            encode = quote! {{
153                #enter
154                let #output_var = #encoder_t::encode_pack_fn(#encoder_var, move |#pack_var| {
155                    #(#decls)*
156                    #(#encoders)*
157                    #result_ok(())
158                })?;
159                #leave
160                #output_var
161            }};
162        }
163    }
164
165    Ok(quote!(#result_ok(#encode)))
166}
167
168struct FieldTest<'st> {
169    decl: syn::Stmt,
170    var: &'st syn::Ident,
171}
172
173fn insert_fields<'st>(
174    cx: &Ctxt<'_>,
175    b: &Build<'_>,
176    st: &'st Body<'_>,
177    pack_var: &syn::Ident,
178) -> Result<(Vec<TokenStream>, Vec<FieldTest<'st>>)> {
179    let Ctxt {
180        ctx_var,
181        encoder_var,
182        ..
183    } = *cx;
184
185    let Tokens {
186        context_t,
187        sequence_encoder_t,
188        result_ok,
189
190        map_encoder_t,
191        map_entry_encoder_t,
192        ..
193    } = b.tokens;
194
195    let encode_t_encode = &b.encode_t_encode;
196
197    let sequence_decoder_next_var = b.cx.ident("sequence_decoder_next");
198    let pair_encoder_var = b.cx.ident("pair_encoder");
199    let field_encoder_var = b.cx.ident("field_encoder");
200    let value_encoder_var = b.cx.ident("value_encoder");
201    let field_name_static = b.cx.ident("FIELD_NAME");
202
203    let mut encoders = Vec::with_capacity(st.all_fields.len());
204    let mut tests = Vec::with_capacity(st.all_fields.len());
205
206    for f in &st.unskipped_fields {
207        let encode_path = &f.encode_path.1;
208        let access = &f.self_access;
209        let name = &f.name;
210        let name_type = st.name_local_type();
211
212        let mut encode;
213
214        let enter = match &f.member {
215            syn::Member::Named(ident) => {
216                let field_name = syn::LitStr::new(&ident.to_string(), ident.span());
217
218                cx.trace.then(|| {
219                    let name = st.name_format(name);
220                    quote!(#context_t::enter_named_field(#ctx_var, #field_name, #name);)
221                })
222            }
223            syn::Member::Unnamed(index) => {
224                let index = index.index;
225                cx.trace.then(|| {
226                    let name = st.name_format(name);
227                    quote!(#context_t::enter_unnamed_field(#ctx_var, #index, #name);)
228                })
229            }
230        };
231
232        let leave = cx.trace.then(|| quote!(#context_t::leave_field(#ctx_var);));
233
234        match f.packing {
235            Packing::Tagged | Packing::Transparent => {
236                encode = quote! {
237                    #enter
238
239                    #map_encoder_t::encode_entry_fn(#encoder_var, move |#pair_encoder_var| {
240                        static #field_name_static: #name_type = #name;
241                        let #field_encoder_var = #map_entry_encoder_t::encode_key(#pair_encoder_var)?;
242                        #encode_t_encode(&#field_name_static, #ctx_var, #field_encoder_var)?;
243                        let #value_encoder_var = #map_entry_encoder_t::encode_value(#pair_encoder_var)?;
244                        #encode_path(#access, #ctx_var, #value_encoder_var)?;
245                        #result_ok(())
246                    })?;
247
248                    #leave
249                };
250            }
251            Packing::Packed => {
252                encode = quote! {
253                    #enter
254                    let #sequence_decoder_next_var = #sequence_encoder_t::encode_next(#pack_var)?;
255                    #encode_path(#access, #ctx_var, #sequence_decoder_next_var)?;
256                    #leave
257                };
258            }
259        };
260
261        if let Some((_, skip_encoding_if_path)) = f.skip_encoding_if.as_ref() {
262            let var = &f.var;
263
264            let decl = syn::parse_quote! {
265                let #var = !#skip_encoding_if_path(#access);
266            };
267
268            encode = quote! {
269                if #var {
270                    #encode
271                }
272            };
273
274            tests.push(FieldTest { decl, var })
275        }
276
277        encoders.push(encode);
278    }
279
280    Ok((encoders, tests))
281}
282
283/// Encode an internally tagged enum.
284fn encode_enum(cx: &Ctxt<'_>, b: &Build<'_>, en: &Enum<'_>) -> Result<TokenStream> {
285    let Ctxt { ctx_var, .. } = *cx;
286
287    let Tokens {
288        context_t,
289        result_ok,
290        result_err,
291        ..
292    } = b.tokens;
293
294    let type_name = en.name;
295
296    if let Some(&(span, Packing::Transparent)) = en.packing_span {
297        b.encode_transparent_enum_diagnostics(span);
298        return Err(());
299    }
300
301    let mut variants = Vec::with_capacity(en.variants.len());
302
303    for v in &en.variants {
304        let Ok((pattern, encode)) = encode_variant(cx, b, en, v) else {
305            continue;
306        };
307
308        variants.push(quote!(#pattern => #encode));
309    }
310
311    // Special case: uninhabitable types.
312    Ok(if variants.is_empty() {
313        quote!(#result_err(#context_t::uninhabitable(#ctx_var, #type_name)))
314    } else {
315        quote!(#result_ok(match self { #(#variants),* }))
316    })
317}
318
319/// Setup encoding for a single variant. that is externally tagged.
320fn encode_variant(
321    cx: &Ctxt<'_>,
322    b: &Build<'_>,
323    en: &Enum<'_>,
324    v: &Variant<'_>,
325) -> Result<(syn::PatStruct, TokenStream)> {
326    let pack_var = b.cx.ident("pack");
327
328    let (encoders, tests) = insert_fields(cx, b, &v.st, &pack_var)?;
329
330    let Ctxt {
331        ctx_var,
332        encoder_var,
333        ..
334    } = *cx;
335
336    let Tokens {
337        context_t,
338        encoder_t,
339        result_ok,
340        map_encoder_t,
341        map_entry_encoder_t,
342        variant_encoder_t,
343        map_hint,
344        ..
345    } = b.tokens;
346
347    let content_static = b.cx.ident("CONTENT");
348    let hint = b.cx.ident("STRUCT_HINT");
349    let name_static = b.cx.ident("NAME");
350    let tag_encoder = b.cx.ident("tag_encoder");
351    let tag_static = b.cx.ident("TAG");
352    let variant_encoder = b.cx.ident("variant_encoder");
353
354    let type_name = v.st.name;
355
356    let mut encode;
357
358    match en.enum_tagging {
359        EnumTagging::Empty => {
360            let static_type = en.static_type();
361            let encode_t_encode = &b.encode_t_encode;
362            let name = &v.name;
363
364            encode = quote! {{
365                static #name_static: #static_type = #name;
366                #encode_t_encode(&#name_static, #ctx_var, #encoder_var)?
367            }};
368        }
369        EnumTagging::Default => {
370            match v.st.packing {
371                Packing::Transparent => {
372                    let f = &v.st.unskipped_fields[0];
373
374                    let encode_path = &f.encode_path.1;
375                    let var = &f.self_access;
376                    encode = quote!(#encode_path(#var, #ctx_var, #encoder_var)?);
377                }
378                Packing::Packed => {
379                    let decls = tests.iter().map(|t| &t.decl);
380
381                    encode = quote! {{
382                        #encoder_t::encode_pack_fn(#encoder_var, move |#pack_var| {
383                            #(#decls)*
384                            #(#encoders)*
385                            #result_ok(())
386                        })?
387                    }};
388                }
389                Packing::Tagged => {
390                    let decls = tests.iter().map(|t| &t.decl);
391                    let (build_hint, hint) =
392                        length_test(v.st.unskipped_fields.len(), &tests).build(b);
393
394                    encode = quote! {{
395                        #build_hint
396
397                        #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
398                            #(#decls)*
399                            #(#encoders)*
400                            #result_ok(())
401                        })?
402                    }};
403                }
404            }
405
406            if let Packing::Tagged = en.enum_packing {
407                let encode_t_encode = &b.encode_t_encode;
408                let name = &v.name;
409                let static_type = en.static_type();
410
411                encode = quote! {{
412                    #encoder_t::encode_variant_fn(#encoder_var, move |#variant_encoder| {
413                        let #tag_encoder = #variant_encoder_t::encode_tag(#variant_encoder)?;
414                        static #name_static: #static_type = #name;
415
416                        #encode_t_encode(&#name_static, #ctx_var, #tag_encoder)?;
417
418                        let #encoder_var = #variant_encoder_t::encode_data(#variant_encoder)?;
419                        #encode;
420                        #result_ok(())
421                    })?
422                }};
423            }
424        }
425        EnumTagging::Internal { tag } => {
426            let name = &v.name;
427
428            let static_type = en.static_type();
429
430            let decls = tests.iter().map(|t| &t.decl);
431            let mut len = length_test(v.st.unskipped_fields.len(), &tests);
432
433            // Add one for the tag field.
434            len.expressions.push(quote!(1));
435
436            let (build_hint, hint) = len.build(b);
437
438            encode = quote! {{
439                #build_hint
440
441                #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
442                    static #tag_static: #static_type = #tag;
443                    static #name_static: #static_type = #name;
444                    #map_encoder_t::insert_entry(#encoder_var, #tag_static, #name_static)?;
445                    #(#decls)*
446                    #(#encoders)*
447                    #result_ok(())
448                })?
449            }};
450        }
451        EnumTagging::Adjacent { tag, content } => {
452            let encode_t_encode = &b.encode_t_encode;
453
454            let name = &v.name;
455            let static_type = en.static_type();
456
457            let decls = tests.iter().map(|t| &t.decl);
458
459            let (build_hint, inner_hint) =
460                length_test(v.st.unskipped_fields.len(), &tests).build(b);
461            let struct_encoder = b.cx.ident("struct_encoder");
462            let content_struct = b.cx.ident("content_struct");
463            let pair = b.cx.ident("pair");
464            let content_tag = b.cx.ident("content_tag");
465
466            encode = quote! {{
467                static #hint: #map_hint = #map_hint::with_size(2);
468                #build_hint
469
470                #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#struct_encoder| {
471                    static #tag_static: #static_type = #tag;
472                    static #name_static: #static_type = #name;
473                    static #content_static: #static_type = #content;
474
475                    #map_encoder_t::insert_entry(#struct_encoder, #tag_static, #name_static)?;
476
477                    #map_encoder_t::encode_entry_fn(#struct_encoder, move |#pair| {
478                        let #content_tag = #map_entry_encoder_t::encode_key(#pair)?;
479                        #encode_t_encode(&#content_static, #ctx_var, #content_tag)?;
480
481                        let #content_struct = #map_entry_encoder_t::encode_value(#pair)?;
482
483                        #encoder_t::encode_map_fn(#content_struct, &#inner_hint, move |#encoder_var| {
484                            #(#decls)*
485                            #(#encoders)*
486                            #result_ok(())
487                        })?;
488
489                        #result_ok(())
490                    })?;
491
492                    #result_ok(())
493                })?
494            }};
495        }
496    }
497
498    let pattern = syn::PatStruct {
499        attrs: Vec::new(),
500        qself: None,
501        path: v.st.path.clone(),
502        brace_token: syn::token::Brace::default(),
503        fields: v.patterns.clone(),
504        rest: None,
505    };
506
507    if cx.trace {
508        let output_var = b.cx.ident("output");
509
510        let (decl, name) = en.name_format(&name_static, &v.name);
511        let enter = quote!(#context_t::enter_variant(#ctx_var, #type_name, #name));
512        let leave = quote!(#context_t::leave_variant(#ctx_var));
513
514        encode = quote! {{
515            #decl
516            #enter;
517            let #output_var = #encode;
518            #leave;
519            #output_var
520        }};
521    }
522
523    Ok((pattern, encode))
524}
525
526struct LengthTest {
527    kind: LengthTestKind,
528    expressions: Punctuated<TokenStream, Token![+]>,
529}
530
531impl LengthTest {
532    fn build(&self, b: &Build<'_>) -> (syn::Stmt, syn::Ident) {
533        let Tokens { map_hint, .. } = b.tokens;
534
535        match self.kind {
536            LengthTestKind::Static => {
537                let hint = b.cx.ident("HINT");
538                let len = &self.expressions;
539                let item = syn::parse_quote!(static #hint: #map_hint = #map_hint::with_size(#len););
540                (item, hint)
541            }
542            LengthTestKind::Dynamic => {
543                let hint = b.cx.ident("hint");
544                let len = &self.expressions;
545                let item = syn::parse_quote!(let #hint: #map_hint = #map_hint::with_size(#len););
546                (item, hint)
547            }
548        }
549    }
550}
551
552enum LengthTestKind {
553    Static,
554    Dynamic,
555}
556
557fn length_test(count: usize, tests: &[FieldTest<'_>]) -> LengthTest {
558    let mut kind = LengthTestKind::Static;
559
560    let mut expressions = Punctuated::<_, Token![+]>::new();
561    let count = count.saturating_sub(tests.len());
562    expressions.push(quote!(#count));
563
564    for FieldTest { var, .. } in tests {
565        kind = LengthTestKind::Dynamic;
566        expressions.push(quote!(if #var { 1 } else { 0 }))
567    }
568
569    LengthTest { kind, expressions }
570}