1use crate::reduced::impl_reduced_binary_pow;
2use crate::{ModularUnaryOps, Reducer, Vanilla};
3
4mod neg_mod_inv {
7 #[rustfmt::skip]
9 const BINV_TABLE: [u8; 128] = [
10 0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF, 0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF,
11 0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF, 0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF,
12 0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF, 0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F,
13 0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F, 0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F,
14 0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F, 0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F,
15 0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F, 0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F,
16 0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F, 0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F,
17 0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F, 0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF,
18 ];
19
20 pub mod u8 {
21 use super::*;
22 pub const fn neginv(m: u8) -> u8 {
23 let i = BINV_TABLE[((m >> 1) & 0x7F) as usize];
24 i.wrapping_neg()
25 }
26 }
27
28 pub mod u16 {
29 use super::*;
30 pub const fn neginv(m: u16) -> u16 {
31 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u16;
32 i = 2u16.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
34 i.wrapping_neg()
35 }
36 }
37
38 pub mod u32 {
39 use super::*;
40 pub const fn neginv(m: u32) -> u32 {
41 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u32;
42 i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
43 i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
44 i.wrapping_neg()
45 }
46 }
47
48 pub mod u64 {
49 use super::*;
50 pub const fn neginv(m: u64) -> u64 {
51 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u64;
52 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
53 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
54 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
55 i.wrapping_neg()
56 }
57 }
58
59 pub mod u128 {
60 use super::*;
61 pub const fn neginv(m: u128) -> u128 {
62 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u128;
63 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
64 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
65 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
66 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
67 i.wrapping_neg()
68 }
69 }
70
71 pub mod usize {
72 #[inline]
73 pub const fn neginv(m: usize) -> usize {
74 #[cfg(target_pointer_width = "16")]
75 return super::u16::neginv(m as _) as _;
76 #[cfg(target_pointer_width = "32")]
77 return super::u32::neginv(m as _) as _;
78 #[cfg(target_pointer_width = "64")]
79 return super::u64::neginv(m as _) as _;
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
90pub struct Montgomery<T> {
91 m: T, inv: T, }
94
95macro_rules! impl_montgomery_for {
96 ($t:ident, $ns:ident) => {
97 mod $ns {
98 use super::*;
99 use crate::word::$t::*;
100 use neg_mod_inv::$t::neginv;
101
102 impl Montgomery<$t> {
103 pub const fn new(m: $t) -> Self {
104 assert!(
105 m & 1 != 0,
106 "Only odd modulus are supported by the Montgomery form"
107 );
108 Self { m, inv: neginv(m) }
109 }
110 const fn reduce(&self, monty: DoubleWord) -> $t {
111 debug_assert!(high(monty) < self.m);
112
113 let tm = low(monty).wrapping_mul(self.inv);
115 let (t, overflow) = monty.overflowing_add(wmul(tm, self.m));
116 let t = high(t);
117
118 if overflow {
119 t + self.m.wrapping_neg()
120 } else if t >= self.m {
121 t - self.m
122 } else {
123 t
124 }
125 }
126 }
127
128 impl Reducer<$t> for Montgomery<$t> {
129 #[inline]
130 fn new(m: &$t) -> Self {
131 Self::new(*m)
132 }
133 #[inline]
134 fn transform(&self, target: $t) -> $t {
135 if target == 0 {
136 return 0;
137 }
138 nrem(merge(0, target), self.m)
139 }
140 #[inline]
141 fn check(&self, target: &$t) -> bool {
142 *target < self.m
143 }
144
145 #[inline]
146 fn residue(&self, target: $t) -> $t {
147 self.reduce(extend(target))
148 }
149 #[inline(always)]
150 fn modulus(&self) -> $t {
151 self.m
152 }
153 #[inline(always)]
154 fn is_zero(&self, target: &$t) -> bool {
155 *target == 0
156 }
157
158 #[inline(always)]
159 fn add(&self, lhs: &$t, rhs: &$t) -> $t {
160 Vanilla::<$t>::add(&self.m, *lhs, *rhs)
161 }
162
163 #[inline(always)]
164 fn dbl(&self, target: $t) -> $t {
165 Vanilla::<$t>::dbl(&self.m, target)
166 }
167
168 #[inline(always)]
169 fn sub(&self, lhs: &$t, rhs: &$t) -> $t {
170 Vanilla::<$t>::sub(&self.m, *lhs, *rhs)
171 }
172
173 #[inline(always)]
174 fn neg(&self, target: $t) -> $t {
175 Vanilla::<$t>::neg(&self.m, target)
176 }
177
178 #[inline]
179 fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
180 self.reduce(wmul(*lhs, *rhs))
181 }
182
183 #[inline]
184 fn sqr(&self, target: $t) -> $t {
185 self.reduce(wsqr(target))
186 }
187
188 #[inline(always)]
189 fn inv(&self, target: $t) -> Option<$t> {
190 self.residue(target)
193 .invm(&self.m)
194 .map(|v| self.transform(v))
195 }
196
197 impl_reduced_binary_pow!(Word);
198 }
199 }
200 };
201}
202impl_montgomery_for!(u8, u8_impl);
203impl_montgomery_for!(u16, u16_impl);
204impl_montgomery_for!(u32, u32_impl);
205impl_montgomery_for!(u64, u64_impl);
206impl_montgomery_for!(u128, u128_impl);
207impl_montgomery_for!(usize, usize_impl);
208
209#[cfg(test)]
215mod tests {
216 use super::*;
217 use rand::random;
218
219 const NRANDOM: u32 = 10;
220
221 #[test]
222 fn creation_test() {
223 let a = (0x81u128 << 120) - 1;
225 let m = (0x81u128 << 119) - 1;
226 let m = m >> m.trailing_zeros();
227 let r = Montgomery::<u128>::new(m);
228 assert_eq!(r.residue(r.transform(a)), a % m);
229
230 let r = Montgomery::<u8>::new(11u8);
232 assert!(r.is_zero(&r.transform(0)));
233 let five = r.transform(5u8);
234 let six = r.transform(6u8);
235 assert!(r.is_zero(&r.add(&five, &six)));
236
237 for _ in 0..NRANDOM {
239 let a = random::<u8>();
240 let m = random::<u8>() | 1;
241 let r = Montgomery::<u8>::new(m);
242 assert_eq!(r.residue(r.transform(a)), a % m);
243
244 let a = random::<u16>();
245 let m = random::<u16>() | 1;
246 let r = Montgomery::<u16>::new(m);
247 assert_eq!(r.residue(r.transform(a)), a % m);
248
249 let a = random::<u32>();
250 let m = random::<u32>() | 1;
251 let r = Montgomery::<u32>::new(m);
252 assert_eq!(r.residue(r.transform(a)), a % m);
253
254 let a = random::<u64>();
255 let m = random::<u64>() | 1;
256 let r = Montgomery::<u64>::new(m);
257 assert_eq!(r.residue(r.transform(a)), a % m);
258
259 let a = random::<u128>();
260 let m = random::<u128>() | 1;
261 let r = Montgomery::<u128>::new(m);
262 assert_eq!(r.residue(r.transform(a)), a % m);
263 }
264 }
265
266 #[test]
267 fn test_against_modops() {
268 use crate::reduced::tests::ReducedTester;
269 for _ in 0..NRANDOM {
270 ReducedTester::<u8>::test_against_modops::<Montgomery<u8>>(true);
271 ReducedTester::<u16>::test_against_modops::<Montgomery<u16>>(true);
272 ReducedTester::<u32>::test_against_modops::<Montgomery<u32>>(true);
273 ReducedTester::<u64>::test_against_modops::<Montgomery<u64>>(true);
274 ReducedTester::<u128>::test_against_modops::<Montgomery<u128>>(true);
275 ReducedTester::<usize>::test_against_modops::<Montgomery<usize>>(true);
276 }
277 }
278}