musli/
reader.rs

1//! Trait governing how to read bytes.
2
3use core::fmt;
4use core::marker;
5use core::mem::MaybeUninit;
6use core::ops::Range;
7use core::ptr;
8use core::slice;
9
10use crate::de::UnsizedVisitor;
11use crate::Context;
12
13mod sealed {
14    use super::{Limit, Reader};
15
16    pub trait Sealed {}
17
18    impl Sealed for &[u8] {}
19    impl Sealed for super::SliceReader<'_> {}
20    impl<'de, R> Sealed for Limit<R> where R: Reader<'de> {}
21    impl<'de, R> Sealed for &mut R where R: ?Sized + Reader<'de> {}
22}
23
24/// Coerce a type into a [`Reader`].
25pub trait IntoReader<'de>: self::sealed::Sealed {
26    /// The reader type.
27    type Reader: Reader<'de>;
28
29    /// Convert the type into a reader.
30    fn into_reader(self) -> Self::Reader;
31}
32
33/// Trait governing how a source of bytes is read.
34///
35/// This requires the reader to be able to hand out contiguous references to the
36/// byte source through [`Reader::read_bytes`].
37pub trait Reader<'de>: self::sealed::Sealed {
38    /// Type borrowed from self.
39    ///
40    /// Why oh why would we want to do this over having a simple `&'this mut T`?
41    ///
42    /// We want to avoid recursive types, which will blow up the compiler. And
43    /// the above is a typical example of when that can go wrong. This ensures
44    /// that each call to `borrow_mut` dereferences the [`Reader`] at each step
45    /// to avoid constructing a large muted type, like `&mut &mut &mut
46    /// SliceReader<'de>`.
47    type Mut<'this>: Reader<'de>
48    where
49        Self: 'this;
50
51    /// Type that can be cloned from the reader.
52    type TryClone: Reader<'de>;
53
54    /// Borrow the current reader.
55    fn borrow_mut(&mut self) -> Self::Mut<'_>;
56
57    /// Try to clone the reader.
58    fn try_clone(&self) -> Option<Self::TryClone>;
59
60    /// Test if the reader is at end of input.
61    fn is_eof(&mut self) -> bool;
62
63    /// Skip over the given number of bytes.
64    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
65    where
66        C: Context;
67
68    /// Peek the next value.
69    fn peek(&mut self) -> Option<u8>;
70
71    /// Read a slice into the given buffer.
72    #[inline]
73    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
74    where
75        C: Context,
76    {
77        struct Visitor<'a>(&'a mut [u8]);
78
79        #[crate::de::unsized_visitor(crate)]
80        impl<C> UnsizedVisitor<'_, C, [u8]> for Visitor<'_>
81        where
82            C: Context,
83        {
84            type Ok = ();
85
86            #[inline]
87            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88                write!(f, "bytes")
89            }
90
91            #[inline]
92            fn visit_ref(self, _: C, bytes: &[u8]) -> Result<Self::Ok, C::Error> {
93                self.0.copy_from_slice(bytes);
94                Ok(())
95            }
96        }
97
98        self.read_bytes(cx, buf.len(), Visitor(buf))
99    }
100
101    /// Read a slice out of the current reader.
102    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
103    where
104        C: Context,
105        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>;
106
107    /// Read into the given buffer which might not have been initialized.
108    ///
109    /// # Safety
110    ///
111    /// The caller must ensure that the buffer points to valid memory of length
112    /// `len`.
113    unsafe fn read_bytes_uninit<C>(
114        &mut self,
115        cx: C,
116        ptr: *mut u8,
117        len: usize,
118    ) -> Result<(), C::Error>
119    where
120        C: Context;
121
122    /// Read a single byte.
123    #[inline]
124    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
125    where
126        C: Context,
127    {
128        let [byte] = self.read_array::<C, 1>(cx)?;
129        Ok(byte)
130    }
131
132    /// Read an array out of the current reader.
133    #[inline]
134    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
135    where
136        C: Context,
137    {
138        struct Visitor<const N: usize>([u8; N]);
139
140        #[crate::de::unsized_visitor(crate)]
141        impl<const N: usize, C> UnsizedVisitor<'_, C, [u8]> for Visitor<N>
142        where
143            C: Context,
144        {
145            type Ok = [u8; N];
146
147            #[inline]
148            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149                write!(f, "bytes")
150            }
151
152            #[inline]
153            fn visit_ref(mut self, cx: C, bytes: &[u8]) -> Result<Self::Ok, C::Error> {
154                self.0.copy_from_slice(bytes);
155                cx.advance(bytes.len());
156                Ok(self.0)
157            }
158        }
159
160        self.read_bytes(cx, N, Visitor([0u8; N]))
161    }
162
163    /// Keep an accurate record of the position within the reader.
164    #[inline]
165    fn limit(self, limit: usize) -> Limit<Self>
166    where
167        Self: Sized,
168    {
169        Limit {
170            remaining: limit,
171            reader: self,
172        }
173    }
174}
175
176impl<'de> IntoReader<'de> for &'de [u8] {
177    type Reader = &'de [u8];
178
179    #[inline]
180    fn into_reader(self) -> Self::Reader {
181        self
182    }
183}
184
185impl<'de> Reader<'de> for &'de [u8] {
186    type Mut<'this>
187        = &'this mut &'de [u8]
188    where
189        Self: 'this;
190
191    type TryClone = &'de [u8];
192
193    #[inline]
194    fn borrow_mut(&mut self) -> Self::Mut<'_> {
195        self
196    }
197
198    #[inline]
199    fn try_clone(&self) -> Option<Self::TryClone> {
200        Some(self)
201    }
202
203    #[inline]
204    fn is_eof(&mut self) -> bool {
205        self.is_empty()
206    }
207
208    #[inline]
209    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
210    where
211        C: Context,
212    {
213        if self.len() < n {
214            return Err(cx.message(SliceUnderflow {
215                n,
216                remaining: self.len(),
217            }));
218        }
219
220        let (_, tail) = self.split_at(n);
221        *self = tail;
222        cx.advance(n);
223        Ok(())
224    }
225
226    #[inline]
227    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
228    where
229        C: Context,
230    {
231        if self.len() < buf.len() {
232            return Err(cx.custom(SliceUnderflow::new(buf.len(), self.len())));
233        }
234
235        let (head, tail) = self.split_at(buf.len());
236        buf.copy_from_slice(head);
237        *self = tail;
238        cx.advance(buf.len());
239        Ok(())
240    }
241
242    #[inline]
243    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
244    where
245        C: Context,
246        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
247    {
248        if self.len() < n {
249            return Err(cx.custom(SliceUnderflow::new(n, self.len())));
250        }
251
252        let (head, tail) = self.split_at(n);
253        *self = tail;
254        let ok = visitor.visit_borrowed(cx, head)?;
255        cx.advance(n);
256        Ok(ok)
257    }
258
259    #[inline]
260    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
261    where
262        C: Context,
263    {
264        if self.len() < n {
265            return Err(cx.custom(SliceUnderflow::new(n, self.len())));
266        }
267
268        ptr.copy_from_nonoverlapping(self.as_ptr(), n);
269        *self = &self[n..];
270        cx.advance(n);
271        Ok(())
272    }
273
274    #[inline]
275    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
276    where
277        C: Context,
278    {
279        let &[first, ref tail @ ..] = *self else {
280            return Err(cx.custom(SliceUnderflow::new(1, self.len())));
281        };
282
283        *self = tail;
284        cx.advance(1);
285        Ok(first)
286    }
287
288    #[inline]
289    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
290    where
291        C: Context,
292    {
293        if self.len() < N {
294            return Err(cx.custom(SliceUnderflow::new(N, self.len())));
295        }
296
297        cx.advance(N);
298
299        let mut array: MaybeUninit<[u8; N]> = MaybeUninit::uninit();
300
301        // SAFETY: We've checked the length of the current buffer just above.
302        // PERFORMANCE: This generates better code than `array::from_fn`, and
303        // `read_array` is performance sensitive.
304        unsafe {
305            array
306                .as_mut_ptr()
307                .cast::<u8>()
308                .copy_from_nonoverlapping(self.as_ptr(), N);
309            *self = self.get_unchecked(N..);
310            Ok(array.assume_init())
311        }
312    }
313
314    #[inline]
315    fn peek(&mut self) -> Option<u8> {
316        self.first().copied()
317    }
318}
319
320/// An efficient [`Reader`] wrapper around a slice.
321pub struct SliceReader<'de> {
322    range: Range<*const u8>,
323    _marker: marker::PhantomData<&'de [u8]>,
324}
325
326// SAFETY: `SliceReader` is effectively equivalent to `&'de [u8]`.
327unsafe impl Send for SliceReader<'_> {}
328// SAFETY: `SliceReader` is effectively equivalent to `&'de [u8]`.
329unsafe impl Sync for SliceReader<'_> {}
330
331impl<'de> SliceReader<'de> {
332    /// Construct a new instance around the specified slice.
333    #[inline]
334    pub fn new(slice: &'de [u8]) -> Self {
335        Self {
336            range: slice.as_ptr_range(),
337            _marker: marker::PhantomData,
338        }
339    }
340
341    /// Get the remaining contents of the reader as a slice.
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use musli::Context;
347    /// use musli::reader::{Reader, SliceReader};
348    ///
349    /// fn process<C>(cx: C) -> Result<(), C::Error>
350    /// where
351    ///     C: Context
352    /// {
353    ///     let mut reader = SliceReader::new(&[1, 2, 3, 4]);
354    ///     assert_eq!(reader.as_slice(), &[1, 2, 3, 4]);
355    ///     reader.skip(cx, 2)?;
356    ///     assert_eq!(reader.as_slice(), &[3, 4]);
357    ///     Ok(())
358    /// }
359    /// ```
360    #[inline]
361    pub fn as_slice(&self) -> &'de [u8] {
362        unsafe { slice::from_raw_parts(self.range.start, self.remaining()) }
363    }
364
365    /// Get remaining bytes in the reader.
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use musli::Context;
371    /// use musli::reader::{Reader, SliceReader};
372    ///
373    /// fn process<C>(cx: C) -> Result<(), C::Error>
374    /// where
375    ///     C: Context
376    /// {
377    ///     let mut reader = SliceReader::new(&[1, 2, 3, 4]);
378    ///     assert_eq!(reader.remaining(), 4);
379    ///     reader.skip(cx, 2);
380    ///     assert_eq!(reader.remaining(), 2);
381    ///     Ok(())
382    /// }
383    /// ```
384    pub fn remaining(&self) -> usize {
385        self.range.end as usize - self.range.start as usize
386    }
387}
388
389impl<'de> Reader<'de> for SliceReader<'de> {
390    type Mut<'this>
391        = &'this mut Self
392    where
393        Self: 'this;
394
395    type TryClone = Self;
396
397    #[inline]
398    fn borrow_mut(&mut self) -> Self::Mut<'_> {
399        self
400    }
401
402    #[inline]
403    fn try_clone(&self) -> Option<Self::TryClone> {
404        Some(Self {
405            range: self.range.clone(),
406            _marker: marker::PhantomData,
407        })
408    }
409
410    #[inline]
411    fn is_eof(&mut self) -> bool {
412        self.range.start == self.range.end
413    }
414
415    #[inline]
416    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
417    where
418        C: Context,
419    {
420        self.range.start = bounds_check_add(cx, &self.range, n)?;
421        cx.advance(n);
422        Ok(())
423    }
424
425    #[inline]
426    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
427    where
428        C: Context,
429        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
430    {
431        let outcome = bounds_check_add(cx, &self.range, n)?;
432
433        let ok = unsafe {
434            let bytes = slice::from_raw_parts(self.range.start, n);
435            self.range.start = outcome;
436            visitor.visit_borrowed(cx, bytes)?
437        };
438
439        cx.advance(n);
440        Ok(ok)
441    }
442
443    #[inline]
444    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
445    where
446        C: Context,
447    {
448        let outcome = bounds_check_add(cx, &self.range, n)?;
449        ptr.copy_from_nonoverlapping(self.range.start, n);
450        self.range.start = outcome;
451        cx.advance(n);
452        Ok(())
453    }
454
455    #[inline]
456    fn peek(&mut self) -> Option<u8> {
457        if self.range.start == self.range.end {
458            return None;
459        }
460
461        // SAFETY: we've checked that the elements are in bound above.
462        unsafe { Some(ptr::read(self.range.start)) }
463    }
464
465    #[inline]
466    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
467    where
468        C: Context,
469    {
470        let outcome = bounds_check_add(cx, &self.range, buf.len())?;
471
472        // SAFETY: We've checked that the updated pointer is in bounds.
473        unsafe {
474            self.range
475                .start
476                .copy_to_nonoverlapping(buf.as_mut_ptr(), buf.len());
477            self.range.start = outcome;
478        }
479
480        cx.advance(buf.len());
481        Ok(())
482    }
483}
484
485#[inline]
486fn bounds_check_add<C>(cx: C, range: &Range<*const u8>, len: usize) -> Result<*const u8, C::Error>
487where
488    C: Context,
489{
490    let outcome = range.start.wrapping_add(len);
491
492    if outcome > range.end || outcome < range.start {
493        Err(cx.message(SliceUnderflow {
494            n: len,
495            remaining: (range.end as usize).wrapping_sub(range.start as usize),
496        }))
497    } else {
498        Ok(outcome)
499    }
500}
501
502/// Limit the number of bytes that can be read out of a reader to the specified limit.
503///
504/// Constructed through [Reader::limit].
505pub struct Limit<R> {
506    remaining: usize,
507    reader: R,
508}
509
510impl<R> Limit<R> {
511    /// Get the remaining data in the limited reader.
512    #[inline]
513    pub fn remaining(&self) -> usize {
514        self.remaining
515    }
516}
517
518impl<'de, R> Limit<R>
519where
520    R: Reader<'de>,
521{
522    #[inline]
523    fn bounds_check<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
524    where
525        C: Context,
526    {
527        match self.remaining.checked_sub(n) {
528            Some(remaining) => {
529                self.remaining = remaining;
530                Ok(())
531            }
532            None => Err(cx.message("Reader out of bounds")),
533        }
534    }
535}
536
537impl<'de, R> Reader<'de> for Limit<R>
538where
539    R: Reader<'de>,
540{
541    type Mut<'this>
542        = &'this mut Self
543    where
544        Self: 'this;
545
546    type TryClone = Limit<R::TryClone>;
547
548    #[inline]
549    fn borrow_mut(&mut self) -> Self::Mut<'_> {
550        self
551    }
552
553    #[inline]
554    fn try_clone(&self) -> Option<Self::TryClone> {
555        Some(Limit {
556            remaining: self.remaining,
557            reader: self.reader.try_clone()?,
558        })
559    }
560
561    #[inline]
562    fn is_eof(&mut self) -> bool {
563        self.remaining == 0 || self.reader.is_eof()
564    }
565
566    #[inline]
567    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
568    where
569        C: Context,
570    {
571        self.bounds_check(cx, n)?;
572        self.reader.skip(cx, n)
573    }
574
575    #[inline]
576    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
577    where
578        C: Context,
579        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
580    {
581        self.bounds_check(cx, n)?;
582        self.reader.read_bytes(cx, n, visitor)
583    }
584
585    #[inline]
586    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
587    where
588        C: Context,
589    {
590        self.bounds_check(cx, n)?;
591        self.reader.read_bytes_uninit(cx, ptr, n)
592    }
593
594    #[inline]
595    fn peek(&mut self) -> Option<u8> {
596        if self.remaining > 0 {
597            self.reader.peek()
598        } else {
599            None
600        }
601    }
602
603    #[inline]
604    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
605    where
606        C: Context,
607    {
608        self.bounds_check(cx, buf.len())?;
609        self.reader.read(cx, buf)
610    }
611
612    #[inline]
613    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
614    where
615        C: Context,
616    {
617        self.bounds_check(cx, 1)?;
618        self.reader.read_byte(cx)
619    }
620
621    #[inline]
622    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
623    where
624        C: Context,
625    {
626        self.bounds_check(cx, N)?;
627        self.reader.read_array(cx)
628    }
629}
630
631impl<'a, 'de, R> IntoReader<'de> for &'a mut R
632where
633    R: ?Sized + Reader<'de>,
634{
635    type Reader = &'a mut R;
636
637    #[inline]
638    fn into_reader(self) -> Self::Reader {
639        self
640    }
641}
642
643impl<'de, R> Reader<'de> for &mut R
644where
645    R: ?Sized + Reader<'de>,
646{
647    type Mut<'this>
648        = &'this mut R
649    where
650        Self: 'this;
651
652    type TryClone = R::TryClone;
653
654    #[inline]
655    fn borrow_mut(&mut self) -> Self::Mut<'_> {
656        self
657    }
658
659    #[inline]
660    fn try_clone(&self) -> Option<Self::TryClone> {
661        (**self).try_clone()
662    }
663
664    #[inline]
665    fn is_eof(&mut self) -> bool {
666        (**self).is_eof()
667    }
668
669    #[inline]
670    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
671    where
672        C: Context,
673    {
674        (**self).skip(cx, n)
675    }
676
677    #[inline]
678    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
679    where
680        C: Context,
681        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
682    {
683        (**self).read_bytes(cx, n, visitor)
684    }
685
686    #[inline]
687    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
688    where
689        C: Context,
690    {
691        (**self).read_bytes_uninit(cx, ptr, n)
692    }
693
694    #[inline]
695    fn peek(&mut self) -> Option<u8> {
696        (**self).peek()
697    }
698
699    #[inline]
700    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
701    where
702        C: Context,
703    {
704        (**self).read(cx, buf)
705    }
706
707    #[inline]
708    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
709    where
710        C: Context,
711    {
712        (**self).read_byte(cx)
713    }
714
715    #[inline]
716    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
717    where
718        C: Context,
719    {
720        (**self).read_array(cx)
721    }
722}
723
724/// Underflow when trying to read from a slice.
725#[derive(Debug)]
726pub(crate) struct SliceUnderflow {
727    n: usize,
728    remaining: usize,
729}
730
731impl SliceUnderflow {
732    #[inline]
733    pub(crate) fn new(n: usize, remaining: usize) -> Self {
734        Self { n, remaining }
735    }
736}
737
738impl fmt::Display for SliceUnderflow {
739    #[inline]
740    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
741        let SliceUnderflow { n, remaining } = self;
742
743        write!(
744            f,
745            "Tried to read {n} bytes from slice, with {remaining} byte remaining"
746        )
747    }
748}
749
750impl core::error::Error for SliceUnderflow {}