num_modular/
monty.rs

1use crate::reduced::impl_reduced_binary_pow;
2use crate::{ModularUnaryOps, Reducer, Vanilla};
3
4/// Negated modular inverse on binary bases
5/// `neginv` calculates `-(m^-1) mod R`, `R = 2^k. If m is odd, then result of m + 1 will be returned.
6mod neg_mod_inv {
7    // Entry i contains (2i+1)^(-1) mod 256.
8    #[rustfmt::skip]
9    const BINV_TABLE: [u8; 128] = [
10        0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF, 0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF,
11        0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF, 0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF,
12        0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF, 0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F,
13        0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F, 0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F,
14        0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F, 0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F,
15        0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F, 0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F,
16        0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F, 0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F,
17        0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F, 0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF,
18    ];
19
20    pub mod u8 {
21        use super::*;
22        pub const fn neginv(m: u8) -> u8 {
23            let i = BINV_TABLE[((m >> 1) & 0x7F) as usize];
24            i.wrapping_neg()
25        }
26    }
27
28    pub mod u16 {
29        use super::*;
30        pub const fn neginv(m: u16) -> u16 {
31            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u16;
32            // hensel lifting
33            i = 2u16.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
34            i.wrapping_neg()
35        }
36    }
37
38    pub mod u32 {
39        use super::*;
40        pub const fn neginv(m: u32) -> u32 {
41            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u32;
42            i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
43            i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
44            i.wrapping_neg()
45        }
46    }
47
48    pub mod u64 {
49        use super::*;
50        pub const fn neginv(m: u64) -> u64 {
51            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u64;
52            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
53            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
54            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
55            i.wrapping_neg()
56        }
57    }
58
59    pub mod u128 {
60        use super::*;
61        pub const fn neginv(m: u128) -> u128 {
62            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u128;
63            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
64            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
65            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
66            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
67            i.wrapping_neg()
68        }
69    }
70
71    pub mod usize {
72        #[inline]
73        pub const fn neginv(m: usize) -> usize {
74            #[cfg(target_pointer_width = "16")]
75            return super::u16::neginv(m as _) as _;
76            #[cfg(target_pointer_width = "32")]
77            return super::u32::neginv(m as _) as _;
78            #[cfg(target_pointer_width = "64")]
79            return super::u64::neginv(m as _) as _;
80        }
81    }
82}
83
84/// A modular reducer based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form), only supports odd modulus.
85///
86/// The generic type T represents the underlying integer representation for modular inverse `-m^-1 mod R`,
87/// and `R=2^B` will be used as the auxiliary modulus, where B is automatically selected
88/// based on the size of T.
89#[derive(Debug, Clone, Copy)]
90pub struct Montgomery<T> {
91    m: T,   // modulus
92    inv: T, // modular inverse of the modulus
93}
94
95macro_rules! impl_montgomery_for {
96    ($t:ident, $ns:ident) => {
97        mod $ns {
98            use super::*;
99            use crate::word::$t::*;
100            use neg_mod_inv::$t::neginv;
101
102            impl Montgomery<$t> {
103                pub const fn new(m: $t) -> Self {
104                    assert!(
105                        m & 1 != 0,
106                        "Only odd modulus are supported by the Montgomery form"
107                    );
108                    Self { m, inv: neginv(m) }
109                }
110                const fn reduce(&self, monty: DoubleWord) -> $t {
111                    debug_assert!(high(monty) < self.m);
112
113                    // REDC algorithm
114                    let tm = low(monty).wrapping_mul(self.inv);
115                    let (t, overflow) = monty.overflowing_add(wmul(tm, self.m));
116                    let t = high(t);
117
118                    if overflow {
119                        t + self.m.wrapping_neg()
120                    } else if t >= self.m {
121                        t - self.m
122                    } else {
123                        t
124                    }
125                }
126            }
127
128            impl Reducer<$t> for Montgomery<$t> {
129                #[inline]
130                fn new(m: &$t) -> Self {
131                    Self::new(*m)
132                }
133                #[inline]
134                fn transform(&self, target: $t) -> $t {
135                    if target == 0 {
136                        return 0;
137                    }
138                    nrem(merge(0, target), self.m)
139                }
140                #[inline]
141                fn check(&self, target: &$t) -> bool {
142                    *target < self.m
143                }
144
145                #[inline]
146                fn residue(&self, target: $t) -> $t {
147                    self.reduce(extend(target))
148                }
149                #[inline(always)]
150                fn modulus(&self) -> $t {
151                    self.m
152                }
153                #[inline(always)]
154                fn is_zero(&self, target: &$t) -> bool {
155                    *target == 0
156                }
157
158                #[inline(always)]
159                fn add(&self, lhs: &$t, rhs: &$t) -> $t {
160                    Vanilla::<$t>::add(&self.m, *lhs, *rhs)
161                }
162
163                #[inline(always)]
164                fn dbl(&self, target: $t) -> $t {
165                    Vanilla::<$t>::dbl(&self.m, target)
166                }
167
168                #[inline(always)]
169                fn sub(&self, lhs: &$t, rhs: &$t) -> $t {
170                    Vanilla::<$t>::sub(&self.m, *lhs, *rhs)
171                }
172
173                #[inline(always)]
174                fn neg(&self, target: $t) -> $t {
175                    Vanilla::<$t>::neg(&self.m, target)
176                }
177
178                #[inline]
179                fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
180                    self.reduce(wmul(*lhs, *rhs))
181                }
182
183                #[inline]
184                fn sqr(&self, target: $t) -> $t {
185                    self.reduce(wsqr(target))
186                }
187
188                #[inline(always)]
189                fn inv(&self, target: $t) -> Option<$t> {
190                    // TODO: support direct montgomery inverse
191                    // REF: http://cetinkayakoc.net/docs/j82.pdf
192                    self.residue(target)
193                        .invm(&self.m)
194                        .map(|v| self.transform(v))
195                }
196
197                impl_reduced_binary_pow!(Word);
198            }
199        }
200    };
201}
202impl_montgomery_for!(u8, u8_impl);
203impl_montgomery_for!(u16, u16_impl);
204impl_montgomery_for!(u32, u32_impl);
205impl_montgomery_for!(u64, u64_impl);
206impl_montgomery_for!(u128, u128_impl);
207impl_montgomery_for!(usize, usize_impl);
208
209// TODO(v0.6.x): accept even numbers by removing 2 factors from m and store the exponent
210// Requirement: 1. A separate class to perform modular arithmetics with 2^n as modulus
211//              2. Algorithm for construct residue from two components (see http://koclab.cs.ucsb.edu/teaching/cs154/docx/Notes7-Montgomery.pdf)
212// Or we can just provide crt function, and let the implementation of monty int with full modulus support as an example code.
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use rand::random;
218
219    const NRANDOM: u32 = 10;
220
221    #[test]
222    fn creation_test() {
223        // a deterministic test case for u128
224        let a = (0x81u128 << 120) - 1;
225        let m = (0x81u128 << 119) - 1;
226        let m = m >> m.trailing_zeros();
227        let r = Montgomery::<u128>::new(m);
228        assert_eq!(r.residue(r.transform(a)), a % m);
229
230        // is_zero test
231        let r = Montgomery::<u8>::new(11u8);
232        assert!(r.is_zero(&r.transform(0)));
233        let five = r.transform(5u8);
234        let six = r.transform(6u8);
235        assert!(r.is_zero(&r.add(&five, &six)));
236
237        // random creation test
238        for _ in 0..NRANDOM {
239            let a = random::<u8>();
240            let m = random::<u8>() | 1;
241            let r = Montgomery::<u8>::new(m);
242            assert_eq!(r.residue(r.transform(a)), a % m);
243
244            let a = random::<u16>();
245            let m = random::<u16>() | 1;
246            let r = Montgomery::<u16>::new(m);
247            assert_eq!(r.residue(r.transform(a)), a % m);
248
249            let a = random::<u32>();
250            let m = random::<u32>() | 1;
251            let r = Montgomery::<u32>::new(m);
252            assert_eq!(r.residue(r.transform(a)), a % m);
253
254            let a = random::<u64>();
255            let m = random::<u64>() | 1;
256            let r = Montgomery::<u64>::new(m);
257            assert_eq!(r.residue(r.transform(a)), a % m);
258
259            let a = random::<u128>();
260            let m = random::<u128>() | 1;
261            let r = Montgomery::<u128>::new(m);
262            assert_eq!(r.residue(r.transform(a)), a % m);
263        }
264    }
265
266    #[test]
267    fn test_against_modops() {
268        use crate::reduced::tests::ReducedTester;
269        for _ in 0..NRANDOM {
270            ReducedTester::<u8>::test_against_modops::<Montgomery<u8>>(true);
271            ReducedTester::<u16>::test_against_modops::<Montgomery<u16>>(true);
272            ReducedTester::<u32>::test_against_modops::<Montgomery<u32>>(true);
273            ReducedTester::<u64>::test_against_modops::<Montgomery<u64>>(true);
274            ReducedTester::<u128>::test_against_modops::<Montgomery<u128>>(true);
275            ReducedTester::<usize>::test_against_modops::<Montgomery<usize>>(true);
276        }
277    }
278}