1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::Token;
5
6use crate::internals::attr::{EnumTagging, Packing};
7use crate::internals::build::{Body, Build, BuildData, Enum, Variant};
8use crate::internals::tokens::Tokens;
9use crate::internals::Result;
10
11struct Ctxt<'a> {
12 ctx_var: &'a syn::Ident,
13 encoder_var: &'a syn::Ident,
14 trace: bool,
15}
16
17pub(crate) fn expand_insert_entry(e: Build<'_>) -> Result<TokenStream> {
18 e.validate_encode()?;
19 e.cx.reset();
20
21 let type_ident = &e.input.ident;
22
23 let encoder_var = e.cx.ident("encoder");
24 let ctx_var = e.cx.ident("ctx");
25 let e_param = e.cx.type_with_span("E", Span::call_site());
26
27 let cx = Ctxt {
28 ctx_var: &ctx_var,
29 encoder_var: &encoder_var,
30 trace: true,
31 };
32
33 let Tokens {
34 encode_t,
35 encoder_t,
36 result,
37 ..
38 } = e.tokens;
39
40 let body = match &e.data {
41 BuildData::Struct(st) => encode_map(&cx, &e, st)?,
42 BuildData::Enum(en) => encode_enum(&cx, &e, en)?,
43 };
44
45 if e.cx.has_errors() {
46 return Err(());
47 }
48
49 let mut impl_generics = e.input.generics.clone();
50
51 if !e.bounds.is_empty() {
52 let where_clause = impl_generics.make_where_clause();
53
54 where_clause
55 .predicates
56 .extend(e.bounds.iter().map(|(_, v)| v.clone()));
57 }
58
59 let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
60 let (_, type_generics, _) = e.input.generics.split_for_impl();
61
62 let mut attributes = Vec::<syn::Attribute>::new();
63
64 if cfg!(not(feature = "verbose")) {
65 attributes.push(syn::parse_quote!(#[allow(clippy::just_underscores_and_digits)]));
66 }
67
68 let mode_ident = e.expansion.mode_path(e.tokens).as_path();
69
70 Ok(quote! {
71 const _: () = {
72 #[automatically_derived]
73 #(#attributes)*
74 impl #impl_generics #encode_t<#mode_ident> for #type_ident #type_generics #where_clause {
75 #[inline]
76 fn encode<#e_param>(&self, #ctx_var: &#e_param::Cx, #encoder_var: #e_param) -> #result<<#e_param as #encoder_t>::Ok, <#e_param as #encoder_t>::Error>
77 where
78 #e_param: #encoder_t<Mode = #mode_ident>,
79 {
80 #body
81 }
82 }
83 };
84 })
85}
86
87fn encode_map(cx: &Ctxt<'_>, b: &Build<'_>, st: &Body<'_>) -> Result<TokenStream> {
89 let Ctxt {
90 ctx_var,
91 encoder_var,
92 ..
93 } = *cx;
94
95 let Tokens {
96 context_t,
97 encoder_t,
98 result_ok,
99 ..
100 } = b.tokens;
101
102 let pack_var = b.cx.ident("pack");
103 let output_var = b.cx.ident("output");
104
105 let (encoders, tests) = insert_fields(cx, b, st, &pack_var)?;
106
107 let type_name = &st.name;
108
109 let enter = cx
110 .trace
111 .then(|| quote!(#context_t::enter_struct(#ctx_var, #type_name);));
112 let leave = cx
113 .trace
114 .then(|| quote!(#context_t::leave_struct(#ctx_var);));
115
116 let encode;
117
118 match st.packing {
119 Packing::Transparent => {
120 let f = &st.unskipped_fields[0];
121
122 let access = &f.self_access;
123 let encode_path = &f.encode_path.1;
124
125 encode = quote! {{
126 #enter
127 let #output_var = #encode_path(#access, #ctx_var, #encoder_var)?;
128 #leave
129 #output_var
130 }};
131 }
132 Packing::Tagged => {
133 let decls = tests.iter().map(|t| &t.decl);
134 let (build_hint, hint) = length_test(st.unskipped_fields.len(), &tests).build(b);
135
136 encode = quote! {{
137 #enter
138 #(#decls)*
139 #build_hint
140
141 let #output_var = #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
142 #(#encoders)*
143 #result_ok(())
144 })?;
145 #leave
146 #output_var
147 }};
148 }
149 Packing::Packed => {
150 let decls = tests.iter().map(|t| &t.decl);
151
152 encode = quote! {{
153 #enter
154 let #output_var = #encoder_t::encode_pack_fn(#encoder_var, move |#pack_var| {
155 #(#decls)*
156 #(#encoders)*
157 #result_ok(())
158 })?;
159 #leave
160 #output_var
161 }};
162 }
163 }
164
165 Ok(quote!(#result_ok(#encode)))
166}
167
168struct FieldTest<'st> {
169 decl: syn::Stmt,
170 var: &'st syn::Ident,
171}
172
173fn insert_fields<'st>(
174 cx: &Ctxt<'_>,
175 b: &Build<'_>,
176 st: &'st Body<'_>,
177 pack_var: &syn::Ident,
178) -> Result<(Vec<TokenStream>, Vec<FieldTest<'st>>)> {
179 let Ctxt {
180 ctx_var,
181 encoder_var,
182 ..
183 } = *cx;
184
185 let Tokens {
186 context_t,
187 sequence_encoder_t,
188 result_ok,
189
190 map_encoder_t,
191 map_entry_encoder_t,
192 ..
193 } = b.tokens;
194
195 let encode_t_encode = &b.encode_t_encode;
196
197 let sequence_decoder_next_var = b.cx.ident("sequence_decoder_next");
198 let pair_encoder_var = b.cx.ident("pair_encoder");
199 let field_encoder_var = b.cx.ident("field_encoder");
200 let value_encoder_var = b.cx.ident("value_encoder");
201 let field_name_static = b.cx.ident("FIELD_NAME");
202
203 let mut encoders = Vec::with_capacity(st.all_fields.len());
204 let mut tests = Vec::with_capacity(st.all_fields.len());
205
206 for f in &st.unskipped_fields {
207 let encode_path = &f.encode_path.1;
208 let access = &f.self_access;
209 let name = &f.name;
210 let name_type = st.name_local_type();
211
212 let mut encode;
213
214 let enter = match &f.member {
215 syn::Member::Named(ident) => {
216 let field_name = syn::LitStr::new(&ident.to_string(), ident.span());
217
218 cx.trace.then(|| {
219 let name = st.name_format(name);
220 quote!(#context_t::enter_named_field(#ctx_var, #field_name, #name);)
221 })
222 }
223 syn::Member::Unnamed(index) => {
224 let index = index.index;
225 cx.trace.then(|| {
226 let name = st.name_format(name);
227 quote!(#context_t::enter_unnamed_field(#ctx_var, #index, #name);)
228 })
229 }
230 };
231
232 let leave = cx.trace.then(|| quote!(#context_t::leave_field(#ctx_var);));
233
234 match f.packing {
235 Packing::Tagged | Packing::Transparent => {
236 encode = quote! {
237 #enter
238
239 #map_encoder_t::encode_entry_fn(#encoder_var, move |#pair_encoder_var| {
240 static #field_name_static: #name_type = #name;
241 let #field_encoder_var = #map_entry_encoder_t::encode_key(#pair_encoder_var)?;
242 #encode_t_encode(&#field_name_static, #ctx_var, #field_encoder_var)?;
243 let #value_encoder_var = #map_entry_encoder_t::encode_value(#pair_encoder_var)?;
244 #encode_path(#access, #ctx_var, #value_encoder_var)?;
245 #result_ok(())
246 })?;
247
248 #leave
249 };
250 }
251 Packing::Packed => {
252 encode = quote! {
253 #enter
254 let #sequence_decoder_next_var = #sequence_encoder_t::encode_next(#pack_var)?;
255 #encode_path(#access, #ctx_var, #sequence_decoder_next_var)?;
256 #leave
257 };
258 }
259 };
260
261 if let Some((_, skip_encoding_if_path)) = f.skip_encoding_if.as_ref() {
262 let var = &f.var;
263
264 let decl = syn::parse_quote! {
265 let #var = !#skip_encoding_if_path(#access);
266 };
267
268 encode = quote! {
269 if #var {
270 #encode
271 }
272 };
273
274 tests.push(FieldTest { decl, var })
275 }
276
277 encoders.push(encode);
278 }
279
280 Ok((encoders, tests))
281}
282
283fn encode_enum(cx: &Ctxt<'_>, b: &Build<'_>, en: &Enum<'_>) -> Result<TokenStream> {
285 let Ctxt { ctx_var, .. } = *cx;
286
287 let Tokens {
288 context_t,
289 result_ok,
290 result_err,
291 ..
292 } = b.tokens;
293
294 let type_name = en.name;
295
296 if let Some(&(span, Packing::Transparent)) = en.packing_span {
297 b.encode_transparent_enum_diagnostics(span);
298 return Err(());
299 }
300
301 let mut variants = Vec::with_capacity(en.variants.len());
302
303 for v in &en.variants {
304 let Ok((pattern, encode)) = encode_variant(cx, b, en, v) else {
305 continue;
306 };
307
308 variants.push(quote!(#pattern => #encode));
309 }
310
311 Ok(if variants.is_empty() {
313 quote!(#result_err(#context_t::uninhabitable(#ctx_var, #type_name)))
314 } else {
315 quote!(#result_ok(match self { #(#variants),* }))
316 })
317}
318
319fn encode_variant(
321 cx: &Ctxt<'_>,
322 b: &Build<'_>,
323 en: &Enum<'_>,
324 v: &Variant<'_>,
325) -> Result<(syn::PatStruct, TokenStream)> {
326 let pack_var = b.cx.ident("pack");
327
328 let (encoders, tests) = insert_fields(cx, b, &v.st, &pack_var)?;
329
330 let Ctxt {
331 ctx_var,
332 encoder_var,
333 ..
334 } = *cx;
335
336 let Tokens {
337 context_t,
338 encoder_t,
339 result_ok,
340 map_encoder_t,
341 map_entry_encoder_t,
342 variant_encoder_t,
343 map_hint,
344 ..
345 } = b.tokens;
346
347 let content_static = b.cx.ident("CONTENT");
348 let hint = b.cx.ident("STRUCT_HINT");
349 let name_static = b.cx.ident("NAME");
350 let tag_encoder = b.cx.ident("tag_encoder");
351 let tag_static = b.cx.ident("TAG");
352 let variant_encoder = b.cx.ident("variant_encoder");
353
354 let type_name = v.st.name;
355
356 let mut encode;
357
358 match en.enum_tagging {
359 EnumTagging::Empty => {
360 let static_type = en.static_type();
361 let encode_t_encode = &b.encode_t_encode;
362 let name = &v.name;
363
364 encode = quote! {{
365 static #name_static: #static_type = #name;
366 #encode_t_encode(&#name_static, #ctx_var, #encoder_var)?
367 }};
368 }
369 EnumTagging::Default => {
370 match v.st.packing {
371 Packing::Transparent => {
372 let f = &v.st.unskipped_fields[0];
373
374 let encode_path = &f.encode_path.1;
375 let var = &f.self_access;
376 encode = quote!(#encode_path(#var, #ctx_var, #encoder_var)?);
377 }
378 Packing::Packed => {
379 let decls = tests.iter().map(|t| &t.decl);
380
381 encode = quote! {{
382 #encoder_t::encode_pack_fn(#encoder_var, move |#pack_var| {
383 #(#decls)*
384 #(#encoders)*
385 #result_ok(())
386 })?
387 }};
388 }
389 Packing::Tagged => {
390 let decls = tests.iter().map(|t| &t.decl);
391 let (build_hint, hint) =
392 length_test(v.st.unskipped_fields.len(), &tests).build(b);
393
394 encode = quote! {{
395 #build_hint
396
397 #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
398 #(#decls)*
399 #(#encoders)*
400 #result_ok(())
401 })?
402 }};
403 }
404 }
405
406 if let Packing::Tagged = en.enum_packing {
407 let encode_t_encode = &b.encode_t_encode;
408 let name = &v.name;
409 let static_type = en.static_type();
410
411 encode = quote! {{
412 #encoder_t::encode_variant_fn(#encoder_var, move |#variant_encoder| {
413 let #tag_encoder = #variant_encoder_t::encode_tag(#variant_encoder)?;
414 static #name_static: #static_type = #name;
415
416 #encode_t_encode(&#name_static, #ctx_var, #tag_encoder)?;
417
418 let #encoder_var = #variant_encoder_t::encode_data(#variant_encoder)?;
419 #encode;
420 #result_ok(())
421 })?
422 }};
423 }
424 }
425 EnumTagging::Internal { tag } => {
426 let name = &v.name;
427
428 let static_type = en.static_type();
429
430 let decls = tests.iter().map(|t| &t.decl);
431 let mut len = length_test(v.st.unskipped_fields.len(), &tests);
432
433 len.expressions.push(quote!(1));
435
436 let (build_hint, hint) = len.build(b);
437
438 encode = quote! {{
439 #build_hint
440
441 #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#encoder_var| {
442 static #tag_static: #static_type = #tag;
443 static #name_static: #static_type = #name;
444 #map_encoder_t::insert_entry(#encoder_var, #tag_static, #name_static)?;
445 #(#decls)*
446 #(#encoders)*
447 #result_ok(())
448 })?
449 }};
450 }
451 EnumTagging::Adjacent { tag, content } => {
452 let encode_t_encode = &b.encode_t_encode;
453
454 let name = &v.name;
455 let static_type = en.static_type();
456
457 let decls = tests.iter().map(|t| &t.decl);
458
459 let (build_hint, inner_hint) =
460 length_test(v.st.unskipped_fields.len(), &tests).build(b);
461 let struct_encoder = b.cx.ident("struct_encoder");
462 let content_struct = b.cx.ident("content_struct");
463 let pair = b.cx.ident("pair");
464 let content_tag = b.cx.ident("content_tag");
465
466 encode = quote! {{
467 static #hint: #map_hint = #map_hint::with_size(2);
468 #build_hint
469
470 #encoder_t::encode_map_fn(#encoder_var, &#hint, move |#struct_encoder| {
471 static #tag_static: #static_type = #tag;
472 static #name_static: #static_type = #name;
473 static #content_static: #static_type = #content;
474
475 #map_encoder_t::insert_entry(#struct_encoder, #tag_static, #name_static)?;
476
477 #map_encoder_t::encode_entry_fn(#struct_encoder, move |#pair| {
478 let #content_tag = #map_entry_encoder_t::encode_key(#pair)?;
479 #encode_t_encode(&#content_static, #ctx_var, #content_tag)?;
480
481 let #content_struct = #map_entry_encoder_t::encode_value(#pair)?;
482
483 #encoder_t::encode_map_fn(#content_struct, &#inner_hint, move |#encoder_var| {
484 #(#decls)*
485 #(#encoders)*
486 #result_ok(())
487 })?;
488
489 #result_ok(())
490 })?;
491
492 #result_ok(())
493 })?
494 }};
495 }
496 }
497
498 let pattern = syn::PatStruct {
499 attrs: Vec::new(),
500 qself: None,
501 path: v.st.path.clone(),
502 brace_token: syn::token::Brace::default(),
503 fields: v.patterns.clone(),
504 rest: None,
505 };
506
507 if cx.trace {
508 let output_var = b.cx.ident("output");
509
510 let (decl, name) = en.name_format(&name_static, &v.name);
511 let enter = quote!(#context_t::enter_variant(#ctx_var, #type_name, #name));
512 let leave = quote!(#context_t::leave_variant(#ctx_var));
513
514 encode = quote! {{
515 #decl
516 #enter;
517 let #output_var = #encode;
518 #leave;
519 #output_var
520 }};
521 }
522
523 Ok((pattern, encode))
524}
525
526struct LengthTest {
527 kind: LengthTestKind,
528 expressions: Punctuated<TokenStream, Token![+]>,
529}
530
531impl LengthTest {
532 fn build(&self, b: &Build<'_>) -> (syn::Stmt, syn::Ident) {
533 let Tokens { map_hint, .. } = b.tokens;
534
535 match self.kind {
536 LengthTestKind::Static => {
537 let hint = b.cx.ident("HINT");
538 let len = &self.expressions;
539 let item = syn::parse_quote!(static #hint: #map_hint = #map_hint::with_size(#len););
540 (item, hint)
541 }
542 LengthTestKind::Dynamic => {
543 let hint = b.cx.ident("hint");
544 let len = &self.expressions;
545 let item = syn::parse_quote!(let #hint: #map_hint = #map_hint::with_size(#len););
546 (item, hint)
547 }
548 }
549 }
550}
551
552enum LengthTestKind {
553 Static,
554 Dynamic,
555}
556
557fn length_test(count: usize, tests: &[FieldTest<'_>]) -> LengthTest {
558 let mut kind = LengthTestKind::Static;
559
560 let mut expressions = Punctuated::<_, Token![+]>::new();
561 let count = count.saturating_sub(tests.len());
562 expressions.push(quote!(#count));
563
564 for FieldTest { var, .. } in tests {
565 kind = LengthTestKind::Dynamic;
566 expressions.push(quote!(if #var { 1 } else { 0 }))
567 }
568
569 LengthTest { kind, expressions }
570}