str_indices/
byte_chunk.rs

1#[cfg(target_arch = "x86_64")]
2use core::arch::x86_64;
3
4#[cfg(target_arch = "aarch64")]
5use core::arch::aarch64;
6
7// Which type to actually use at build time.
8#[cfg(all(feature = "simd", target_arch = "x86_64"))]
9pub(crate) type Chunk = x86_64::__m128i;
10#[cfg(all(feature = "simd", target_arch = "aarch64"))]
11pub(crate) type Chunk = aarch64::uint8x16_t;
12#[cfg(any(
13    not(feature = "simd"),
14    not(any(target_arch = "x86_64", target_arch = "aarch64"))
15))]
16pub(crate) type Chunk = usize;
17
18/// Interface for working with chunks of bytes at a time, providing the
19/// operations needed for the functionality in str_utils.
20pub(crate) trait ByteChunk: Copy + Clone {
21    /// Size of the chunk in bytes.
22    const SIZE: usize;
23
24    /// Maximum number of iterations the chunk can accumulate
25    /// before sum_bytes() becomes inaccurate.
26    const MAX_ACC: usize;
27
28    /// Creates a new chunk with all bytes set to zero.
29    fn zero() -> Self;
30
31    /// Creates a new chunk with all bytes set to n.
32    fn splat(n: u8) -> Self;
33
34    /// Returns whether all bytes are zero or not.
35    fn is_zero(&self) -> bool;
36
37    /// Shifts bytes back lexographically by n bytes.
38    fn shift_back_lex(&self, n: usize) -> Self;
39
40    /// Shifts the bottom byte of self into the top byte of n.
41    fn shift_across(&self, n: Self) -> Self;
42
43    /// Shifts bits to the right by n bits.
44    fn shr(&self, n: usize) -> Self;
45
46    /// Compares bytes for equality with the given byte.
47    ///
48    /// Bytes that are equal are set to 1, bytes that are not
49    /// are set to 0.
50    fn cmp_eq_byte(&self, byte: u8) -> Self;
51
52    /// Compares bytes to see if they're in the non-inclusive range (a, b),
53    /// where a < b <= 127.
54    ///
55    /// Bytes in the range are set to 1, bytes not in the range are set to 0.
56    fn bytes_between_127(&self, a: u8, b: u8) -> Self;
57
58    /// Performs a bitwise and on two chunks.
59    fn bitand(&self, other: Self) -> Self;
60
61    /// Adds the bytes of two chunks together.
62    fn add(&self, other: Self) -> Self;
63
64    /// Subtracts other's bytes from this chunk.
65    fn sub(&self, other: Self) -> Self;
66
67    /// Increments the nth-from-last lexographic byte by 1.
68    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self;
69
70    /// Decrements the last lexographic byte by 1.
71    fn dec_last_lex_byte(&self) -> Self;
72
73    /// Returns the sum of all bytes in the chunk.
74    fn sum_bytes(&self) -> usize;
75}
76
77impl ByteChunk for usize {
78    const SIZE: usize = core::mem::size_of::<usize>();
79    const MAX_ACC: usize = (256 / core::mem::size_of::<usize>()) - 1;
80
81    #[inline(always)]
82    fn zero() -> Self {
83        0
84    }
85
86    #[inline(always)]
87    fn splat(n: u8) -> Self {
88        const ONES: usize = core::usize::MAX / 0xFF;
89        ONES * n as usize
90    }
91
92    #[inline(always)]
93    fn is_zero(&self) -> bool {
94        *self == 0
95    }
96
97    #[inline(always)]
98    fn shift_back_lex(&self, n: usize) -> Self {
99        if cfg!(target_endian = "little") {
100            *self >> (n * 8)
101        } else {
102            *self << (n * 8)
103        }
104    }
105
106    #[inline(always)]
107    fn shift_across(&self, n: Self) -> Self {
108        let shift_distance = (Self::SIZE - 1) * 8;
109        if cfg!(target_endian = "little") {
110            (*self >> shift_distance) | (n << 8)
111        } else {
112            (*self << shift_distance) | (n >> 8)
113        }
114    }
115
116    #[inline(always)]
117    fn shr(&self, n: usize) -> Self {
118        *self >> n
119    }
120
121    #[inline(always)]
122    fn cmp_eq_byte(&self, byte: u8) -> Self {
123        const ONES: usize = core::usize::MAX / 0xFF;
124        const ONES_HIGH: usize = ONES << 7;
125        let word = *self ^ (byte as usize * ONES);
126        (!(((word & !ONES_HIGH) + !ONES_HIGH) | word) & ONES_HIGH) >> 7
127    }
128
129    #[inline(always)]
130    fn bytes_between_127(&self, a: u8, b: u8) -> Self {
131        const ONES: usize = core::usize::MAX / 0xFF;
132        const ONES_HIGH: usize = ONES << 7;
133        let tmp = *self & (ONES * 127);
134        (((ONES * (127 + b as usize) - tmp) & !*self & (tmp + (ONES * (127 - a as usize))))
135            & ONES_HIGH)
136            >> 7
137    }
138
139    #[inline(always)]
140    fn bitand(&self, other: Self) -> Self {
141        *self & other
142    }
143
144    #[inline(always)]
145    fn add(&self, other: Self) -> Self {
146        *self + other
147    }
148
149    #[inline(always)]
150    fn sub(&self, other: Self) -> Self {
151        *self - other
152    }
153
154    #[inline(always)]
155    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
156        if cfg!(target_endian = "little") {
157            *self + (1 << ((Self::SIZE - 1 - n) * 8))
158        } else {
159            *self + (1 << (n * 8))
160        }
161    }
162
163    #[inline(always)]
164    fn dec_last_lex_byte(&self) -> Self {
165        if cfg!(target_endian = "little") {
166            *self - (1 << ((Self::SIZE - 1) * 8))
167        } else {
168            *self - 1
169        }
170    }
171
172    #[inline(always)]
173    fn sum_bytes(&self) -> usize {
174        const ONES: usize = core::usize::MAX / 0xFF;
175        self.wrapping_mul(ONES) >> ((Self::SIZE - 1) * 8)
176    }
177}
178
179// Note: use only SSE2 and older instructions, since these are
180// guaranteed on all x86_64 platforms.
181#[cfg(target_arch = "x86_64")]
182impl ByteChunk for x86_64::__m128i {
183    const SIZE: usize = core::mem::size_of::<x86_64::__m128i>();
184    const MAX_ACC: usize = 255;
185
186    #[inline(always)]
187    fn zero() -> Self {
188        unsafe { x86_64::_mm_setzero_si128() }
189    }
190
191    #[inline(always)]
192    fn splat(n: u8) -> Self {
193        unsafe { x86_64::_mm_set1_epi8(n as i8) }
194    }
195
196    #[inline(always)]
197    fn is_zero(&self) -> bool {
198        let tmp = unsafe { core::mem::transmute::<Self, (u64, u64)>(*self) };
199        tmp.0 == 0 && tmp.1 == 0
200    }
201
202    #[inline(always)]
203    fn shift_back_lex(&self, n: usize) -> Self {
204        match n {
205            0 => *self,
206            1 => unsafe { x86_64::_mm_srli_si128(*self, 1) },
207            2 => unsafe { x86_64::_mm_srli_si128(*self, 2) },
208            3 => unsafe { x86_64::_mm_srli_si128(*self, 3) },
209            4 => unsafe { x86_64::_mm_srli_si128(*self, 4) },
210            _ => unreachable!(),
211        }
212    }
213
214    #[inline(always)]
215    fn shift_across(&self, n: Self) -> Self {
216        unsafe {
217            let bottom_byte = x86_64::_mm_srli_si128(*self, 15);
218            let rest_shifted = x86_64::_mm_slli_si128(n, 1);
219            x86_64::_mm_or_si128(bottom_byte, rest_shifted)
220        }
221    }
222
223    #[inline(always)]
224    fn shr(&self, n: usize) -> Self {
225        match n {
226            0 => *self,
227            1 => unsafe { x86_64::_mm_srli_epi64(*self, 1) },
228            2 => unsafe { x86_64::_mm_srli_epi64(*self, 2) },
229            3 => unsafe { x86_64::_mm_srli_epi64(*self, 3) },
230            4 => unsafe { x86_64::_mm_srli_epi64(*self, 4) },
231            _ => unreachable!(),
232        }
233    }
234
235    #[inline(always)]
236    fn cmp_eq_byte(&self, byte: u8) -> Self {
237        let tmp = unsafe { x86_64::_mm_cmpeq_epi8(*self, Self::splat(byte)) };
238        unsafe { x86_64::_mm_and_si128(tmp, Self::splat(1)) }
239    }
240
241    #[inline(always)]
242    fn bytes_between_127(&self, a: u8, b: u8) -> Self {
243        let tmp1 = unsafe { x86_64::_mm_cmpgt_epi8(*self, Self::splat(a)) };
244        let tmp2 = unsafe { x86_64::_mm_cmplt_epi8(*self, Self::splat(b)) };
245        let tmp3 = unsafe { x86_64::_mm_and_si128(tmp1, tmp2) };
246        unsafe { x86_64::_mm_and_si128(tmp3, Self::splat(1)) }
247    }
248
249    #[inline(always)]
250    fn bitand(&self, other: Self) -> Self {
251        unsafe { x86_64::_mm_and_si128(*self, other) }
252    }
253
254    #[inline(always)]
255    fn add(&self, other: Self) -> Self {
256        unsafe { x86_64::_mm_add_epi8(*self, other) }
257    }
258
259    #[inline(always)]
260    fn sub(&self, other: Self) -> Self {
261        unsafe { x86_64::_mm_sub_epi8(*self, other) }
262    }
263
264    #[inline(always)]
265    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
266        let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
267        tmp[15 - n] += 1;
268        unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
269    }
270
271    #[inline(always)]
272    fn dec_last_lex_byte(&self) -> Self {
273        let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
274        tmp[15] -= 1;
275        unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
276    }
277
278    #[inline(always)]
279    fn sum_bytes(&self) -> usize {
280        let half_sum = unsafe { x86_64::_mm_sad_epu8(*self, x86_64::_mm_setzero_si128()) };
281        let (low, high) = unsafe { core::mem::transmute::<Self, (u64, u64)>(half_sum) };
282        (low + high) as usize
283    }
284}
285
286#[cfg(target_arch = "aarch64")]
287impl ByteChunk for aarch64::uint8x16_t {
288    const SIZE: usize = core::mem::size_of::<Self>();
289    const MAX_ACC: usize = 255;
290
291    #[inline(always)]
292    fn zero() -> Self {
293        unsafe { aarch64::vdupq_n_u8(0) }
294    }
295
296    #[inline(always)]
297    fn splat(n: u8) -> Self {
298        unsafe { aarch64::vdupq_n_u8(n) }
299    }
300
301    #[inline(always)]
302    fn is_zero(&self) -> bool {
303        unsafe { aarch64::vmaxvq_u8(*self) == 0 }
304    }
305
306    #[inline(always)]
307    fn shift_back_lex(&self, n: usize) -> Self {
308        unsafe {
309            match n {
310                1 => aarch64::vextq_u8(*self, Self::zero(), 1),
311                2 => aarch64::vextq_u8(*self, Self::zero(), 2),
312                _ => unreachable!(),
313            }
314        }
315    }
316
317    #[inline(always)]
318    fn shift_across(&self, n: Self) -> Self {
319        unsafe { aarch64::vextq_u8(*self, n, 15) }
320    }
321
322    #[inline(always)]
323    fn shr(&self, n: usize) -> Self {
324        unsafe {
325            let u64_vec = aarch64::vreinterpretq_u64_u8(*self);
326            let result = match n {
327                1 => aarch64::vshrq_n_u64(u64_vec, 1),
328                _ => unreachable!(),
329            };
330            aarch64::vreinterpretq_u8_u64(result)
331        }
332    }
333
334    #[inline(always)]
335    fn cmp_eq_byte(&self, byte: u8) -> Self {
336        unsafe {
337            let equal = aarch64::vceqq_u8(*self, Self::splat(byte));
338            aarch64::vshrq_n_u8(equal, 7)
339        }
340    }
341
342    #[inline(always)]
343    fn bytes_between_127(&self, a: u8, b: u8) -> Self {
344        use aarch64::vreinterpretq_s8_u8 as cast;
345        unsafe {
346            let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a)));
347            let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b)));
348            let in_range = aarch64::vandq_u8(a_gt, b_gt);
349            aarch64::vshrq_n_u8(in_range, 7)
350        }
351    }
352
353    #[inline(always)]
354    fn bitand(&self, other: Self) -> Self {
355        unsafe { aarch64::vandq_u8(*self, other) }
356    }
357
358    #[inline(always)]
359    fn add(&self, other: Self) -> Self {
360        unsafe { aarch64::vaddq_u8(*self, other) }
361    }
362
363    #[inline(always)]
364    fn sub(&self, other: Self) -> Self {
365        unsafe { aarch64::vsubq_u8(*self, other) }
366    }
367
368    #[inline(always)]
369    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
370        const END: i32 = Chunk::SIZE as i32 - 1;
371        match n {
372            0 => unsafe {
373                let lane = aarch64::vgetq_lane_u8(*self, END);
374                aarch64::vsetq_lane_u8(lane + 1, *self, END)
375            },
376            1 => unsafe {
377                let lane = aarch64::vgetq_lane_u8(*self, END - 1);
378                aarch64::vsetq_lane_u8(lane + 1, *self, END - 1)
379            },
380            _ => unreachable!(),
381        }
382    }
383
384    #[inline(always)]
385    fn dec_last_lex_byte(&self) -> Self {
386        const END: i32 = Chunk::SIZE as i32 - 1;
387        unsafe {
388            let last = aarch64::vgetq_lane_u8(*self, END);
389            aarch64::vsetq_lane_u8(last - 1, *self, END)
390        }
391    }
392
393    #[inline(always)]
394    fn sum_bytes(&self) -> usize {
395        unsafe { aarch64::vaddlvq_u8(*self).into() }
396    }
397}
398
399//=============================================================
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn usize_flag_bytes_01() {
407        let v: usize = 0xE2_09_08_A6_E2_A6_E2_09;
408        assert_eq!(0x00_00_00_00_00_00_00_00, v.cmp_eq_byte(0x07));
409        assert_eq!(0x00_00_01_00_00_00_00_00, v.cmp_eq_byte(0x08));
410        assert_eq!(0x00_01_00_00_00_00_00_01, v.cmp_eq_byte(0x09));
411        assert_eq!(0x00_00_00_01_00_01_00_00, v.cmp_eq_byte(0xA6));
412        assert_eq!(0x01_00_00_00_01_00_01_00, v.cmp_eq_byte(0xE2));
413    }
414
415    #[test]
416    fn usize_bytes_between_127_01() {
417        let v: usize = 0x7E_09_00_A6_FF_7F_08_07;
418        assert_eq!(0x01_01_00_00_00_00_01_01, v.bytes_between_127(0x00, 0x7F));
419        assert_eq!(0x00_01_00_00_00_00_01_00, v.bytes_between_127(0x07, 0x7E));
420        assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
421    }
422
423    #[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))]
424    #[test]
425    fn sum_bytes_simd() {
426        let ones = Chunk::splat(1);
427        let mut acc = Chunk::zero();
428        for _ in 0..Chunk::MAX_ACC {
429            acc = acc.add(ones);
430        }
431
432        assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC);
433    }
434}