1use crate::{udouble, ModularInteger, ModularUnaryOps, Reducer};
2use core::ops::*;
3#[cfg(feature = "num_traits")]
4use num_traits::{Inv, Pow};
5
6#[derive(Debug, Clone, Copy)]
8pub struct ReducedInt<T, R: Reducer<T>> {
9 a: T,
11
12 r: R,
14}
15
16impl<T, R: Reducer<T>> ReducedInt<T, R> {
17 #[inline]
19 pub fn new(n: T, m: &T) -> Self {
20 let r = R::new(m);
21 let a = r.transform(n);
22 Self { a, r }
23 }
24
25 #[inline(always)]
26 fn check_modulus_eq(&self, rhs: &Self)
27 where
28 T: PartialEq,
29 {
30 if cfg!(debug_assertions) && self.r.modulus() != rhs.r.modulus() {
32 panic!("The modulus of two operators should be the same!");
33 }
34 }
35
36 #[inline(always)]
37 pub fn repr(&self) -> &T {
38 &self.a
39 }
40
41 #[inline(always)]
42 pub fn inv(self) -> Option<Self> {
43 Some(Self {
44 a: self.r.inv(self.a)?,
45 r: self.r,
46 })
47 }
48
49 #[inline(always)]
50 pub fn pow(self, exp: &T) -> Self {
51 Self {
52 a: self.r.pow(self.a, exp),
53 r: self.r,
54 }
55 }
56}
57
58impl<T: PartialEq, R: Reducer<T>> PartialEq for ReducedInt<T, R> {
59 #[inline]
60 fn eq(&self, other: &Self) -> bool {
61 self.check_modulus_eq(other);
62 self.a == other.a
63 }
64}
65
66macro_rules! impl_binops {
67 ($method:ident, impl $op:ident) => {
68 impl<T: PartialEq, R: Reducer<T>> $op for ReducedInt<T, R> {
69 type Output = Self;
70 fn $method(self, rhs: Self) -> Self::Output {
71 self.check_modulus_eq(&rhs);
72 let Self { a, r } = self;
73 let a = r.$method(&a, &rhs.a);
74 Self { a, r }
75 }
76 }
77
78 impl<T: PartialEq + Clone, R: Reducer<T>> $op<&Self> for ReducedInt<T, R> {
79 type Output = Self;
80 #[inline]
81 fn $method(self, rhs: &Self) -> Self::Output {
82 self.check_modulus_eq(&rhs);
83 let Self { a, r } = self;
84 let a = r.$method(&a, &rhs.a);
85 Self { a, r }
86 }
87 }
88
89 impl<T: PartialEq + Clone, R: Reducer<T>> $op<ReducedInt<T, R>> for &ReducedInt<T, R> {
90 type Output = ReducedInt<T, R>;
91 #[inline]
92 fn $method(self, rhs: ReducedInt<T, R>) -> Self::Output {
93 self.check_modulus_eq(&rhs);
94 let ReducedInt { a, r } = rhs;
95 let a = r.$method(&self.a, &a);
96 ReducedInt { a, r }
97 }
98 }
99
100 impl<T: PartialEq + Clone, R: Reducer<T> + Clone> $op<&ReducedInt<T, R>>
101 for &ReducedInt<T, R>
102 {
103 type Output = ReducedInt<T, R>;
104 #[inline]
105 fn $method(self, rhs: &ReducedInt<T, R>) -> Self::Output {
106 self.check_modulus_eq(&rhs);
107 let a = self.r.$method(&self.a, &rhs.a);
108 ReducedInt {
109 a,
110 r: self.r.clone(),
111 }
112 }
113 }
114
115 impl<T: PartialEq, R: Reducer<T>> $op<T> for ReducedInt<T, R> {
116 type Output = Self;
117 fn $method(self, rhs: T) -> Self::Output {
118 let Self { a, r } = self;
119 let rhs = r.transform(rhs);
120 let a = r.$method(&a, &rhs);
121 Self { a, r }
122 }
123 }
124 };
125}
126impl_binops!(add, impl Add);
127impl_binops!(sub, impl Sub);
128impl_binops!(mul, impl Mul);
129
130impl<T: PartialEq, R: Reducer<T>> Neg for ReducedInt<T, R> {
131 type Output = Self;
132 #[inline]
133 fn neg(self) -> Self::Output {
134 let Self { a, r } = self;
135 let a = r.neg(a);
136 Self { a, r }
137 }
138}
139impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Neg for &ReducedInt<T, R> {
140 type Output = ReducedInt<T, R>;
141 #[inline]
142 fn neg(self) -> Self::Output {
143 let a = self.r.neg(self.a.clone());
144 ReducedInt {
145 a,
146 r: self.r.clone(),
147 }
148 }
149}
150
151const INV_ERR_MSG: &str = "the modular inverse doesn't exist!";
152
153#[cfg(feature = "num_traits")]
154impl<T: PartialEq, R: Reducer<T>> Inv for ReducedInt<T, R> {
155 type Output = Self;
156 #[inline]
157 fn inv(self) -> Self::Output {
158 self.inv().expect(INV_ERR_MSG)
159 }
160}
161#[cfg(feature = "num_traits")]
162impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Inv for &ReducedInt<T, R> {
163 type Output = ReducedInt<T, R>;
164 #[inline]
165 fn inv(self) -> Self::Output {
166 self.clone().inv().expect(INV_ERR_MSG)
167 }
168}
169
170impl<T: PartialEq, R: Reducer<T>> Div for ReducedInt<T, R> {
171 type Output = Self;
172 #[inline]
173 fn div(self, rhs: Self) -> Self::Output {
174 self.check_modulus_eq(&rhs);
175 let ReducedInt { a, r } = rhs;
176 let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG));
177 ReducedInt { a, r }
178 }
179}
180impl<T: PartialEq + Clone, R: Reducer<T>> Div<&ReducedInt<T, R>> for ReducedInt<T, R> {
181 type Output = Self;
182 #[inline]
183 fn div(self, rhs: &Self) -> Self::Output {
184 self.check_modulus_eq(rhs);
185 let Self { a, r } = self;
186 let a = r.mul(&a, &r.inv(rhs.a.clone()).expect(INV_ERR_MSG));
187 ReducedInt { a, r }
188 }
189}
190impl<T: PartialEq + Clone, R: Reducer<T>> Div<ReducedInt<T, R>> for &ReducedInt<T, R> {
191 type Output = ReducedInt<T, R>;
192 #[inline]
193 fn div(self, rhs: ReducedInt<T, R>) -> Self::Output {
194 self.check_modulus_eq(&rhs);
195 let ReducedInt { a, r } = rhs;
196 let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG));
197 ReducedInt { a, r }
198 }
199}
200impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Div<&ReducedInt<T, R>> for &ReducedInt<T, R> {
201 type Output = ReducedInt<T, R>;
202 #[inline]
203 fn div(self, rhs: &ReducedInt<T, R>) -> Self::Output {
204 self.check_modulus_eq(rhs);
205 let a = self
206 .r
207 .mul(&self.a, &self.r.inv(rhs.a.clone()).expect(INV_ERR_MSG));
208 ReducedInt {
209 a,
210 r: self.r.clone(),
211 }
212 }
213}
214
215#[cfg(feature = "num_traits")]
216impl<T: PartialEq, R: Reducer<T>> Pow<T> for ReducedInt<T, R> {
217 type Output = Self;
218 #[inline]
219 fn pow(self, rhs: T) -> Self::Output {
220 ReducedInt::pow(self, rhs)
221 }
222}
223#[cfg(feature = "num_traits")]
224impl<T: PartialEq + Clone, R: Reducer<T> + Clone> Pow<T> for &ReducedInt<T, R> {
225 type Output = ReducedInt<T, R>;
226 #[inline]
227 fn pow(self, rhs: T) -> Self::Output {
228 let a = self.r.pow(self.a.clone(), rhs);
229 ReducedInt {
230 a,
231 r: self.r.clone(),
232 }
233 }
234}
235
236impl<T: PartialEq + Clone, R: Reducer<T> + Clone> ModularInteger for ReducedInt<T, R> {
237 type Base = T;
238
239 #[inline]
240 fn modulus(&self) -> T {
241 self.r.modulus()
242 }
243
244 #[inline(always)]
245 fn residue(&self) -> T {
246 debug_assert!(self.r.check(&self.a));
247 self.r.residue(self.a.clone())
248 }
249
250 #[inline(always)]
251 fn is_zero(&self) -> bool {
252 self.r.is_zero(&self.a)
253 }
254
255 #[inline]
256 fn convert(&self, n: T) -> Self {
257 Self {
258 a: self.r.transform(n),
259 r: self.r.clone(),
260 }
261 }
262
263 #[inline]
264 fn double(self) -> Self {
265 let Self { a, r } = self;
266 let a = r.dbl(a);
267 Self { a, r }
268 }
269
270 #[inline]
271 fn square(self) -> Self {
272 let Self { a, r } = self;
273 let a = r.sqr(a);
274 Self { a, r }
275 }
276}
277
278#[derive(Debug, Clone, Copy)]
282pub struct Vanilla<T>(T);
283
284macro_rules! impl_uprim_vanilla_core_const {
285 ($($T:ty)*) => {$(
286 impl Vanilla<$T> {
288 #[inline]
289 pub(crate) const fn add(m: &$T, lhs: $T, rhs: $T) -> $T {
290 let (sum, overflow) = lhs.overflowing_add(rhs);
291 if overflow || sum >= *m {
292 let (sum2, overflow2) = sum.overflowing_sub(*m);
293 debug_assert!(overflow == overflow2);
294 sum2
295 } else {
296 sum
297 }
298 }
299
300 #[inline]
301 pub(crate) const fn dbl(m: &$T, target: $T) -> $T {
302 Self::add(m, target, target)
303 }
304
305 #[inline]
306 pub(crate) const fn sub(m: &$T, lhs: $T, rhs: $T) -> $T {
307 if lhs >= rhs {
309 lhs - rhs
310 } else {
311 *m - (rhs - lhs)
312 }
313 }
314
315 #[inline]
316 pub(crate) const fn neg(m: &$T, target: $T) -> $T {
317 match target {
318 0 => 0,
319 x => *m - x
320 }
321 }
322 }
323 )*};
324}
325impl_uprim_vanilla_core_const!(u8 u16 u32 u64 u128 usize);
326
327macro_rules! impl_reduced_binary_pow {
328 ($T:ty) => {
329 fn pow(&self, base: $T, exp: &$T) -> $T {
330 match *exp {
331 1 => base,
332 2 => self.sqr(base),
333 e => {
334 let mut multi = base;
335 let mut exp = e;
336 let mut result = self.transform(1);
337 while exp > 0 {
338 if exp & 1 != 0 {
339 result = self.mul(&result, &multi);
340 }
341 multi = self.sqr(multi);
342 exp >>= 1;
343 }
344 result
345 }
346 }
347 }
348 };
349}
350
351pub(crate) use impl_reduced_binary_pow;
352
353macro_rules! impl_uprim_vanilla_core {
354 ($single:ty) => {
355 #[inline(always)]
356 fn new(m: &$single) -> Self {
357 assert!(m > &0);
358 Self(*m)
359 }
360 #[inline(always)]
361 fn transform(&self, target: $single) -> $single {
362 target % self.0
363 }
364 #[inline(always)]
365 fn check(&self, target: &$single) -> bool {
366 *target < self.0
367 }
368 #[inline(always)]
369 fn residue(&self, target: $single) -> $single {
370 target
371 }
372 #[inline(always)]
373 fn modulus(&self) -> $single {
374 self.0
375 }
376 #[inline(always)]
377 fn is_zero(&self, target: &$single) -> bool {
378 *target == 0
379 }
380
381 #[inline(always)]
382 fn add(&self, lhs: &$single, rhs: &$single) -> $single {
383 Vanilla::<$single>::add(&self.0, *lhs, *rhs)
384 }
385
386 #[inline(always)]
387 fn dbl(&self, target: $single) -> $single {
388 Vanilla::<$single>::dbl(&self.0, target)
389 }
390
391 #[inline(always)]
392 fn sub(&self, lhs: &$single, rhs: &$single) -> $single {
393 Vanilla::<$single>::sub(&self.0, *lhs, *rhs)
394 }
395
396 #[inline(always)]
397 fn neg(&self, target: $single) -> $single {
398 Vanilla::<$single>::neg(&self.0, target)
399 }
400
401 #[inline(always)]
402 fn inv(&self, target: $single) -> Option<$single> {
403 target.invm(&self.0)
404 }
405
406 impl_reduced_binary_pow!($single);
407 };
408}
409
410macro_rules! impl_uprim_vanilla {
411 ($t:ident, $ns:ident) => {
412 mod $ns {
413 use super::*;
414 use crate::word::$t::*;
415
416 impl Reducer<$t> for Vanilla<$t> {
417 impl_uprim_vanilla_core!($t);
418
419 #[inline]
420 fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
421 (wmul(*lhs, *rhs) % extend(self.0)) as $t
422 }
423
424 #[inline]
425 fn sqr(&self, target: $t) -> $t {
426 (wsqr(target) % extend(self.0)) as $t
427 }
428 }
429 }
430 };
431}
432
433impl_uprim_vanilla!(u8, u8_impl);
434impl_uprim_vanilla!(u16, u16_impl);
435impl_uprim_vanilla!(u32, u32_impl);
436impl_uprim_vanilla!(u64, u64_impl);
437impl_uprim_vanilla!(usize, usize_impl);
438
439impl Reducer<u128> for Vanilla<u128> {
440 impl_uprim_vanilla_core!(u128);
441
442 #[inline]
443 fn mul(&self, lhs: &u128, rhs: &u128) -> u128 {
444 udouble::widening_mul(*lhs, *rhs) % self.0
445 }
446
447 #[inline]
448 fn sqr(&self, target: u128) -> u128 {
449 udouble::widening_square(target) % self.0
450 }
451}
452
453pub type VanillaInt<T> = ReducedInt<T, Vanilla<T>>;
455
456#[cfg(test)]
457pub(crate) mod tests {
458 use super::*;
459 use crate::{ModularCoreOps, ModularPow, ModularUnaryOps};
460 use core::marker::PhantomData;
461 use rand::random;
462
463 pub(crate) struct ReducedTester<T>(PhantomData<T>);
464
465 macro_rules! impl_reduced_test_for {
466 ($($T:ty)*) => {$(
467 impl ReducedTester<$T> {
468 pub fn test_against_modops<R: Reducer<$T> + Copy>(odd_only: bool) {
469 let mut m = random::<$T>().saturating_add(1);
470 if odd_only {
471 m |= 1;
472 }
473
474 let (a, b) = (random::<$T>(), random::<$T>());
475 let am = ReducedInt::<$T, R>::new(a, &m);
476 let bm = ReducedInt::<$T, R>::new(b, &m);
477 assert_eq!((am + bm).residue(), a.addm(b, &m), "incorrect add");
478 assert_eq!((am - bm).residue(), a.subm(b, &m), "incorrect sub");
479 assert_eq!((am * bm).residue(), a.mulm(b, &m), "incorrect mul");
480 assert_eq!(am.neg().residue(), a.negm(&m), "incorrect neg");
481 assert_eq!(am.double().residue(), a.dblm(&m), "incorrect dbl");
482 assert_eq!(am.square().residue(), a.sqm(&m), "incorrect sqr");
483
484 let e = random::<u8>() as $T;
485 assert_eq!(am.pow(&e).residue(), a.powm(e, &m), "incorrect pow");
486 if let Some(v) = a.invm(&m) {
487 assert_eq!(am.inv().unwrap().residue(), v, "incorrect inv");
488 }
489 }
490 }
491 )*};
492 }
493 impl_reduced_test_for!(u8 u16 u32 u64 u128 usize);
494
495 #[test]
496 fn test_against_modops() {
497 for _ in 0..10 {
498 ReducedTester::<u8>::test_against_modops::<Vanilla<u8>>(false);
499 ReducedTester::<u16>::test_against_modops::<Vanilla<u16>>(false);
500 ReducedTester::<u32>::test_against_modops::<Vanilla<u32>>(false);
501 ReducedTester::<u64>::test_against_modops::<Vanilla<u64>>(false);
502 ReducedTester::<u128>::test_against_modops::<Vanilla<u128>>(false);
503 ReducedTester::<usize>::test_against_modops::<Vanilla<usize>>(false);
504 }
505 }
506}