1#[cfg(target_arch = "x86_64")]
2use core::arch::x86_64;
3
4#[cfg(target_arch = "aarch64")]
5use core::arch::aarch64;
6
7#[cfg(all(feature = "simd", target_arch = "x86_64"))]
9pub(crate) type Chunk = x86_64::__m128i;
10#[cfg(all(feature = "simd", target_arch = "aarch64"))]
11pub(crate) type Chunk = aarch64::uint8x16_t;
12#[cfg(any(
13 not(feature = "simd"),
14 not(any(target_arch = "x86_64", target_arch = "aarch64"))
15))]
16pub(crate) type Chunk = usize;
17
18pub(crate) trait ByteChunk: Copy + Clone {
21 const SIZE: usize;
23
24 const MAX_ACC: usize;
27
28 fn zero() -> Self;
30
31 fn splat(n: u8) -> Self;
33
34 fn is_zero(&self) -> bool;
36
37 fn shift_back_lex(&self, n: usize) -> Self;
39
40 fn shift_across(&self, n: Self) -> Self;
42
43 fn shr(&self, n: usize) -> Self;
45
46 fn cmp_eq_byte(&self, byte: u8) -> Self;
51
52 fn bytes_between_127(&self, a: u8, b: u8) -> Self;
57
58 fn bitand(&self, other: Self) -> Self;
60
61 fn add(&self, other: Self) -> Self;
63
64 fn sub(&self, other: Self) -> Self;
66
67 fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self;
69
70 fn dec_last_lex_byte(&self) -> Self;
72
73 fn sum_bytes(&self) -> usize;
75}
76
77impl ByteChunk for usize {
78 const SIZE: usize = core::mem::size_of::<usize>();
79 const MAX_ACC: usize = (256 / core::mem::size_of::<usize>()) - 1;
80
81 #[inline(always)]
82 fn zero() -> Self {
83 0
84 }
85
86 #[inline(always)]
87 fn splat(n: u8) -> Self {
88 const ONES: usize = core::usize::MAX / 0xFF;
89 ONES * n as usize
90 }
91
92 #[inline(always)]
93 fn is_zero(&self) -> bool {
94 *self == 0
95 }
96
97 #[inline(always)]
98 fn shift_back_lex(&self, n: usize) -> Self {
99 if cfg!(target_endian = "little") {
100 *self >> (n * 8)
101 } else {
102 *self << (n * 8)
103 }
104 }
105
106 #[inline(always)]
107 fn shift_across(&self, n: Self) -> Self {
108 let shift_distance = (Self::SIZE - 1) * 8;
109 if cfg!(target_endian = "little") {
110 (*self >> shift_distance) | (n << 8)
111 } else {
112 (*self << shift_distance) | (n >> 8)
113 }
114 }
115
116 #[inline(always)]
117 fn shr(&self, n: usize) -> Self {
118 *self >> n
119 }
120
121 #[inline(always)]
122 fn cmp_eq_byte(&self, byte: u8) -> Self {
123 const ONES: usize = core::usize::MAX / 0xFF;
124 const ONES_HIGH: usize = ONES << 7;
125 let word = *self ^ (byte as usize * ONES);
126 (!(((word & !ONES_HIGH) + !ONES_HIGH) | word) & ONES_HIGH) >> 7
127 }
128
129 #[inline(always)]
130 fn bytes_between_127(&self, a: u8, b: u8) -> Self {
131 const ONES: usize = core::usize::MAX / 0xFF;
132 const ONES_HIGH: usize = ONES << 7;
133 let tmp = *self & (ONES * 127);
134 (((ONES * (127 + b as usize) - tmp) & !*self & (tmp + (ONES * (127 - a as usize))))
135 & ONES_HIGH)
136 >> 7
137 }
138
139 #[inline(always)]
140 fn bitand(&self, other: Self) -> Self {
141 *self & other
142 }
143
144 #[inline(always)]
145 fn add(&self, other: Self) -> Self {
146 *self + other
147 }
148
149 #[inline(always)]
150 fn sub(&self, other: Self) -> Self {
151 *self - other
152 }
153
154 #[inline(always)]
155 fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
156 if cfg!(target_endian = "little") {
157 *self + (1 << ((Self::SIZE - 1 - n) * 8))
158 } else {
159 *self + (1 << (n * 8))
160 }
161 }
162
163 #[inline(always)]
164 fn dec_last_lex_byte(&self) -> Self {
165 if cfg!(target_endian = "little") {
166 *self - (1 << ((Self::SIZE - 1) * 8))
167 } else {
168 *self - 1
169 }
170 }
171
172 #[inline(always)]
173 fn sum_bytes(&self) -> usize {
174 const ONES: usize = core::usize::MAX / 0xFF;
175 self.wrapping_mul(ONES) >> ((Self::SIZE - 1) * 8)
176 }
177}
178
179#[cfg(target_arch = "x86_64")]
182impl ByteChunk for x86_64::__m128i {
183 const SIZE: usize = core::mem::size_of::<x86_64::__m128i>();
184 const MAX_ACC: usize = 255;
185
186 #[inline(always)]
187 fn zero() -> Self {
188 unsafe { x86_64::_mm_setzero_si128() }
189 }
190
191 #[inline(always)]
192 fn splat(n: u8) -> Self {
193 unsafe { x86_64::_mm_set1_epi8(n as i8) }
194 }
195
196 #[inline(always)]
197 fn is_zero(&self) -> bool {
198 let tmp = unsafe { core::mem::transmute::<Self, (u64, u64)>(*self) };
199 tmp.0 == 0 && tmp.1 == 0
200 }
201
202 #[inline(always)]
203 fn shift_back_lex(&self, n: usize) -> Self {
204 match n {
205 0 => *self,
206 1 => unsafe { x86_64::_mm_srli_si128(*self, 1) },
207 2 => unsafe { x86_64::_mm_srli_si128(*self, 2) },
208 3 => unsafe { x86_64::_mm_srli_si128(*self, 3) },
209 4 => unsafe { x86_64::_mm_srli_si128(*self, 4) },
210 _ => unreachable!(),
211 }
212 }
213
214 #[inline(always)]
215 fn shift_across(&self, n: Self) -> Self {
216 unsafe {
217 let bottom_byte = x86_64::_mm_srli_si128(*self, 15);
218 let rest_shifted = x86_64::_mm_slli_si128(n, 1);
219 x86_64::_mm_or_si128(bottom_byte, rest_shifted)
220 }
221 }
222
223 #[inline(always)]
224 fn shr(&self, n: usize) -> Self {
225 match n {
226 0 => *self,
227 1 => unsafe { x86_64::_mm_srli_epi64(*self, 1) },
228 2 => unsafe { x86_64::_mm_srli_epi64(*self, 2) },
229 3 => unsafe { x86_64::_mm_srli_epi64(*self, 3) },
230 4 => unsafe { x86_64::_mm_srli_epi64(*self, 4) },
231 _ => unreachable!(),
232 }
233 }
234
235 #[inline(always)]
236 fn cmp_eq_byte(&self, byte: u8) -> Self {
237 let tmp = unsafe { x86_64::_mm_cmpeq_epi8(*self, Self::splat(byte)) };
238 unsafe { x86_64::_mm_and_si128(tmp, Self::splat(1)) }
239 }
240
241 #[inline(always)]
242 fn bytes_between_127(&self, a: u8, b: u8) -> Self {
243 let tmp1 = unsafe { x86_64::_mm_cmpgt_epi8(*self, Self::splat(a)) };
244 let tmp2 = unsafe { x86_64::_mm_cmplt_epi8(*self, Self::splat(b)) };
245 let tmp3 = unsafe { x86_64::_mm_and_si128(tmp1, tmp2) };
246 unsafe { x86_64::_mm_and_si128(tmp3, Self::splat(1)) }
247 }
248
249 #[inline(always)]
250 fn bitand(&self, other: Self) -> Self {
251 unsafe { x86_64::_mm_and_si128(*self, other) }
252 }
253
254 #[inline(always)]
255 fn add(&self, other: Self) -> Self {
256 unsafe { x86_64::_mm_add_epi8(*self, other) }
257 }
258
259 #[inline(always)]
260 fn sub(&self, other: Self) -> Self {
261 unsafe { x86_64::_mm_sub_epi8(*self, other) }
262 }
263
264 #[inline(always)]
265 fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
266 let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
267 tmp[15 - n] += 1;
268 unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
269 }
270
271 #[inline(always)]
272 fn dec_last_lex_byte(&self) -> Self {
273 let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
274 tmp[15] -= 1;
275 unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
276 }
277
278 #[inline(always)]
279 fn sum_bytes(&self) -> usize {
280 let half_sum = unsafe { x86_64::_mm_sad_epu8(*self, x86_64::_mm_setzero_si128()) };
281 let (low, high) = unsafe { core::mem::transmute::<Self, (u64, u64)>(half_sum) };
282 (low + high) as usize
283 }
284}
285
286#[cfg(target_arch = "aarch64")]
287impl ByteChunk for aarch64::uint8x16_t {
288 const SIZE: usize = core::mem::size_of::<Self>();
289 const MAX_ACC: usize = 255;
290
291 #[inline(always)]
292 fn zero() -> Self {
293 unsafe { aarch64::vdupq_n_u8(0) }
294 }
295
296 #[inline(always)]
297 fn splat(n: u8) -> Self {
298 unsafe { aarch64::vdupq_n_u8(n) }
299 }
300
301 #[inline(always)]
302 fn is_zero(&self) -> bool {
303 unsafe { aarch64::vmaxvq_u8(*self) == 0 }
304 }
305
306 #[inline(always)]
307 fn shift_back_lex(&self, n: usize) -> Self {
308 unsafe {
309 match n {
310 1 => aarch64::vextq_u8(*self, Self::zero(), 1),
311 2 => aarch64::vextq_u8(*self, Self::zero(), 2),
312 _ => unreachable!(),
313 }
314 }
315 }
316
317 #[inline(always)]
318 fn shift_across(&self, n: Self) -> Self {
319 unsafe { aarch64::vextq_u8(*self, n, 15) }
320 }
321
322 #[inline(always)]
323 fn shr(&self, n: usize) -> Self {
324 unsafe {
325 let u64_vec = aarch64::vreinterpretq_u64_u8(*self);
326 let result = match n {
327 1 => aarch64::vshrq_n_u64(u64_vec, 1),
328 _ => unreachable!(),
329 };
330 aarch64::vreinterpretq_u8_u64(result)
331 }
332 }
333
334 #[inline(always)]
335 fn cmp_eq_byte(&self, byte: u8) -> Self {
336 unsafe {
337 let equal = aarch64::vceqq_u8(*self, Self::splat(byte));
338 aarch64::vshrq_n_u8(equal, 7)
339 }
340 }
341
342 #[inline(always)]
343 fn bytes_between_127(&self, a: u8, b: u8) -> Self {
344 use aarch64::vreinterpretq_s8_u8 as cast;
345 unsafe {
346 let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a)));
347 let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b)));
348 let in_range = aarch64::vandq_u8(a_gt, b_gt);
349 aarch64::vshrq_n_u8(in_range, 7)
350 }
351 }
352
353 #[inline(always)]
354 fn bitand(&self, other: Self) -> Self {
355 unsafe { aarch64::vandq_u8(*self, other) }
356 }
357
358 #[inline(always)]
359 fn add(&self, other: Self) -> Self {
360 unsafe { aarch64::vaddq_u8(*self, other) }
361 }
362
363 #[inline(always)]
364 fn sub(&self, other: Self) -> Self {
365 unsafe { aarch64::vsubq_u8(*self, other) }
366 }
367
368 #[inline(always)]
369 fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
370 const END: i32 = Chunk::SIZE as i32 - 1;
371 match n {
372 0 => unsafe {
373 let lane = aarch64::vgetq_lane_u8(*self, END);
374 aarch64::vsetq_lane_u8(lane + 1, *self, END)
375 },
376 1 => unsafe {
377 let lane = aarch64::vgetq_lane_u8(*self, END - 1);
378 aarch64::vsetq_lane_u8(lane + 1, *self, END - 1)
379 },
380 _ => unreachable!(),
381 }
382 }
383
384 #[inline(always)]
385 fn dec_last_lex_byte(&self) -> Self {
386 const END: i32 = Chunk::SIZE as i32 - 1;
387 unsafe {
388 let last = aarch64::vgetq_lane_u8(*self, END);
389 aarch64::vsetq_lane_u8(last - 1, *self, END)
390 }
391 }
392
393 #[inline(always)]
394 fn sum_bytes(&self) -> usize {
395 unsafe { aarch64::vaddlvq_u8(*self).into() }
396 }
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn usize_flag_bytes_01() {
407 let v: usize = 0xE2_09_08_A6_E2_A6_E2_09;
408 assert_eq!(0x00_00_00_00_00_00_00_00, v.cmp_eq_byte(0x07));
409 assert_eq!(0x00_00_01_00_00_00_00_00, v.cmp_eq_byte(0x08));
410 assert_eq!(0x00_01_00_00_00_00_00_01, v.cmp_eq_byte(0x09));
411 assert_eq!(0x00_00_00_01_00_01_00_00, v.cmp_eq_byte(0xA6));
412 assert_eq!(0x01_00_00_00_01_00_01_00, v.cmp_eq_byte(0xE2));
413 }
414
415 #[test]
416 fn usize_bytes_between_127_01() {
417 let v: usize = 0x7E_09_00_A6_FF_7F_08_07;
418 assert_eq!(0x01_01_00_00_00_00_01_01, v.bytes_between_127(0x00, 0x7F));
419 assert_eq!(0x00_01_00_00_00_00_01_00, v.bytes_between_127(0x07, 0x7E));
420 assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
421 }
422
423 #[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))]
424 #[test]
425 fn sum_bytes_simd() {
426 let ones = Chunk::splat(1);
427 let mut acc = Chunk::zero();
428 for _ in 0..Chunk::MAX_ACC {
429 acc = acc.add(ones);
430 }
431
432 assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC);
433 }
434}