base64/engine/general_purpose/
decode.rs
1use crate::{
2 engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3 DecodeError, DecodeSliceError, PAD_BYTE,
4};
5
6#[doc(hidden)]
7pub struct GeneralPurposeEstimate {
8 rem: usize,
10 conservative_decoded_len: usize,
11}
12
13impl GeneralPurposeEstimate {
14 pub(crate) fn new(encoded_len: usize) -> Self {
15 let rem = encoded_len % 4;
16 Self {
17 rem,
18 conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
19 }
20 }
21}
22
23impl DecodeEstimate for GeneralPurposeEstimate {
24 fn decoded_len_estimate(&self) -> usize {
25 self.conservative_decoded_len
26 }
27}
28
29#[inline]
35pub(crate) fn decode_helper(
36 input: &[u8],
37 estimate: GeneralPurposeEstimate,
38 output: &mut [u8],
39 decode_table: &[u8; 256],
40 decode_allow_trailing_bits: bool,
41 padding_mode: DecodePaddingMode,
42) -> Result<DecodeMetadata, DecodeSliceError> {
43 let input_complete_nonterminal_quads_len =
44 complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
45
46 const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
47 const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
48
49 let input_complete_quads_after_unrolled_chunks_len =
50 input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
51
52 let input_unrolled_loop_len =
53 input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
54
55 for (chunk_index, chunk) in input[..input_unrolled_loop_len]
57 .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
58 .enumerate()
59 {
60 let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
61 let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
62 ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
63
64 decode_chunk_8(
65 &chunk[0..8],
66 input_index,
67 decode_table,
68 &mut chunk_output[0..6],
69 )?;
70 decode_chunk_8(
71 &chunk[8..16],
72 input_index + 8,
73 decode_table,
74 &mut chunk_output[6..12],
75 )?;
76 decode_chunk_8(
77 &chunk[16..24],
78 input_index + 16,
79 decode_table,
80 &mut chunk_output[12..18],
81 )?;
82 decode_chunk_8(
83 &chunk[24..32],
84 input_index + 24,
85 decode_table,
86 &mut chunk_output[18..24],
87 )?;
88 }
89
90 let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
92 let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
93 {
94 let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
95
96 for (chunk_index, chunk) in input
97 [input_unrolled_loop_len..input_complete_nonterminal_quads_len]
98 .chunks_exact(4)
99 .enumerate()
100 {
101 let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
102
103 decode_chunk_4(
104 chunk,
105 input_unrolled_loop_len + chunk_index * 4,
106 decode_table,
107 chunk_output,
108 )?;
109 }
110 }
111
112 super::decode_suffix::decode_suffix(
113 input,
114 input_complete_nonterminal_quads_len,
115 output,
116 output_complete_quad_len,
117 decode_table,
118 decode_allow_trailing_bits,
119 padding_mode,
120 )
121}
122
123pub(crate) fn complete_quads_len(
132 input: &[u8],
133 input_len_rem: usize,
134 output_len: usize,
135 decode_table: &[u8; 256],
136) -> Result<usize, DecodeSliceError> {
137 debug_assert!(input.len() % 4 == input_len_rem);
138
139 if input_len_rem == 1 {
141 let last_byte = input[input.len() - 1];
142 if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
144 return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
145 }
146 };
147
148 let input_complete_nonterminal_quads_len = input
150 .len()
151 .saturating_sub(input_len_rem)
152 .saturating_sub((input_len_rem == 0) as usize * 4);
154 debug_assert!(
155 input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
156 );
157
158 if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
160 return Err(DecodeSliceError::OutputSliceTooSmall);
161 };
162 Ok(input_complete_nonterminal_quads_len)
163}
164
165#[inline(always)]
174fn decode_chunk_8(
175 input: &[u8],
176 index_at_start_of_input: usize,
177 decode_table: &[u8; 256],
178 output: &mut [u8],
179) -> Result<(), DecodeError> {
180 let morsel = decode_table[usize::from(input[0])];
181 if morsel == INVALID_VALUE {
182 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
183 }
184 let mut accum = u64::from(morsel) << 58;
185
186 let morsel = decode_table[usize::from(input[1])];
187 if morsel == INVALID_VALUE {
188 return Err(DecodeError::InvalidByte(
189 index_at_start_of_input + 1,
190 input[1],
191 ));
192 }
193 accum |= u64::from(morsel) << 52;
194
195 let morsel = decode_table[usize::from(input[2])];
196 if morsel == INVALID_VALUE {
197 return Err(DecodeError::InvalidByte(
198 index_at_start_of_input + 2,
199 input[2],
200 ));
201 }
202 accum |= u64::from(morsel) << 46;
203
204 let morsel = decode_table[usize::from(input[3])];
205 if morsel == INVALID_VALUE {
206 return Err(DecodeError::InvalidByte(
207 index_at_start_of_input + 3,
208 input[3],
209 ));
210 }
211 accum |= u64::from(morsel) << 40;
212
213 let morsel = decode_table[usize::from(input[4])];
214 if morsel == INVALID_VALUE {
215 return Err(DecodeError::InvalidByte(
216 index_at_start_of_input + 4,
217 input[4],
218 ));
219 }
220 accum |= u64::from(morsel) << 34;
221
222 let morsel = decode_table[usize::from(input[5])];
223 if morsel == INVALID_VALUE {
224 return Err(DecodeError::InvalidByte(
225 index_at_start_of_input + 5,
226 input[5],
227 ));
228 }
229 accum |= u64::from(morsel) << 28;
230
231 let morsel = decode_table[usize::from(input[6])];
232 if morsel == INVALID_VALUE {
233 return Err(DecodeError::InvalidByte(
234 index_at_start_of_input + 6,
235 input[6],
236 ));
237 }
238 accum |= u64::from(morsel) << 22;
239
240 let morsel = decode_table[usize::from(input[7])];
241 if morsel == INVALID_VALUE {
242 return Err(DecodeError::InvalidByte(
243 index_at_start_of_input + 7,
244 input[7],
245 ));
246 }
247 accum |= u64::from(morsel) << 16;
248
249 output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
250
251 Ok(())
252}
253
254#[inline(always)]
256fn decode_chunk_4(
257 input: &[u8],
258 index_at_start_of_input: usize,
259 decode_table: &[u8; 256],
260 output: &mut [u8],
261) -> Result<(), DecodeError> {
262 let morsel = decode_table[usize::from(input[0])];
263 if morsel == INVALID_VALUE {
264 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
265 }
266 let mut accum = u32::from(morsel) << 26;
267
268 let morsel = decode_table[usize::from(input[1])];
269 if morsel == INVALID_VALUE {
270 return Err(DecodeError::InvalidByte(
271 index_at_start_of_input + 1,
272 input[1],
273 ));
274 }
275 accum |= u32::from(morsel) << 20;
276
277 let morsel = decode_table[usize::from(input[2])];
278 if morsel == INVALID_VALUE {
279 return Err(DecodeError::InvalidByte(
280 index_at_start_of_input + 2,
281 input[2],
282 ));
283 }
284 accum |= u32::from(morsel) << 14;
285
286 let morsel = decode_table[usize::from(input[3])];
287 if morsel == INVALID_VALUE {
288 return Err(DecodeError::InvalidByte(
289 index_at_start_of_input + 3,
290 input[3],
291 ));
292 }
293 accum |= u32::from(morsel) << 8;
294
295 output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
296
297 Ok(())
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 use crate::engine::general_purpose::STANDARD;
305
306 #[test]
307 fn decode_chunk_8_writes_only_6_bytes() {
308 let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
310
311 decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
312 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
313 }
314
315 #[test]
316 fn decode_chunk_4_writes_only_3_bytes() {
317 let input = b"Zm9v"; let mut output = [0_u8, 1, 2, 3];
319
320 decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
321 assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
322 }
323
324 #[test]
325 fn estimate_short_lengths() {
326 for (range, decoded_len_estimate) in [
327 (0..=0, 0),
328 (1..=4, 3),
329 (5..=8, 6),
330 (9..=12, 9),
331 (13..=16, 12),
332 (17..=20, 15),
333 ] {
334 for encoded_len in range {
335 let estimate = GeneralPurposeEstimate::new(encoded_len);
336 assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
337 }
338 }
339 }
340
341 #[test]
342 fn estimate_via_u128_inflation() {
343 (0..1000)
345 .chain(usize::MAX - 1000..=usize::MAX)
346 .for_each(|encoded_len| {
347 let len_128 = encoded_len as u128;
349
350 let estimate = GeneralPurposeEstimate::new(encoded_len);
351 assert_eq!(
352 (len_128 + 3) / 4 * 3,
353 estimate.conservative_decoded_len as u128
354 );
355 })
356 }
357}