bstr/
ascii.rs

1// The following ~400 lines of code exists for exactly one purpose, which is
2// to optimize this code:
3//
4//     byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
5//
6// Yes... Overengineered is a word that comes to mind, but this is effectively
7// a very similar problem to memchr, and virtually nobody has been able to
8// resist optimizing the crap out of that (except for perhaps the BSD and MUSL
9// folks). In particular, this routine makes a very common case (ASCII) very
10// fast, which seems worth it. We do stop short of adding AVX variants of the
11// code below in order to retain our sanity and also to avoid needing to deal
12// with runtime target feature detection. RESIST!
13//
14// In order to understand the SIMD version below, it would be good to read this
15// comment describing how my memchr routine works:
16// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
17//
18// The primary difference with memchr is that for ASCII, we can do a bit less
19// work. In particular, we don't need to detect the presence of a specific
20// byte, but rather, whether any byte has its most significant bit set. That
21// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
22// _mm_movemask_epi8.
23
24#[cfg(any(test, miri, not(target_arch = "x86_64")))]
25const USIZE_BYTES: usize = core::mem::size_of::<usize>();
26#[cfg(any(test, miri, not(target_arch = "x86_64")))]
27const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1;
28#[cfg(any(test, miri, not(target_arch = "x86_64")))]
29const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;
30
31// This is a mask where the most significant bit of each byte in the usize
32// is set. We test this bit to determine whether a character is ASCII or not.
33// Namely, a single byte is regarded as an ASCII codepoint if and only if it's
34// most significant bit is not set.
35#[cfg(any(test, miri, not(target_arch = "x86_64")))]
36const ASCII_MASK_U64: u64 = 0x8080808080808080;
37#[cfg(any(test, miri, not(target_arch = "x86_64")))]
38const ASCII_MASK: usize = ASCII_MASK_U64 as usize;
39
40/// Returns the index of the first non ASCII byte in the given slice.
41///
42/// If slice only contains ASCII bytes, then the length of the slice is
43/// returned.
44pub fn first_non_ascii_byte(slice: &[u8]) -> usize {
45    #[cfg(any(miri, not(target_arch = "x86_64")))]
46    {
47        first_non_ascii_byte_fallback(slice)
48    }
49
50    #[cfg(all(not(miri), target_arch = "x86_64"))]
51    {
52        first_non_ascii_byte_sse2(slice)
53    }
54}
55
56#[cfg(any(test, miri, not(target_arch = "x86_64")))]
57fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
58    let start_ptr = slice.as_ptr();
59    let end_ptr = slice[slice.len()..].as_ptr();
60    let mut ptr = start_ptr;
61
62    unsafe {
63        if slice.len() < USIZE_BYTES {
64            return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
65        }
66
67        let chunk = read_unaligned_usize(ptr);
68        let mask = chunk & ASCII_MASK;
69        if mask != 0 {
70            return first_non_ascii_byte_mask(mask);
71        }
72
73        ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & ALIGN_MASK));
74        debug_assert!(ptr > start_ptr);
75        debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
76        if slice.len() >= FALLBACK_LOOP_SIZE {
77            while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
78                debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
79
80                let a = *(ptr as *const usize);
81                let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
82                if (a | b) & ASCII_MASK != 0 {
83                    // What a kludge. We wrap the position finding code into
84                    // a non-inlineable function, which makes the codegen in
85                    // the tight loop above a bit better by avoiding a
86                    // couple extra movs. We pay for it by two additional
87                    // stores, but only in the case of finding a non-ASCII
88                    // byte.
89                    #[inline(never)]
90                    unsafe fn findpos(
91                        start_ptr: *const u8,
92                        ptr: *const u8,
93                    ) -> usize {
94                        let a = *(ptr as *const usize);
95                        let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
96
97                        let mut at = sub(ptr, start_ptr);
98                        let maska = a & ASCII_MASK;
99                        if maska != 0 {
100                            return at + first_non_ascii_byte_mask(maska);
101                        }
102
103                        at += USIZE_BYTES;
104                        let maskb = b & ASCII_MASK;
105                        debug_assert!(maskb != 0);
106                        return at + first_non_ascii_byte_mask(maskb);
107                    }
108                    return findpos(start_ptr, ptr);
109                }
110                ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
111            }
112        }
113        first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
114    }
115}
116
117#[cfg(all(not(miri), target_arch = "x86_64"))]
118fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
119    use core::arch::x86_64::*;
120
121    const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>();
122    const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
123    const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;
124
125    let start_ptr = slice.as_ptr();
126    let end_ptr = slice[slice.len()..].as_ptr();
127    let mut ptr = start_ptr;
128
129    unsafe {
130        if slice.len() < VECTOR_SIZE {
131            return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
132        }
133
134        let chunk = _mm_loadu_si128(ptr as *const __m128i);
135        let mask = _mm_movemask_epi8(chunk);
136        if mask != 0 {
137            return mask.trailing_zeros() as usize;
138        }
139
140        ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
141        debug_assert!(ptr > start_ptr);
142        debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
143        if slice.len() >= VECTOR_LOOP_SIZE {
144            while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
145                debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
146
147                let a = _mm_load_si128(ptr as *const __m128i);
148                let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
149                let c =
150                    _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
151                let d =
152                    _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);
153
154                let or1 = _mm_or_si128(a, b);
155                let or2 = _mm_or_si128(c, d);
156                let or3 = _mm_or_si128(or1, or2);
157                if _mm_movemask_epi8(or3) != 0 {
158                    let mut at = sub(ptr, start_ptr);
159                    let mask = _mm_movemask_epi8(a);
160                    if mask != 0 {
161                        return at + mask.trailing_zeros() as usize;
162                    }
163
164                    at += VECTOR_SIZE;
165                    let mask = _mm_movemask_epi8(b);
166                    if mask != 0 {
167                        return at + mask.trailing_zeros() as usize;
168                    }
169
170                    at += VECTOR_SIZE;
171                    let mask = _mm_movemask_epi8(c);
172                    if mask != 0 {
173                        return at + mask.trailing_zeros() as usize;
174                    }
175
176                    at += VECTOR_SIZE;
177                    let mask = _mm_movemask_epi8(d);
178                    debug_assert!(mask != 0);
179                    return at + mask.trailing_zeros() as usize;
180                }
181                ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
182            }
183        }
184        while ptr <= end_ptr.sub(VECTOR_SIZE) {
185            debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);
186
187            let chunk = _mm_loadu_si128(ptr as *const __m128i);
188            let mask = _mm_movemask_epi8(chunk);
189            if mask != 0 {
190                return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
191            }
192            ptr = ptr.add(VECTOR_SIZE);
193        }
194        first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
195    }
196}
197
198#[inline(always)]
199unsafe fn first_non_ascii_byte_slow(
200    start_ptr: *const u8,
201    end_ptr: *const u8,
202    mut ptr: *const u8,
203) -> usize {
204    debug_assert!(start_ptr <= ptr);
205    debug_assert!(ptr <= end_ptr);
206
207    while ptr < end_ptr {
208        if *ptr > 0x7F {
209            return sub(ptr, start_ptr);
210        }
211        ptr = ptr.offset(1);
212    }
213    sub(end_ptr, start_ptr)
214}
215
216/// Compute the position of the first ASCII byte in the given mask.
217///
218/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
219/// 8 contiguous bytes of the slice being checked where *at least* one of those
220/// bytes is not an ASCII byte.
221///
222/// The position returned is always in the inclusive range [0, 7].
223#[cfg(any(test, miri, not(target_arch = "x86_64")))]
224fn first_non_ascii_byte_mask(mask: usize) -> usize {
225    #[cfg(target_endian = "little")]
226    {
227        mask.trailing_zeros() as usize / 8
228    }
229    #[cfg(target_endian = "big")]
230    {
231        mask.leading_zeros() as usize / 8
232    }
233}
234
235/// Increment the given pointer by the given amount.
236unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
237    ptr.add(amt)
238}
239
240/// Decrement the given pointer by the given amount.
241unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
242    ptr.sub(amt)
243}
244
245#[cfg(any(test, miri, not(target_arch = "x86_64")))]
246unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
247    use core::ptr;
248
249    let mut n: usize = 0;
250    ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES);
251    n
252}
253
254/// Subtract `b` from `a` and return the difference. `a` should be greater than
255/// or equal to `b`.
256fn sub(a: *const u8, b: *const u8) -> usize {
257    debug_assert!(a >= b);
258    (a as usize) - (b as usize)
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    // Our testing approach here is to try and exhaustively test every case.
266    // This includes the position at which a non-ASCII byte occurs in addition
267    // to the alignment of the slice that we're searching.
268
269    #[test]
270    fn positive_fallback_forward() {
271        for i in 0..517 {
272            let s = "a".repeat(i);
273            assert_eq!(
274                i,
275                first_non_ascii_byte_fallback(s.as_bytes()),
276                "i: {:?}, len: {:?}, s: {:?}",
277                i,
278                s.len(),
279                s
280            );
281        }
282    }
283
284    #[test]
285    #[cfg(target_arch = "x86_64")]
286    #[cfg(not(miri))]
287    fn positive_sse2_forward() {
288        for i in 0..517 {
289            let b = "a".repeat(i).into_bytes();
290            assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
291        }
292    }
293
294    #[test]
295    #[cfg(not(miri))]
296    fn negative_fallback_forward() {
297        for i in 0..517 {
298            for align in 0..65 {
299                let mut s = "a".repeat(i);
300                s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
301                let s = s.get(align..).unwrap_or("");
302                assert_eq!(
303                    i.saturating_sub(align),
304                    first_non_ascii_byte_fallback(s.as_bytes()),
305                    "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
306                    i,
307                    align,
308                    s.len(),
309                    s
310                );
311            }
312        }
313    }
314
315    #[test]
316    #[cfg(target_arch = "x86_64")]
317    #[cfg(not(miri))]
318    fn negative_sse2_forward() {
319        for i in 0..517 {
320            for align in 0..65 {
321                let mut s = "a".repeat(i);
322                s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
323                let s = s.get(align..).unwrap_or("");
324                assert_eq!(
325                    i.saturating_sub(align),
326                    first_non_ascii_byte_sse2(s.as_bytes()),
327                    "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
328                    i,
329                    align,
330                    s.len(),
331                    s
332                );
333            }
334        }
335    }
336}