bstr/byteset/
scalar.rs

1// This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2// the 'inverse' query of memchr, e.g. finding the first byte not in the
3// provided set. This is simple for the 1-byte case.
4
5use core::{cmp, usize};
6
7const USIZE_BYTES: usize = core::mem::size_of::<usize>();
8const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1;
9
10// The number of bytes to loop at in one iteration of memchr/memrchr.
11const LOOP_SIZE: usize = 2 * USIZE_BYTES;
12
13/// Repeat the given byte into a word size number. That is, every 8 bits
14/// is equivalent to the given byte. For example, if `b` is `\x4E` or
15/// `01001110` in binary, then the returned value on a 32-bit system would be:
16/// `01001110_01001110_01001110_01001110`.
17#[inline(always)]
18fn repeat_byte(b: u8) -> usize {
19    (b as usize) * (usize::MAX / 255)
20}
21
22pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
23    let vn1 = repeat_byte(n1);
24    let confirm = |byte| byte != n1;
25    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
26    let start_ptr = haystack.as_ptr();
27
28    unsafe {
29        let end_ptr = haystack.as_ptr().add(haystack.len());
30        let mut ptr = start_ptr;
31
32        if haystack.len() < USIZE_BYTES {
33            return forward_search(start_ptr, end_ptr, ptr, confirm);
34        }
35
36        let chunk = read_unaligned_usize(ptr);
37        if (chunk ^ vn1) != 0 {
38            return forward_search(start_ptr, end_ptr, ptr, confirm);
39        }
40
41        ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & ALIGN_MASK));
42        debug_assert!(ptr > start_ptr);
43        debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
44        while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
45            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
46
47            let a = *(ptr as *const usize);
48            let b = *(ptr.add(USIZE_BYTES) as *const usize);
49            let eqa = (a ^ vn1) != 0;
50            let eqb = (b ^ vn1) != 0;
51            if eqa || eqb {
52                break;
53            }
54            ptr = ptr.add(LOOP_SIZE);
55        }
56        forward_search(start_ptr, end_ptr, ptr, confirm)
57    }
58}
59
60/// Return the last index not matching the byte `x` in `text`.
61pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
62    let vn1 = repeat_byte(n1);
63    let confirm = |byte| byte != n1;
64    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
65    let start_ptr = haystack.as_ptr();
66
67    unsafe {
68        let end_ptr = haystack.as_ptr().add(haystack.len());
69        let mut ptr = end_ptr;
70
71        if haystack.len() < USIZE_BYTES {
72            return reverse_search(start_ptr, end_ptr, ptr, confirm);
73        }
74
75        let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
76        if (chunk ^ vn1) != 0 {
77            return reverse_search(start_ptr, end_ptr, ptr, confirm);
78        }
79
80        ptr = ptr.sub(end_ptr as usize & ALIGN_MASK);
81        debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
82        while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
83            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
84
85            let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
86            let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
87            let eqa = (a ^ vn1) != 0;
88            let eqb = (b ^ vn1) != 0;
89            if eqa || eqb {
90                break;
91            }
92            ptr = ptr.sub(loop_size);
93        }
94        reverse_search(start_ptr, end_ptr, ptr, confirm)
95    }
96}
97
98#[inline(always)]
99unsafe fn forward_search<F: Fn(u8) -> bool>(
100    start_ptr: *const u8,
101    end_ptr: *const u8,
102    mut ptr: *const u8,
103    confirm: F,
104) -> Option<usize> {
105    debug_assert!(start_ptr <= ptr);
106    debug_assert!(ptr <= end_ptr);
107
108    while ptr < end_ptr {
109        if confirm(*ptr) {
110            return Some(sub(ptr, start_ptr));
111        }
112        ptr = ptr.offset(1);
113    }
114    None
115}
116
117#[inline(always)]
118unsafe fn reverse_search<F: Fn(u8) -> bool>(
119    start_ptr: *const u8,
120    end_ptr: *const u8,
121    mut ptr: *const u8,
122    confirm: F,
123) -> Option<usize> {
124    debug_assert!(start_ptr <= ptr);
125    debug_assert!(ptr <= end_ptr);
126
127    while ptr > start_ptr {
128        ptr = ptr.offset(-1);
129        if confirm(*ptr) {
130            return Some(sub(ptr, start_ptr));
131        }
132    }
133    None
134}
135
136unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
137    (ptr as *const usize).read_unaligned()
138}
139
140/// Subtract `b` from `a` and return the difference. `a` should be greater than
141/// or equal to `b`.
142fn sub(a: *const u8, b: *const u8) -> usize {
143    debug_assert!(a >= b);
144    (a as usize) - (b as usize)
145}
146
147/// Safe wrapper around `forward_search`
148#[inline]
149pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
150    s: &[u8],
151    confirm: F,
152) -> Option<usize> {
153    unsafe {
154        let start = s.as_ptr();
155        let end = start.add(s.len());
156        forward_search(start, end, start, confirm)
157    }
158}
159
160/// Safe wrapper around `reverse_search`
161#[inline]
162pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
163    s: &[u8],
164    confirm: F,
165) -> Option<usize> {
166    unsafe {
167        let start = s.as_ptr();
168        let end = start.add(s.len());
169        reverse_search(start, end, end, confirm)
170    }
171}
172
173#[cfg(all(test, feature = "std"))]
174mod tests {
175    use alloc::{vec, vec::Vec};
176
177    use super::{inv_memchr, inv_memrchr};
178
179    // search string, search byte, inv_memchr result, inv_memrchr result.
180    // these are expanded into a much larger set of tests in build_tests
181    const TESTS: &[(&[u8], u8, usize, usize)] = &[
182        (b"z", b'a', 0, 0),
183        (b"zz", b'a', 0, 1),
184        (b"aza", b'a', 1, 1),
185        (b"zaz", b'a', 0, 2),
186        (b"zza", b'a', 0, 1),
187        (b"zaa", b'a', 0, 0),
188        (b"zzz", b'a', 0, 2),
189    ];
190
191    type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
192
193    fn build_tests() -> Vec<TestCase> {
194        #[cfg(not(miri))]
195        const MAX_PER: usize = 515;
196        #[cfg(miri)]
197        const MAX_PER: usize = 10;
198
199        let mut result = vec![];
200        for &(search, byte, fwd_pos, rev_pos) in TESTS {
201            result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
202            for i in 1..MAX_PER {
203                // add a bunch of copies of the search byte to the end.
204                let mut suffixed: Vec<u8> = search.into();
205                suffixed.extend(std::iter::repeat(byte).take(i));
206                result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
207
208                // add a bunch of copies of the search byte to the start.
209                let mut prefixed: Vec<u8> =
210                    std::iter::repeat(byte).take(i).collect();
211                prefixed.extend(search);
212                result.push((
213                    prefixed,
214                    byte,
215                    Some((fwd_pos + i, rev_pos + i)),
216                ));
217
218                // add a bunch of copies of the search byte to both ends.
219                let mut surrounded: Vec<u8> =
220                    std::iter::repeat(byte).take(i).collect();
221                surrounded.extend(search);
222                surrounded.extend(std::iter::repeat(byte).take(i));
223                result.push((
224                    surrounded,
225                    byte,
226                    Some((fwd_pos + i, rev_pos + i)),
227                ));
228            }
229        }
230
231        // build non-matching tests for several sizes
232        for i in 0..MAX_PER {
233            result.push((
234                std::iter::repeat(b'\0').take(i).collect(),
235                b'\0',
236                None,
237            ));
238        }
239
240        result
241    }
242
243    #[test]
244    fn test_inv_memchr() {
245        use crate::{ByteSlice, B};
246
247        #[cfg(not(miri))]
248        const MAX_OFFSET: usize = 130;
249        #[cfg(miri)]
250        const MAX_OFFSET: usize = 13;
251
252        for (search, byte, matching) in build_tests() {
253            assert_eq!(
254                inv_memchr(byte, &search),
255                matching.map(|m| m.0),
256                "inv_memchr when searching for {:?} in {:?}",
257                byte as char,
258                // better printing
259                B(&search).as_bstr(),
260            );
261            assert_eq!(
262                inv_memrchr(byte, &search),
263                matching.map(|m| m.1),
264                "inv_memrchr when searching for {:?} in {:?}",
265                byte as char,
266                // better printing
267                B(&search).as_bstr(),
268            );
269            // Test a rather large number off offsets for potential alignment
270            // issues.
271            for offset in 1..MAX_OFFSET {
272                if offset >= search.len() {
273                    break;
274                }
275                // If this would cause us to shift the results off the end,
276                // skip it so that we don't have to recompute them.
277                if let Some((f, r)) = matching {
278                    if offset > f || offset > r {
279                        break;
280                    }
281                }
282                let realigned = &search[offset..];
283
284                let forward_pos = matching.map(|m| m.0 - offset);
285                let reverse_pos = matching.map(|m| m.1 - offset);
286
287                assert_eq!(
288                    inv_memchr(byte, &realigned),
289                    forward_pos,
290                    "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
291                    offset,
292                    byte as char,
293                    realigned.as_bstr(),
294                );
295                assert_eq!(
296                    inv_memrchr(byte, &realigned),
297                    reverse_pos,
298                    "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
299                    offset,
300                    byte as char,
301                    realigned.as_bstr(),
302                );
303            }
304        }
305    }
306}