Skip to main content

musli/
reader.rs

1//! Trait governing how to read bytes.
2
3use core::fmt;
4use core::marker;
5use core::mem::MaybeUninit;
6use core::ops::Range;
7use core::ptr;
8use core::slice;
9
10use crate::Context;
11use crate::de::UnsizedVisitor;
12
13mod sealed {
14    use super::{Limit, Reader};
15
16    pub trait Sealed {}
17
18    impl Sealed for &[u8] {}
19    impl Sealed for super::SliceReader<'_> {}
20    impl<'de, R> Sealed for Limit<R> where R: Reader<'de> {}
21    impl<'de, R> Sealed for &mut R where R: ?Sized + Reader<'de> {}
22}
23
24/// Coerce a type into a [`Reader`].
25pub trait IntoReader<'de>: self::sealed::Sealed {
26    /// The reader type.
27    type Reader: Reader<'de>;
28
29    /// Convert the type into a reader.
30    fn into_reader(self) -> Self::Reader;
31}
32
33/// Trait governing how a source of bytes is read.
34///
35/// This requires the reader to be able to hand out contiguous references to the
36/// byte source through [`Reader::read_bytes`].
37pub trait Reader<'de>: self::sealed::Sealed {
38    /// Type borrowed from self.
39    ///
40    /// Why oh why would we want to do this over having a simple `&'this mut T`?
41    ///
42    /// We want to avoid recursive types, which will blow up the compiler. And
43    /// the above is a typical example of when that can go wrong. This ensures
44    /// that each call to `borrow_mut` dereferences the [`Reader`] at each step
45    /// to avoid constructing a large muted type, like `&mut &mut &mut
46    /// SliceReader<'de>`.
47    type Mut<'this>: Reader<'de>
48    where
49        Self: 'this;
50
51    /// Type that can be cloned from the reader.
52    type TryClone: Reader<'de>;
53
54    /// Borrow the current reader.
55    fn borrow_mut(&mut self) -> Self::Mut<'_>;
56
57    /// Try to clone the reader.
58    fn try_clone(&self) -> Option<Self::TryClone>;
59
60    /// Test if the reader is at end of input.
61    fn is_eof(&mut self) -> bool;
62
63    /// Skip over the given number of bytes.
64    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
65    where
66        C: Context;
67
68    /// Peek the next value.
69    fn peek(&mut self) -> Option<u8>;
70
71    /// Read a slice into the given buffer.
72    #[inline]
73    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
74    where
75        C: Context,
76    {
77        struct Visitor<'a>(&'a mut [u8]);
78
79        #[crate::trait_defaults(crate)]
80        impl<C> UnsizedVisitor<'_, C, [u8]> for Visitor<'_>
81        where
82            C: Context,
83        {
84            type Ok = ();
85
86            #[inline]
87            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88                write!(f, "bytes")
89            }
90
91            #[inline]
92            fn visit_ref(self, _: C, bytes: &[u8]) -> Result<Self::Ok, C::Error> {
93                self.0.copy_from_slice(bytes);
94                Ok(())
95            }
96        }
97
98        self.read_bytes(cx, buf.len(), Visitor(buf))
99    }
100
101    /// Read a slice out of the current reader.
102    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
103    where
104        C: Context,
105        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>;
106
107    /// Read into the given buffer which might not have been initialized.
108    ///
109    /// # Safety
110    ///
111    /// The caller must ensure that the buffer points to valid memory of length
112    /// `len`.
113    unsafe fn read_bytes_uninit<C>(
114        &mut self,
115        cx: C,
116        ptr: *mut u8,
117        len: usize,
118    ) -> Result<(), C::Error>
119    where
120        C: Context;
121
122    /// Read a single byte.
123    #[inline]
124    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
125    where
126        C: Context,
127    {
128        let [byte] = self.read_array::<C, 1>(cx)?;
129        Ok(byte)
130    }
131
132    /// Read an array out of the current reader.
133    #[inline]
134    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
135    where
136        C: Context,
137    {
138        struct Visitor<const N: usize>([u8; N]);
139
140        #[crate::trait_defaults(crate)]
141        impl<const N: usize, C> UnsizedVisitor<'_, C, [u8]> for Visitor<N>
142        where
143            C: Context,
144        {
145            type Ok = [u8; N];
146
147            #[inline]
148            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149                write!(f, "bytes")
150            }
151
152            #[inline]
153            fn visit_ref(mut self, cx: C, bytes: &[u8]) -> Result<Self::Ok, C::Error> {
154                self.0.copy_from_slice(bytes);
155                cx.advance(bytes.len());
156                Ok(self.0)
157            }
158        }
159
160        self.read_bytes(cx, N, Visitor([0u8; N]))
161    }
162
163    /// Keep an accurate record of the position within the reader.
164    #[inline]
165    fn limit(self, limit: usize) -> Limit<Self>
166    where
167        Self: Sized,
168    {
169        Limit {
170            remaining: limit,
171            reader: self,
172        }
173    }
174}
175
176impl<'de> IntoReader<'de> for &'de [u8] {
177    type Reader = &'de [u8];
178
179    #[inline]
180    fn into_reader(self) -> Self::Reader {
181        self
182    }
183}
184
185impl<'de> Reader<'de> for &'de [u8] {
186    type Mut<'this>
187        = &'this mut &'de [u8]
188    where
189        Self: 'this;
190
191    type TryClone = &'de [u8];
192
193    #[inline]
194    fn borrow_mut(&mut self) -> Self::Mut<'_> {
195        self
196    }
197
198    #[inline]
199    fn try_clone(&self) -> Option<Self::TryClone> {
200        Some(self)
201    }
202
203    #[inline]
204    fn is_eof(&mut self) -> bool {
205        self.is_empty()
206    }
207
208    #[inline]
209    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
210    where
211        C: Context,
212    {
213        if self.len() < n {
214            return Err(cx.message(SliceUnderflow {
215                n,
216                remaining: self.len(),
217            }));
218        }
219
220        let (_, tail) = self.split_at(n);
221        *self = tail;
222        cx.advance(n);
223        Ok(())
224    }
225
226    #[inline]
227    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
228    where
229        C: Context,
230    {
231        if self.len() < buf.len() {
232            return Err(cx.custom(SliceUnderflow::new(buf.len(), self.len())));
233        }
234
235        let (head, tail) = self.split_at(buf.len());
236        buf.copy_from_slice(head);
237        *self = tail;
238        cx.advance(buf.len());
239        Ok(())
240    }
241
242    #[inline]
243    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
244    where
245        C: Context,
246        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
247    {
248        if self.len() < n {
249            return Err(cx.custom(SliceUnderflow::new(n, self.len())));
250        }
251
252        let (head, tail) = self.split_at(n);
253        *self = tail;
254        let ok = visitor.visit_borrowed(cx, head)?;
255        cx.advance(n);
256        Ok(ok)
257    }
258
259    #[inline]
260    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
261    where
262        C: Context,
263    {
264        if self.len() < n {
265            return Err(cx.custom(SliceUnderflow::new(n, self.len())));
266        }
267
268        unsafe {
269            ptr.copy_from_nonoverlapping(self.as_ptr(), n);
270        }
271
272        *self = &self[n..];
273        cx.advance(n);
274        Ok(())
275    }
276
277    #[inline]
278    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
279    where
280        C: Context,
281    {
282        let &[first, ref tail @ ..] = *self else {
283            return Err(cx.custom(SliceUnderflow::new(1, self.len())));
284        };
285
286        *self = tail;
287        cx.advance(1);
288        Ok(first)
289    }
290
291    #[inline]
292    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
293    where
294        C: Context,
295    {
296        if self.len() < N {
297            return Err(cx.custom(SliceUnderflow::new(N, self.len())));
298        }
299
300        cx.advance(N);
301
302        let mut array: MaybeUninit<[u8; N]> = MaybeUninit::uninit();
303
304        // SAFETY: We've checked the length of the current buffer just above.
305        // PERFORMANCE: This generates better code than `array::from_fn`, and
306        // `read_array` is performance sensitive.
307        unsafe {
308            array
309                .as_mut_ptr()
310                .cast::<u8>()
311                .copy_from_nonoverlapping(self.as_ptr(), N);
312            *self = self.get_unchecked(N..);
313            Ok(array.assume_init())
314        }
315    }
316
317    #[inline]
318    fn peek(&mut self) -> Option<u8> {
319        self.first().copied()
320    }
321}
322
323/// An efficient [`Reader`] wrapper around a slice.
324pub struct SliceReader<'de> {
325    range: Range<*const u8>,
326    _marker: marker::PhantomData<&'de [u8]>,
327}
328
329// SAFETY: `SliceReader` is effectively equivalent to `&'de [u8]`.
330unsafe impl Send for SliceReader<'_> {}
331// SAFETY: `SliceReader` is effectively equivalent to `&'de [u8]`.
332unsafe impl Sync for SliceReader<'_> {}
333
334impl<'de> SliceReader<'de> {
335    /// Construct a new instance around the specified slice.
336    ///
337    /// # Examples
338    ///
339    /// ```
340    /// use musli::reader::SliceReader;
341    ///
342    /// let data = &[1, 2, 3, 4];
343    /// let reader = SliceReader::new(data);
344    /// assert_eq!(reader.as_slice(), &[1, 2, 3, 4]);
345    /// ```
346    #[inline]
347    pub fn new(slice: &'de [u8]) -> Self {
348        Self {
349            range: slice.as_ptr_range(),
350            _marker: marker::PhantomData,
351        }
352    }
353
354    /// Get the remaining contents of the reader as a slice.
355    ///
356    /// # Examples
357    ///
358    /// ```
359    /// use musli::Context;
360    /// use musli::reader::{Reader, SliceReader};
361    ///
362    /// fn process<C>(cx: C) -> Result<(), C::Error>
363    /// where
364    ///     C: Context
365    /// {
366    ///     let mut reader = SliceReader::new(&[1, 2, 3, 4]);
367    ///     assert_eq!(reader.as_slice(), &[1, 2, 3, 4]);
368    ///     reader.skip(cx, 2)?;
369    ///     assert_eq!(reader.as_slice(), &[3, 4]);
370    ///     Ok(())
371    /// }
372    /// ```
373    #[inline]
374    pub fn as_slice(&self) -> &'de [u8] {
375        unsafe { slice::from_raw_parts(self.range.start, self.remaining()) }
376    }
377
378    /// Get remaining bytes in the reader.
379    ///
380    /// # Examples
381    ///
382    /// ```
383    /// use musli::Context;
384    /// use musli::reader::{Reader, SliceReader};
385    ///
386    /// fn process<C>(cx: C) -> Result<(), C::Error>
387    /// where
388    ///     C: Context
389    /// {
390    ///     let mut reader = SliceReader::new(&[1, 2, 3, 4]);
391    ///     assert_eq!(reader.remaining(), 4);
392    ///     reader.skip(cx, 2);
393    ///     assert_eq!(reader.remaining(), 2);
394    ///     Ok(())
395    /// }
396    /// ```
397    pub fn remaining(&self) -> usize {
398        self.range.end as usize - self.range.start as usize
399    }
400}
401
402impl<'de> Reader<'de> for SliceReader<'de> {
403    type Mut<'this>
404        = &'this mut Self
405    where
406        Self: 'this;
407
408    type TryClone = Self;
409
410    #[inline]
411    fn borrow_mut(&mut self) -> Self::Mut<'_> {
412        self
413    }
414
415    #[inline]
416    fn try_clone(&self) -> Option<Self::TryClone> {
417        Some(Self {
418            range: self.range.clone(),
419            _marker: marker::PhantomData,
420        })
421    }
422
423    #[inline]
424    fn is_eof(&mut self) -> bool {
425        self.range.start == self.range.end
426    }
427
428    #[inline]
429    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
430    where
431        C: Context,
432    {
433        self.range.start = bounds_check_add(cx, &self.range, n)?;
434        cx.advance(n);
435        Ok(())
436    }
437
438    #[inline]
439    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
440    where
441        C: Context,
442        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
443    {
444        let outcome = bounds_check_add(cx, &self.range, n)?;
445
446        let ok = unsafe {
447            let bytes = slice::from_raw_parts(self.range.start, n);
448            self.range.start = outcome;
449            visitor.visit_borrowed(cx, bytes)?
450        };
451
452        cx.advance(n);
453        Ok(ok)
454    }
455
456    #[inline]
457    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
458    where
459        C: Context,
460    {
461        let outcome = bounds_check_add(cx, &self.range, n)?;
462
463        unsafe {
464            ptr.copy_from_nonoverlapping(self.range.start, n);
465        }
466
467        self.range.start = outcome;
468        cx.advance(n);
469        Ok(())
470    }
471
472    #[inline]
473    fn peek(&mut self) -> Option<u8> {
474        if self.range.start == self.range.end {
475            return None;
476        }
477
478        // SAFETY: we've checked that the elements are in bound above.
479        unsafe { Some(ptr::read(self.range.start)) }
480    }
481
482    #[inline]
483    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
484    where
485        C: Context,
486    {
487        let outcome = bounds_check_add(cx, &self.range, buf.len())?;
488
489        // SAFETY: We've checked that the updated pointer is in bounds.
490        unsafe {
491            self.range
492                .start
493                .copy_to_nonoverlapping(buf.as_mut_ptr(), buf.len());
494            self.range.start = outcome;
495        }
496
497        cx.advance(buf.len());
498        Ok(())
499    }
500}
501
502#[inline]
503fn bounds_check_add<C>(cx: C, range: &Range<*const u8>, len: usize) -> Result<*const u8, C::Error>
504where
505    C: Context,
506{
507    let outcome = range.start.wrapping_add(len);
508
509    if outcome > range.end || outcome < range.start {
510        Err(cx.message(SliceUnderflow {
511            n: len,
512            remaining: (range.end as usize).wrapping_sub(range.start as usize),
513        }))
514    } else {
515        Ok(outcome)
516    }
517}
518
519/// Limit the number of bytes that can be read out of a reader to the specified limit.
520///
521/// This type wraps another reader and ensures that no more than a specified number
522/// of bytes can be read from it. This is useful for implementing bounded reads
523/// in serialization contexts.
524///
525/// Constructed through [Reader::limit].
526///
527/// # Examples
528///
529/// ```
530/// use musli::Context;
531/// use musli::reader::{Reader, SliceReader};
532///
533/// let cx = musli::context::new();
534/// let data = &[1, 2, 3, 4, 5];
535/// let mut reader = SliceReader::new(data);
536/// let mut limited = reader.limit(3);
537///
538/// // Can read from the limited reader
539/// let byte = limited.read_byte(&cx)?;
540/// assert_eq!(byte, 1);
541/// assert_eq!(limited.remaining(), 2);
542///
543/// // Read two more bytes
544/// limited.read_byte(&cx)?;
545/// limited.read_byte(&cx)?;
546/// assert_eq!(limited.remaining(), 0);
547/// # Ok::<_, musli::context::ErrorMarker>(())
548/// ```
549pub struct Limit<R> {
550    remaining: usize,
551    reader: R,
552}
553
554impl<R> Limit<R> {
555    /// Get the remaining data in the limited reader.
556    ///
557    /// Returns the number of bytes that can still be read from this limited reader
558    /// before the limit is reached.
559    ///
560    /// # Examples
561    ///
562    /// ```
563    /// use musli::reader::{Reader, SliceReader};
564    ///
565    /// let data = &[1, 2, 3, 4, 5];
566    /// let mut reader = SliceReader::new(data);
567    /// let limited = reader.limit(3);
568    ///
569    /// assert_eq!(limited.remaining(), 3);
570    /// ```
571    #[inline]
572    pub fn remaining(&self) -> usize {
573        self.remaining
574    }
575}
576
577impl<'de, R> Limit<R>
578where
579    R: Reader<'de>,
580{
581    #[inline]
582    fn bounds_check<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
583    where
584        C: Context,
585    {
586        match self.remaining.checked_sub(n) {
587            Some(remaining) => {
588                self.remaining = remaining;
589                Ok(())
590            }
591            None => Err(cx.message("Reader out of bounds")),
592        }
593    }
594}
595
596impl<'de, R> Reader<'de> for Limit<R>
597where
598    R: Reader<'de>,
599{
600    type Mut<'this>
601        = &'this mut Self
602    where
603        Self: 'this;
604
605    type TryClone = Limit<R::TryClone>;
606
607    #[inline]
608    fn borrow_mut(&mut self) -> Self::Mut<'_> {
609        self
610    }
611
612    #[inline]
613    fn try_clone(&self) -> Option<Self::TryClone> {
614        Some(Limit {
615            remaining: self.remaining,
616            reader: self.reader.try_clone()?,
617        })
618    }
619
620    #[inline]
621    fn is_eof(&mut self) -> bool {
622        self.remaining == 0 || self.reader.is_eof()
623    }
624
625    #[inline]
626    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
627    where
628        C: Context,
629    {
630        self.bounds_check(cx, n)?;
631        self.reader.skip(cx, n)
632    }
633
634    #[inline]
635    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
636    where
637        C: Context,
638        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
639    {
640        self.bounds_check(cx, n)?;
641        self.reader.read_bytes(cx, n, visitor)
642    }
643
644    #[inline]
645    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
646    where
647        C: Context,
648    {
649        self.bounds_check(cx, n)?;
650
651        unsafe { self.reader.read_bytes_uninit(cx, ptr, n) }
652    }
653
654    #[inline]
655    fn peek(&mut self) -> Option<u8> {
656        if self.remaining > 0 {
657            self.reader.peek()
658        } else {
659            None
660        }
661    }
662
663    #[inline]
664    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
665    where
666        C: Context,
667    {
668        self.bounds_check(cx, buf.len())?;
669        self.reader.read(cx, buf)
670    }
671
672    #[inline]
673    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
674    where
675        C: Context,
676    {
677        self.bounds_check(cx, 1)?;
678        self.reader.read_byte(cx)
679    }
680
681    #[inline]
682    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
683    where
684        C: Context,
685    {
686        self.bounds_check(cx, N)?;
687        self.reader.read_array(cx)
688    }
689}
690
691impl<'a, 'de, R> IntoReader<'de> for &'a mut R
692where
693    R: ?Sized + Reader<'de>,
694{
695    type Reader = &'a mut R;
696
697    #[inline]
698    fn into_reader(self) -> Self::Reader {
699        self
700    }
701}
702
703impl<'de, R> Reader<'de> for &mut R
704where
705    R: ?Sized + Reader<'de>,
706{
707    type Mut<'this>
708        = &'this mut R
709    where
710        Self: 'this;
711
712    type TryClone = R::TryClone;
713
714    #[inline]
715    fn borrow_mut(&mut self) -> Self::Mut<'_> {
716        self
717    }
718
719    #[inline]
720    fn try_clone(&self) -> Option<Self::TryClone> {
721        (**self).try_clone()
722    }
723
724    #[inline]
725    fn is_eof(&mut self) -> bool {
726        (**self).is_eof()
727    }
728
729    #[inline]
730    fn skip<C>(&mut self, cx: C, n: usize) -> Result<(), C::Error>
731    where
732        C: Context,
733    {
734        (**self).skip(cx, n)
735    }
736
737    #[inline]
738    fn read_bytes<C, V>(&mut self, cx: C, n: usize, visitor: V) -> Result<V::Ok, V::Error>
739    where
740        C: Context,
741        V: UnsizedVisitor<'de, C, [u8], Error = C::Error, Allocator = C::Allocator>,
742    {
743        (**self).read_bytes(cx, n, visitor)
744    }
745
746    #[inline]
747    unsafe fn read_bytes_uninit<C>(&mut self, cx: C, ptr: *mut u8, n: usize) -> Result<(), C::Error>
748    where
749        C: Context,
750    {
751        unsafe { (**self).read_bytes_uninit(cx, ptr, n) }
752    }
753
754    #[inline]
755    fn peek(&mut self) -> Option<u8> {
756        (**self).peek()
757    }
758
759    #[inline]
760    fn read<C>(&mut self, cx: C, buf: &mut [u8]) -> Result<(), C::Error>
761    where
762        C: Context,
763    {
764        (**self).read(cx, buf)
765    }
766
767    #[inline]
768    fn read_byte<C>(&mut self, cx: C) -> Result<u8, C::Error>
769    where
770        C: Context,
771    {
772        (**self).read_byte(cx)
773    }
774
775    #[inline]
776    fn read_array<C, const N: usize>(&mut self, cx: C) -> Result<[u8; N], C::Error>
777    where
778        C: Context,
779    {
780        (**self).read_array(cx)
781    }
782}
783
784/// Underflow when trying to read from a slice.
785#[derive(Debug)]
786pub(crate) struct SliceUnderflow {
787    n: usize,
788    remaining: usize,
789}
790
791impl SliceUnderflow {
792    #[inline]
793    pub(crate) fn new(n: usize, remaining: usize) -> Self {
794        Self { n, remaining }
795    }
796}
797
798impl fmt::Display for SliceUnderflow {
799    #[inline]
800    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
801        let SliceUnderflow { n, remaining } = self;
802
803        write!(
804            f,
805            "Tried to read {n} bytes from slice, with {remaining} byte remaining"
806        )
807    }
808}
809
810impl core::error::Error for SliceUnderflow {}