num_modular/preinv.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
use crate::{DivExact, ModularUnaryOps};
/// Pre-computing the modular inverse for fast divisibility check.
///
/// This struct stores the modular inverse of a divisor, and a limit for divisibility check.
/// See <https://math.stackexchange.com/a/1251328> for the explanation of the trick
#[derive(Debug, Clone, Copy)]
pub struct PreModInv<T> {
d_inv: T, // modular inverse of divisor
q_lim: T, // limit of residue
}
macro_rules! impl_preinv_for_prim_int {
($t:ident, $ns:ident) => {
mod $ns {
use super::*;
use crate::word::$t::*;
impl PreModInv<$t> {
/// Construct the preinv instance with raw values.
///
/// This function can be used to initialize preinv in a constant context, the divisor d
/// is required only for verification of d_inv and q_lim.
#[inline]
pub const fn new(d_inv: $t, q_lim: $t) -> Self {
Self { d_inv, q_lim }
}
// check if the divisor is consistent in debug mode
#[inline]
fn debug_check(&self, d: $t) {
debug_assert!(d % 2 != 0, "only odd divisors are supported");
debug_assert!(d.wrapping_mul(self.d_inv) == 1);
debug_assert!(self.q_lim * d > (<$t>::MAX - d));
}
}
impl From<$t> for PreModInv<$t> {
#[inline]
fn from(v: $t) -> Self {
use crate::word::$t::*;
debug_assert!(v % 2 != 0, "only odd divisors are supported");
let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t;
let q_lim = <$t>::MAX / v;
Self { d_inv, q_lim }
}
}
impl DivExact<$t, PreModInv<$t>> for $t {
type Output = $t;
#[inline]
fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self> {
pre.debug_check(d);
let q = self.wrapping_mul(pre.d_inv);
if q <= pre.q_lim {
Some(q)
} else {
None
}
}
}
impl DivExact<$t, PreModInv<$t>> for DoubleWord {
type Output = DoubleWord;
#[inline]
fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self::Output> {
pre.debug_check(d);
// this implementation comes from GNU factor,
// see https://math.stackexchange.com/q/4436380/815652 for explanation
let (n0, n1) = split(self);
let q0 = n0.wrapping_mul(pre.d_inv);
let nr0 = wmul(q0, d);
let nr0 = split(nr0).1;
if nr0 > n1 {
return None;
}
let nr1 = n1 - nr0;
let q1 = nr1.wrapping_mul(pre.d_inv);
if q1 > pre.q_lim {
return None;
}
Some(merge(q0, q1))
}
}
}
};
}
impl_preinv_for_prim_int!(u8, u8_impl);
impl_preinv_for_prim_int!(u16, u16_impl);
impl_preinv_for_prim_int!(u32, u32_impl);
impl_preinv_for_prim_int!(u64, u64_impl);
impl_preinv_for_prim_int!(usize, usize_impl);
// XXX: unchecked div_exact can be introduced by not checking the q_lim,
// investigate this after `exact_div` is introduced or removed from core lib
// https://github.com/rust-lang/rust/issues/85122
#[cfg(test)]
mod tests {
use super::*;
use rand::random;
#[test]
fn div_exact_test() {
const N: u8 = 100;
for _ in 0..N {
// u8 test
let d = random::<u8>() | 1;
let pre: PreModInv<_> = d.into();
let n: u8 = random();
let expect = if n % d == 0 { Some(n / d) } else { None };
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
let n: u16 = random();
let expect = if n % (d as u16) == 0 {
Some(n / (d as u16))
} else {
None
};
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
// u16 test
let d = random::<u16>() | 1;
let pre: PreModInv<_> = d.into();
let n: u16 = random();
let expect = if n % d == 0 { Some(n / d) } else { None };
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
let n: u32 = random();
let expect = if n % (d as u32) == 0 {
Some(n / (d as u32))
} else {
None
};
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
// u32 test
let d = random::<u32>() | 1;
let pre: PreModInv<_> = d.into();
let n: u32 = random();
let expect = if n % d == 0 { Some(n / d) } else { None };
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
let n: u64 = random();
let expect = if n % (d as u64) == 0 {
Some(n / (d as u64))
} else {
None
};
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
// u64 test
let d = random::<u64>() | 1;
let pre: PreModInv<_> = d.into();
let n: u64 = random();
let expect = if n % d == 0 { Some(n / d) } else { None };
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
let n: u128 = random();
let expect = if n % (d as u128) == 0 {
Some(n / (d as u128))
} else {
None
};
assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
}
}
}