num_modular/
double.rs

1//! This module implements a double width integer type based on the largest built-in integer (u128)
2//! Part of the optimization comes from `ethnum` and `zkp-u256` crates.
3
4use core::ops::*;
5
6/// Alias of the builtin integer type with max width (currently [u128])
7#[allow(non_camel_case_types)]
8pub type umax = u128;
9
10const HALF_BITS: u32 = umax::BITS / 2;
11
12// Split umax into hi and lo parts. Tt's critical to use inline here
13#[inline(always)]
14const fn split(v: umax) -> (umax, umax) {
15    (v >> HALF_BITS, v & (umax::MAX >> HALF_BITS))
16}
17
18#[inline(always)]
19const fn div_rem(n: umax, d: umax) -> (umax, umax) {
20    (n / d, n % d)
21}
22
23#[allow(non_camel_case_types)]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25/// A double width integer type based on the largest built-in integer type [umax] (currently [u128]), and
26/// to support double-width operations on it is the only goal for this type.
27///
28/// Although it can be regarded as u256, it's not as feature-rich as in other crates
29/// since it's only designed to support this crate and few other crates (will be noted in comments).
30pub struct udouble {
31    /// Most significant part
32    pub hi: umax,
33    /// Least significant part
34    pub lo: umax,
35}
36
37impl udouble {
38    pub const MAX: Self = Self {
39        lo: umax::MAX,
40        hi: umax::MAX,
41    };
42
43    //> (used in u128::addm)
44    #[inline]
45    pub const fn widening_add(lhs: umax, rhs: umax) -> Self {
46        let (sum, carry) = lhs.overflowing_add(rhs);
47        udouble {
48            hi: carry as umax,
49            lo: sum,
50        }
51    }
52
53    /// Calculate multiplication of two [umax] integers with result represented in double width integer
54    // equivalent to umul_ppmm, can be implemented efficiently with carrying_mul and widening_mul implemented (rust#85532)
55    //> (used in u128::mulm, MersenneInt, Montgomery::<u128>::{reduce, mul}, num-order::NumHash)
56    #[inline]
57    pub const fn widening_mul(lhs: umax, rhs: umax) -> Self {
58        let ((x1, x0), (y1, y0)) = (split(lhs), split(rhs));
59
60        let z2 = x1 * y1;
61        let (c0, z0) = split(x0 * y0); // c0 <= umax::MAX - 1
62        let (c1, z1) = split(x1 * y0 + c0);
63        let z2 = z2 + c1;
64        let (c1, z1) = split(x0 * y1 + z1);
65        Self {
66            hi: z2 + c1,
67            lo: z0 | z1 << HALF_BITS,
68        }
69    }
70
71    /// Optimized squaring function for [umax] integers
72    //> (used in Montgomery::<u128>::{square})
73    #[inline]
74    pub const fn widening_square(x: umax) -> Self {
75        // the algorithm here is basically the same as widening_mul
76        let (x1, x0) = split(x);
77
78        let z2 = x1 * x1;
79        let m = x1 * x0;
80        let (c0, z0) = split(x0 * x0);
81        let (c1, z1) = split(m + c0);
82        let z2 = z2 + c1;
83        let (c1, z1) = split(m + z1);
84        Self {
85            hi: z2 + c1,
86            lo: z0 | z1 << HALF_BITS,
87        }
88    }
89
90    //> (used in Montgomery::<u128>::reduce)
91    #[inline]
92    pub const fn overflowing_add(&self, rhs: Self) -> (Self, bool) {
93        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
94        let (hi, of1) = self.hi.overflowing_add(rhs.hi);
95        let (hi, of2) = hi.overflowing_add(carry as umax);
96        (Self { lo, hi }, of1 || of2)
97    }
98
99    // double by double multiplication, listed here in case of future use
100    #[allow(dead_code)]
101    fn overflowing_mul(&self, rhs: Self) -> (Self, bool) {
102        let c2 = self.hi != 0 && rhs.hi != 0;
103        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs.lo);
104        let (z1x, c1x) = umax::overflowing_mul(self.lo, rhs.hi);
105        let (z1y, c1y) = umax::overflowing_mul(self.hi, rhs.lo);
106        let (z1z, c1z) = umax::overflowing_add(z1x, z1y);
107        let (z1, c1) = z1z.overflowing_add(c0);
108        (Self { hi: z1, lo: z0 }, c1x | c1y | c1z | c1 | c2)
109    }
110
111    /// Multiplication of double width and single width
112    //> (used in num-order:NumHash)
113    #[inline]
114    pub const fn overflowing_mul1(&self, rhs: umax) -> (Self, bool) {
115        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs);
116        let (z1, c1) = self.hi.overflowing_mul(rhs);
117        let (z1, cs1) = z1.overflowing_add(c0);
118        (Self { hi: z1, lo: z0 }, c1 | cs1)
119    }
120
121    /// Multiplication of double width and single width
122    //> (used in Self::mul::<umax>)
123    #[inline]
124    pub fn checked_mul1(&self, rhs: umax) -> Option<Self> {
125        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs);
126        let z1 = self.hi.checked_mul(rhs)?.checked_add(c0)?;
127        Some(Self { hi: z1, lo: z0 })
128    }
129
130    //> (used in num-order::NumHash)
131    #[inline]
132    pub fn checked_shl(self, rhs: u32) -> Option<Self> {
133        if rhs < umax::BITS * 2 {
134            Some(self << rhs)
135        } else {
136            None
137        }
138    }
139
140    //> (not used yet)
141    #[inline]
142    pub fn checked_shr(self, rhs: u32) -> Option<Self> {
143        if rhs < umax::BITS * 2 {
144            Some(self >> rhs)
145        } else {
146            None
147        }
148    }
149}
150
151impl From<umax> for udouble {
152    #[inline]
153    fn from(v: umax) -> Self {
154        Self { lo: v, hi: 0 }
155    }
156}
157
158impl Add for udouble {
159    type Output = udouble;
160
161    // equivalent to add_ssaaaa
162    #[inline]
163    fn add(self, rhs: Self) -> Self::Output {
164        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
165        let hi = self.hi + rhs.hi + carry as umax;
166        Self { lo, hi }
167    }
168}
169//> (used in Self::div_rem)
170impl Add<umax> for udouble {
171    type Output = udouble;
172    #[inline]
173    fn add(self, rhs: umax) -> Self::Output {
174        let (lo, carry) = self.lo.overflowing_add(rhs);
175        let hi = if carry { self.hi + 1 } else { self.hi };
176        Self { lo, hi }
177    }
178}
179impl AddAssign for udouble {
180    #[inline]
181    fn add_assign(&mut self, rhs: Self) {
182        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
183        self.lo = lo;
184        self.hi += rhs.hi + carry as umax;
185    }
186}
187impl AddAssign<umax> for udouble {
188    #[inline]
189    fn add_assign(&mut self, rhs: umax) {
190        let (lo, carry) = self.lo.overflowing_add(rhs);
191        self.lo = lo;
192        if carry {
193            self.hi += 1
194        }
195    }
196}
197
198//> (used in test of Add)
199impl Sub for udouble {
200    type Output = Self;
201    #[inline]
202    fn sub(self, rhs: Self) -> Self::Output {
203        let carry = self.lo < rhs.lo;
204        let lo = self.lo.wrapping_sub(rhs.lo);
205        let hi = self.hi - rhs.hi - carry as umax;
206        Self { lo, hi }
207    }
208}
209impl Sub<umax> for udouble {
210    type Output = Self;
211    #[inline]
212    fn sub(self, rhs: umax) -> Self::Output {
213        let carry = self.lo < rhs;
214        let lo = self.lo.wrapping_sub(rhs);
215        let hi = if carry { self.hi - 1 } else { self.hi };
216        Self { lo, hi }
217    }
218}
219//> (used in test of AddAssign)
220impl SubAssign for udouble {
221    #[inline]
222    fn sub_assign(&mut self, rhs: Self) {
223        let carry = self.lo < rhs.lo;
224        self.lo = self.lo.wrapping_sub(rhs.lo);
225        self.hi -= rhs.hi + carry as umax;
226    }
227}
228impl SubAssign<umax> for udouble {
229    #[inline]
230    fn sub_assign(&mut self, rhs: umax) {
231        let carry = self.lo < rhs;
232        self.lo = self.lo.wrapping_sub(rhs);
233        if carry {
234            self.hi -= 1;
235        }
236    }
237}
238
239macro_rules! impl_sh_ops {
240    ($t:ty) => {
241        //> (used in Self::checked_shl)
242        impl Shl<$t> for udouble {
243            type Output = Self;
244            #[inline]
245            fn shl(self, rhs: $t) -> Self::Output {
246                match rhs {
247                    0 => self, // avoid shifting by full bits, which is UB
248                    s if s >= umax::BITS as $t => Self {
249                        hi: self.lo << (s - umax::BITS as $t),
250                        lo: 0,
251                    },
252                    s => Self {
253                        lo: self.lo << s,
254                        hi: (self.hi << s) | (self.lo >> (umax::BITS as $t - s)),
255                    },
256                }
257            }
258        }
259        //> (not used yet)
260        impl ShlAssign<$t> for udouble {
261            #[inline]
262            fn shl_assign(&mut self, rhs: $t) {
263                match rhs {
264                    0 => {}
265                    s if s >= umax::BITS as $t => {
266                        self.hi = self.lo << (s - umax::BITS as $t);
267                        self.lo = 0;
268                    }
269                    s => {
270                        self.hi <<= s;
271                        self.hi |= self.lo >> (umax::BITS as $t - s);
272                        self.lo <<= s;
273                    }
274                }
275            }
276        }
277        //> (used in Self::checked_shr)
278        impl Shr<$t> for udouble {
279            type Output = Self;
280            #[inline]
281            fn shr(self, rhs: $t) -> Self::Output {
282                match rhs {
283                    0 => self,
284                    s if s >= umax::BITS as $t => Self {
285                        lo: self.hi >> (rhs - umax::BITS as $t),
286                        hi: 0,
287                    },
288                    s => Self {
289                        hi: self.hi >> s,
290                        lo: (self.lo >> s) | (self.hi << (umax::BITS as $t - s)),
291                    },
292                }
293            }
294        }
295        //> (not used yet)
296        impl ShrAssign<$t> for udouble {
297            #[inline]
298            fn shr_assign(&mut self, rhs: $t) {
299                match rhs {
300                    0 => {}
301                    s if s >= umax::BITS as $t => {
302                        self.lo = self.hi >> (rhs - umax::BITS as $t);
303                        self.hi = 0;
304                    }
305                    s => {
306                        self.lo >>= s;
307                        self.lo |= self.hi << (umax::BITS as $t - s);
308                        self.hi >>= s;
309                    }
310                }
311            }
312        }
313    };
314}
315
316// only implement most useful ones, so that we don't need to optimize so many variants
317impl_sh_ops!(u8);
318impl_sh_ops!(u16);
319impl_sh_ops!(u32);
320
321//> (not used yet)
322impl BitAnd for udouble {
323    type Output = Self;
324    #[inline]
325    fn bitand(self, rhs: Self) -> Self::Output {
326        Self {
327            lo: self.lo & rhs.lo,
328            hi: self.hi & rhs.hi,
329        }
330    }
331}
332//> (not used yet)
333impl BitAndAssign for udouble {
334    #[inline]
335    fn bitand_assign(&mut self, rhs: Self) {
336        self.lo &= rhs.lo;
337        self.hi &= rhs.hi;
338    }
339}
340//> (not used yet)
341impl BitOr for udouble {
342    type Output = Self;
343    #[inline]
344    fn bitor(self, rhs: Self) -> Self::Output {
345        Self {
346            lo: self.lo | rhs.lo,
347            hi: self.hi | rhs.hi,
348        }
349    }
350}
351//> (not used yet)
352impl BitOrAssign for udouble {
353    #[inline]
354    fn bitor_assign(&mut self, rhs: Self) {
355        self.lo |= rhs.lo;
356        self.hi |= rhs.hi;
357    }
358}
359//> (not used yet)
360impl BitXor for udouble {
361    type Output = Self;
362    #[inline]
363    fn bitxor(self, rhs: Self) -> Self::Output {
364        Self {
365            lo: self.lo ^ rhs.lo,
366            hi: self.hi ^ rhs.hi,
367        }
368    }
369}
370//> (not used yet)
371impl BitXorAssign for udouble {
372    #[inline]
373    fn bitxor_assign(&mut self, rhs: Self) {
374        self.lo ^= rhs.lo;
375        self.hi ^= rhs.hi;
376    }
377}
378//> (not used yet)
379impl Not for udouble {
380    type Output = Self;
381    #[inline]
382    fn not(self) -> Self::Output {
383        Self {
384            lo: !self.lo,
385            hi: !self.hi,
386        }
387    }
388}
389
390impl udouble {
391    //> (used in Self::div_rem)
392    #[inline]
393    pub const fn leading_zeros(self) -> u32 {
394        if self.hi == 0 {
395            self.lo.leading_zeros() + umax::BITS
396        } else {
397            self.hi.leading_zeros()
398        }
399    }
400
401    // double by double division (long division), it's not the most efficient algorithm.
402    // listed here in case of future use
403    #[allow(dead_code)]
404    fn div_rem_2by2(self, other: Self) -> (Self, Self) {
405        let mut n = self; // numerator
406        let mut d = other; // denominator
407        let mut q = Self { lo: 0, hi: 0 }; // quotient
408
409        let nbits = (2 * umax::BITS - n.leading_zeros()) as u16; // assuming umax = u128
410        let dbits = (2 * umax::BITS - d.leading_zeros()) as u16;
411        assert!(dbits != 0, "division by zero");
412
413        // Early return in case we are dividing by a larger number than us
414        if nbits < dbits {
415            return (q, n);
416        }
417
418        // Bitwise long division
419        let mut shift = nbits - dbits;
420        d <<= shift;
421        loop {
422            if n >= d {
423                q += 1;
424                n -= d;
425            }
426            if shift == 0 {
427                break;
428            }
429
430            d >>= 1u8;
431            q <<= 1u8;
432            shift -= 1;
433        }
434        (q, n)
435    }
436
437    // double by single to single division.
438    // equivalent to `udiv_qrnnd` in C or `divq` in assembly.
439    //> (used in Self::{div, rem}::<umax>)
440    fn div_rem_2by1(self, other: umax) -> (umax, umax) {
441        // the following algorithm comes from `ethnum` crate
442        const B: umax = 1 << HALF_BITS; // number base (64 bits)
443
444        // Normalize the divisor.
445        let s = other.leading_zeros();
446        let (n, d) = (self << s, other << s); // numerator, denominator
447        let (d1, d0) = split(d);
448        let (n1, n0) = split(n.lo); // split lower part of dividend
449
450        // Compute the first quotient digit q1.
451        let (mut q1, mut rhat) = div_rem(n.hi, d1);
452
453        // q1 has at most error 2. No more than 2 iterations.
454        while q1 >= B || q1 * d0 > B * rhat + n1 {
455            q1 -= 1;
456            rhat += d1;
457            if rhat >= B {
458                break;
459            }
460        }
461
462        let r21 =
463            n.hi.wrapping_mul(B)
464                .wrapping_add(n1)
465                .wrapping_sub(q1.wrapping_mul(d));
466
467        // Compute the second quotient digit q0.
468        let (mut q0, mut rhat) = div_rem(r21, d1);
469
470        // q0 has at most error 2. No more than 2 iterations.
471        while q0 >= B || q0 * d0 > B * rhat + n0 {
472            q0 -= 1;
473            rhat += d1;
474            if rhat >= B {
475                break;
476            }
477        }
478
479        let r = (r21
480            .wrapping_mul(B)
481            .wrapping_add(n0)
482            .wrapping_sub(q0.wrapping_mul(d)))
483            >> s;
484        let q = q1 * B + q0;
485        (q, r)
486    }
487}
488
489impl Mul<umax> for udouble {
490    type Output = Self;
491    #[inline]
492    fn mul(self, rhs: umax) -> Self::Output {
493        self.checked_mul1(rhs).expect("multiplication overflow!")
494    }
495}
496
497impl Div<umax> for udouble {
498    type Output = Self;
499    #[inline]
500    fn div(self, rhs: umax) -> Self::Output {
501        // self.div_rem(rhs.into()).0
502        if self.hi < rhs {
503            // The result fits in 128 bits.
504            Self {
505                lo: self.div_rem_2by1(rhs).0,
506                hi: 0,
507            }
508        } else {
509            let (q, r) = div_rem(self.hi, rhs);
510            Self {
511                lo: Self { lo: self.lo, hi: r }.div_rem_2by1(rhs).0,
512                hi: q,
513            }
514        }
515    }
516}
517
518//> (used in Montgomery::<u128>::transform)
519impl Rem<umax> for udouble {
520    type Output = umax;
521    #[inline]
522    fn rem(self, rhs: umax) -> Self::Output {
523        if self.hi < rhs {
524            // The result fits in 128 bits.
525            self.div_rem_2by1(rhs).1
526        } else {
527            Self {
528                lo: self.lo,
529                hi: self.hi % rhs,
530            }
531            .div_rem_2by1(rhs)
532            .1
533        }
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use rand::random;
541
542    #[test]
543    fn test_construction() {
544        // from widening operators
545        assert_eq!(udouble { hi: 0, lo: 2 }, udouble::widening_add(1, 1));
546        assert_eq!(
547            udouble {
548                hi: 1,
549                lo: umax::MAX - 1
550            },
551            udouble::widening_add(umax::MAX, umax::MAX)
552        );
553
554        assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_mul(1, 1));
555        assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_square(1));
556        assert_eq!(
557            udouble { hi: 1 << 32, lo: 0 },
558            udouble::widening_mul(1 << 80, 1 << 80)
559        );
560        assert_eq!(
561            udouble { hi: 1 << 32, lo: 0 },
562            udouble::widening_square(1 << 80)
563        );
564        assert_eq!(
565            udouble {
566                hi: 1 << 32,
567                lo: 2 << 120 | 1 << 80
568            },
569            udouble::widening_mul(1 << 80 | 1 << 40, 1 << 80 | 1 << 40)
570        );
571        assert_eq!(
572            udouble {
573                hi: 1 << 32,
574                lo: 2 << 120 | 1 << 80
575            },
576            udouble::widening_square(1 << 80 | 1 << 40)
577        );
578        assert_eq!(
579            udouble {
580                hi: umax::MAX - 1,
581                lo: 1
582            },
583            udouble::widening_mul(umax::MAX, umax::MAX)
584        );
585        assert_eq!(
586            udouble {
587                hi: umax::MAX - 1,
588                lo: 1
589            },
590            udouble::widening_square(umax::MAX)
591        );
592    }
593
594    #[test]
595    fn test_ops() {
596        const ONE: udouble = udouble { hi: 0, lo: 1 };
597        const TWO: udouble = udouble { hi: 0, lo: 2 };
598        const MAX: udouble = udouble {
599            hi: 0,
600            lo: umax::MAX,
601        };
602        const ONEZERO: udouble = udouble { hi: 1, lo: 0 };
603        const ONEMAX: udouble = udouble {
604            hi: 1,
605            lo: umax::MAX,
606        };
607        const TWOZERO: udouble = udouble { hi: 2, lo: 0 };
608
609        assert_eq!(ONE + MAX, ONEZERO);
610        assert_eq!(ONE + ONEMAX, TWOZERO);
611        assert_eq!(ONEZERO - ONE, MAX);
612        assert_eq!(ONEZERO - MAX, ONE);
613        assert_eq!(TWOZERO - ONE, ONEMAX);
614        assert_eq!(TWOZERO - ONEMAX, ONE);
615
616        assert_eq!(ONE << umax::BITS, ONEZERO);
617        assert_eq!((MAX << 1u8) + 1, ONEMAX);
618        assert_eq!(
619            ONE << 200u8,
620            udouble {
621                lo: 0,
622                hi: 1 << (200 - umax::BITS)
623            }
624        );
625        assert_eq!(ONEZERO >> umax::BITS, ONE);
626        assert_eq!(ONEMAX >> 1u8, MAX);
627
628        assert_eq!(ONE * MAX.lo, MAX);
629        assert_eq!(ONEMAX * ONE.lo, ONEMAX);
630        assert_eq!(ONEMAX * TWO.lo, ONEMAX + ONEMAX);
631        assert_eq!(MAX / ONE.lo, MAX);
632        assert_eq!(MAX / MAX.lo, ONE);
633        assert_eq!(ONE / MAX.lo, udouble { lo: 0, hi: 0 });
634        assert_eq!(ONEMAX / ONE.lo, ONEMAX);
635        assert_eq!(ONEMAX / MAX.lo, TWO);
636        assert_eq!(ONEMAX / TWO.lo, MAX);
637        assert_eq!(ONE % MAX.lo, 1);
638        assert_eq!(TWO % MAX.lo, 2);
639        assert_eq!(ONEMAX % MAX.lo, 1);
640        assert_eq!(ONEMAX % TWO.lo, 1);
641
642        assert_eq!(ONEMAX.checked_mul1(MAX.lo), None);
643        assert_eq!(TWOZERO.checked_mul1(MAX.lo), None);
644    }
645
646    #[test]
647    fn test_assign_ops() {
648        for _ in 0..10 {
649            let x = udouble {
650                hi: random::<u32>() as umax,
651                lo: random(),
652            };
653            let y = udouble {
654                hi: random::<u32>() as umax,
655                lo: random(),
656            };
657            let mut z = x;
658
659            z += y;
660            assert_eq!(z, x + y);
661            z -= y;
662            assert_eq!(z, x);
663        }
664    }
665}