num_modular/
reduced.rs

1use crate::{udouble, ModularInteger, ModularUnaryOps, Reducer};
2use core::ops::*;
3#[cfg(feature = "num_traits")]
4use num_traits::{Inv, Pow};
5
6/// An integer in a modulo ring
7#[derive(Debug, Clone, Copy)]
8pub struct ReducedInt<T, R: Reducer<T>> {
9    /// The reduced representation of the integer in a modulo ring.
10    a: T,
11
12    /// The reducer for the integer
13    r: R,
14}
15
16impl<T, R: Reducer<T>> ReducedInt<T, R> {
17    /// Convert n into the modulo ring ℤ/mℤ (i.e. `n % m`)
18    #[inline]
19    pub fn new(n: T, m: &T) -> Self {
20        let r = R::new(m);
21        let a = r.transform(n);
22        Self { a, r }
23    }
24
25    #[inline(always)]
26    fn check_modulus_eq(&self, rhs: &Self)
27    where
28        T: PartialEq,
29    {
30        // we don't directly compare m because m could be empty in case of Mersenne modular integer
31        if cfg!(debug_assertions) && self.r.modulus() != rhs.r.modulus() {
32            panic!("The modulus of two operators should be the same!");
33        }
34    }
35
36    #[inline(always)]
37    pub fn repr(&self) -> &T {
38        &self.a
39    }
40
41    #[inline(always)]
42    pub fn inv(self) -> Option<Self> {
43        Some(Self {
44            a: self.r.inv(self.a)?,
45            r: self.r,
46        })
47    }
48
49    #[inline(always)]
50    pub fn pow(self, exp: &T) -> Self {
51        Self {
52            a: self.r.pow(self.a, exp),
53            r: self.r,
54        }
55    }
56}
57
58impl<T: PartialEq, R: Reducer<T>> PartialEq for ReducedInt<T, R> {
59    #[inline]
60    fn eq(&self, other: &Self) -> bool {
61        self.check_modulus_eq(other);
62        self.a == other.a
63    }
64}
65
66macro_rules! impl_binops {
67    ($method:ident, impl $op:ident) => {
68        impl<T: PartialEq, R: Reducer<T>> $op for ReducedInt<T, R> {
69            type Output = Self;
70            fn $method(self, rhs: Self) -> Self::Output {
71                self.check_modulus_eq(&rhs);
72                let Self { a, r } = self;
73                let a = r.$method(&a, &rhs.a);
74                Self { a, r }
75            }
76        }
77
78        impl<T: PartialEq + Clone, R: Reducer<T>> $op<&Self> for ReducedInt<T, R> {
79            type Output = Self;
80            #[inline]
81            fn $method(self, rhs: &Self) -> Self::Output {
82                self.check_modulus_eq(&rhs);
83                let Self { a, r } = self;
84                let a = r.$method(&a, &rhs.a);
85                Self { a, r }
86            }
87        }
88
89        impl<T: PartialEq + Clone, R: Reducer<T>> $op<ReducedInt<T, R>> for &ReducedInt<T, R> {
90            type Output = ReducedInt<T, R>;
91            #[inline]
92            fn $method(self, rhs: ReducedInt<T, R>) -> Self::Output {
93                self.check_modulus_eq(&rhs);
94                let ReducedInt { a, r } = rhs;
95                let a = r.$method(&self.a, &a);
96                ReducedInt { a, r }
97            }
98        }
99
100        impl<T: PartialEq + Clone, R: Reducer<T> + Clone> $op<&ReducedInt<T, R>>
101            for &ReducedInt<T, R>
102        {
103            type Output = ReducedInt<T, R>;
104            #[inline]
105            fn $method(self, rhs: &ReducedInt<T, R>) -> Self::Output {
106                self.check_modulus_eq(&rhs);
107                let a = self.r.$method(&self.a, &rhs.a);
108                ReducedInt {
109                    a,
110                    r: self.r.clone(),
111                }
112            }
113        }
114
115        impl<T: PartialEq, R: Reducer<T>> $op<T> for ReducedInt<T, R> {
116            type Output = Self;
117            fn $method(self, rhs: T) -> Self::Output {
118                let Self { a, r } = self;
119                let rhs = r.transform(rhs);
120                let a = r.$method(&a, &rhs);
121                Self { a, r }
122            }
123        }
124    };
125}
126impl_binops!(add, impl Add);
127impl_binops!(sub, impl Sub);
128impl_binops!(mul, impl Mul);
129
130impl<T: PartialEq, R: Reducer<T>> Neg for ReducedInt<T, R> {
131    type Output = Self;
132    #[inline]
133    fn neg(self) -> Self::Output {
134        let Self { a, r } = self;
135        let a = r.neg(a);
136        Self { a, r }
137    }
138}
139impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Neg for &ReducedInt<T, R> {
140    type Output = ReducedInt<T, R>;
141    #[inline]
142    fn neg(self) -> Self::Output {
143        let a = self.r.neg(self.a.clone());
144        ReducedInt {
145            a,
146            r: self.r.clone(),
147        }
148    }
149}
150
151const INV_ERR_MSG: &str = "the modular inverse doesn't exist!";
152
153#[cfg(feature = "num_traits")]
154impl<T: PartialEq, R: Reducer<T>> Inv for ReducedInt<T, R> {
155    type Output = Self;
156    #[inline]
157    fn inv(self) -> Self::Output {
158        self.inv().expect(INV_ERR_MSG)
159    }
160}
161#[cfg(feature = "num_traits")]
162impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Inv for &ReducedInt<T, R> {
163    type Output = ReducedInt<T, R>;
164    #[inline]
165    fn inv(self) -> Self::Output {
166        self.clone().inv().expect(INV_ERR_MSG)
167    }
168}
169
170impl<T: PartialEq, R: Reducer<T>> Div for ReducedInt<T, R> {
171    type Output = Self;
172    #[inline]
173    fn div(self, rhs: Self) -> Self::Output {
174        self.check_modulus_eq(&rhs);
175        let ReducedInt { a, r } = rhs;
176        let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG));
177        ReducedInt { a, r }
178    }
179}
180impl<T: PartialEq + Clone, R: Reducer<T>> Div<&ReducedInt<T, R>> for ReducedInt<T, R> {
181    type Output = Self;
182    #[inline]
183    fn div(self, rhs: &Self) -> Self::Output {
184        self.check_modulus_eq(rhs);
185        let Self { a, r } = self;
186        let a = r.mul(&a, &r.inv(rhs.a.clone()).expect(INV_ERR_MSG));
187        ReducedInt { a, r }
188    }
189}
190impl<T: PartialEq + Clone, R: Reducer<T>> Div<ReducedInt<T, R>> for &ReducedInt<T, R> {
191    type Output = ReducedInt<T, R>;
192    #[inline]
193    fn div(self, rhs: ReducedInt<T, R>) -> Self::Output {
194        self.check_modulus_eq(&rhs);
195        let ReducedInt { a, r } = rhs;
196        let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG));
197        ReducedInt { a, r }
198    }
199}
200impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Div<&ReducedInt<T, R>> for &ReducedInt<T, R> {
201    type Output = ReducedInt<T, R>;
202    #[inline]
203    fn div(self, rhs: &ReducedInt<T, R>) -> Self::Output {
204        self.check_modulus_eq(rhs);
205        let a = self
206            .r
207            .mul(&self.a, &self.r.inv(rhs.a.clone()).expect(INV_ERR_MSG));
208        ReducedInt {
209            a,
210            r: self.r.clone(),
211        }
212    }
213}
214
215#[cfg(feature = "num_traits")]
216impl<T: PartialEq, R: Reducer<T>> Pow<T> for ReducedInt<T, R> {
217    type Output = Self;
218    #[inline]
219    fn pow(self, rhs: T) -> Self::Output {
220        ReducedInt::pow(self, rhs)
221    }
222}
223#[cfg(feature = "num_traits")]
224impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Pow<T> for &ReducedInt<T, R> {
225    type Output = ReducedInt<T, R>;
226    #[inline]
227    fn pow(self, rhs: T) -> Self::Output {
228        let a = self.r.pow(self.a.clone(), rhs);
229        ReducedInt {
230            a,
231            r: self.r.clone(),
232        }
233    }
234}
235
236impl<T: PartialEq + Clone, R: Reducer<T> + Clone> ModularInteger for ReducedInt<T, R> {
237    type Base = T;
238
239    #[inline]
240    fn modulus(&self) -> T {
241        self.r.modulus()
242    }
243
244    #[inline(always)]
245    fn residue(&self) -> T {
246        debug_assert!(self.r.check(&self.a));
247        self.r.residue(self.a.clone())
248    }
249
250    #[inline(always)]
251    fn is_zero(&self) -> bool {
252        self.r.is_zero(&self.a)
253    }
254
255    #[inline]
256    fn convert(&self, n: T) -> Self {
257        Self {
258            a: self.r.transform(n),
259            r: self.r.clone(),
260        }
261    }
262
263    #[inline]
264    fn double(self) -> Self {
265        let Self { a, r } = self;
266        let a = r.dbl(a);
267        Self { a, r }
268    }
269
270    #[inline]
271    fn square(self) -> Self {
272        let Self { a, r } = self;
273        let a = r.sqr(a);
274        Self { a, r }
275    }
276}
277
278// An vanilla reducer is also provided here
279/// A plain reducer that just use normal [Rem] operators. It will keep the integer
280/// in range [0, modulus) after each operation.
281#[derive(Debug, Clone, Copy)]
282pub struct Vanilla<T>(T);
283
284macro_rules! impl_uprim_vanilla_core_const {
285    ($($T:ty)*) => {$(
286        // These methods are for internal use only, wait for the introduction of const Trait in Rust
287        impl Vanilla<$T> {
288            #[inline]
289            pub(crate) const fn add(m: &$T, lhs: $T, rhs: $T) -> $T {
290                let (sum, overflow) = lhs.overflowing_add(rhs);
291                if overflow || sum >= *m {
292                    let (sum2, overflow2) = sum.overflowing_sub(*m);
293                    debug_assert!(overflow == overflow2);
294                    sum2
295                } else {
296                    sum
297                }
298            }
299
300            #[inline]
301            pub(crate) const fn dbl(m: &$T, target: $T) -> $T {
302                Self::add(m, target, target)
303            }
304
305            #[inline]
306            pub(crate) const fn sub(m: &$T, lhs: $T, rhs: $T) -> $T {
307                // this implementation should be equivalent to using overflowing_add and _sub after optimization.
308                if lhs >= rhs {
309                    lhs - rhs
310                } else {
311                    *m - (rhs - lhs)
312                }
313            }
314
315            #[inline]
316            pub(crate) const fn neg(m: &$T, target: $T) -> $T {
317                match target {
318                    0 => 0,
319                    x => *m - x
320                }
321            }
322        }
323    )*};
324}
325impl_uprim_vanilla_core_const!(u8 u16 u32 u64 u128 usize);
326
327macro_rules! impl_reduced_binary_pow {
328    ($T:ty) => {
329        fn pow(&self, base: $T, exp: &$T) -> $T {
330            match *exp {
331                1 => base,
332                2 => self.sqr(base),
333                e => {
334                    let mut multi = base;
335                    let mut exp = e;
336                    let mut result = self.transform(1);
337                    while exp > 0 {
338                        if exp & 1 != 0 {
339                            result = self.mul(&result, &multi);
340                        }
341                        multi = self.sqr(multi);
342                        exp >>= 1;
343                    }
344                    result
345                }
346            }
347        }
348    };
349}
350
351pub(crate) use impl_reduced_binary_pow;
352
353macro_rules! impl_uprim_vanilla_core {
354    ($single:ty) => {
355        #[inline(always)]
356        fn new(m: &$single) -> Self {
357            assert!(m > &0);
358            Self(*m)
359        }
360        #[inline(always)]
361        fn transform(&self, target: $single) -> $single {
362            target % self.0
363        }
364        #[inline(always)]
365        fn check(&self, target: &$single) -> bool {
366            *target < self.0
367        }
368        #[inline(always)]
369        fn residue(&self, target: $single) -> $single {
370            target
371        }
372        #[inline(always)]
373        fn modulus(&self) -> $single {
374            self.0
375        }
376        #[inline(always)]
377        fn is_zero(&self, target: &$single) -> bool {
378            *target == 0
379        }
380
381        #[inline(always)]
382        fn add(&self, lhs: &$single, rhs: &$single) -> $single {
383            Vanilla::<$single>::add(&self.0, *lhs, *rhs)
384        }
385
386        #[inline(always)]
387        fn dbl(&self, target: $single) -> $single {
388            Vanilla::<$single>::dbl(&self.0, target)
389        }
390
391        #[inline(always)]
392        fn sub(&self, lhs: &$single, rhs: &$single) -> $single {
393            Vanilla::<$single>::sub(&self.0, *lhs, *rhs)
394        }
395
396        #[inline(always)]
397        fn neg(&self, target: $single) -> $single {
398            Vanilla::<$single>::neg(&self.0, target)
399        }
400
401        #[inline(always)]
402        fn inv(&self, target: $single) -> Option<$single> {
403            target.invm(&self.0)
404        }
405
406        impl_reduced_binary_pow!($single);
407    };
408}
409
410macro_rules! impl_uprim_vanilla {
411    ($t:ident, $ns:ident) => {
412        mod $ns {
413            use super::*;
414            use crate::word::$t::*;
415
416            impl Reducer<$t> for Vanilla<$t> {
417                impl_uprim_vanilla_core!($t);
418
419                #[inline]
420                fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
421                    (wmul(*lhs, *rhs) % extend(self.0)) as $t
422                }
423
424                #[inline]
425                fn sqr(&self, target: $t) -> $t {
426                    (wsqr(target) % extend(self.0)) as $t
427                }
428            }
429        }
430    };
431}
432
433impl_uprim_vanilla!(u8, u8_impl);
434impl_uprim_vanilla!(u16, u16_impl);
435impl_uprim_vanilla!(u32, u32_impl);
436impl_uprim_vanilla!(u64, u64_impl);
437impl_uprim_vanilla!(usize, usize_impl);
438
439impl Reducer<u128> for Vanilla<u128> {
440    impl_uprim_vanilla_core!(u128);
441
442    #[inline]
443    fn mul(&self, lhs: &u128, rhs: &u128) -> u128 {
444        udouble::widening_mul(*lhs, *rhs) % self.0
445    }
446
447    #[inline]
448    fn sqr(&self, target: u128) -> u128 {
449        udouble::widening_square(target) % self.0
450    }
451}
452
453/// An integer in modulo ring based on conventional [Rem] operations
454pub type VanillaInt<T> = ReducedInt<T, Vanilla<T>>;
455
456#[cfg(test)]
457pub(crate) mod tests {
458    use super::*;
459    use crate::{ModularCoreOps, ModularPow, ModularUnaryOps};
460    use core::marker::PhantomData;
461    use rand::random;
462
463    pub(crate) struct ReducedTester<T>(PhantomData<T>);
464
465    macro_rules! impl_reduced_test_for {
466        ($($T:ty)*) => {$(
467            impl ReducedTester<$T> {
468                pub fn test_against_modops<R: Reducer<$T> + Copy>(odd_only: bool) {
469                    let mut m = random::<$T>().saturating_add(1);
470                    if odd_only {
471                        m |= 1;
472                    }
473
474                    let (a, b) = (random::<$T>(), random::<$T>());
475                    let am = ReducedInt::<$T, R>::new(a, &m);
476                    let bm = ReducedInt::<$T, R>::new(b, &m);
477                    assert_eq!((am + bm).residue(), a.addm(b, &m), "incorrect add");
478                    assert_eq!((am - bm).residue(), a.subm(b, &m), "incorrect sub");
479                    assert_eq!((am * bm).residue(), a.mulm(b, &m), "incorrect mul");
480                    assert_eq!(am.neg().residue(), a.negm(&m), "incorrect neg");
481                    assert_eq!(am.double().residue(), a.dblm(&m), "incorrect dbl");
482                    assert_eq!(am.square().residue(), a.sqm(&m), "incorrect sqr");
483
484                    let e = random::<u8>() as $T;
485                    assert_eq!(am.pow(&e).residue(), a.powm(e, &m), "incorrect pow");
486                    if let Some(v) = a.invm(&m) {
487                        assert_eq!(am.inv().unwrap().residue(), v, "incorrect inv");
488                    }
489                }
490            }
491        )*};
492    }
493    impl_reduced_test_for!(u8 u16 u32 u64 u128 usize);
494
495    #[test]
496    fn test_against_modops() {
497        for _ in 0..10 {
498            ReducedTester::<u8>::test_against_modops::<Vanilla<u8>>(false);
499            ReducedTester::<u16>::test_against_modops::<Vanilla<u16>>(false);
500            ReducedTester::<u32>::test_against_modops::<Vanilla<u32>>(false);
501            ReducedTester::<u64>::test_against_modops::<Vanilla<u64>>(false);
502            ReducedTester::<u128>::test_against_modops::<Vanilla<u128>>(false);
503            ReducedTester::<usize>::test_against_modops::<Vanilla<usize>>(false);
504        }
505    }
506}