蒙哥马利乘模

蒙哥马利快速乘模运算
蒙哥马利运算是一种新的运算,他把乘模简化为对进制数的除法,以及简单加法,这就使得乘模避开了大量的取模试除。
 
蒙哥马利表示法是蒙哥马利运算的基础,
 
例如a%mod的蒙哥马利表示法为a*p%mod,p是一个进制数
于是我们很容易发现对于蒙哥马利表示法变会普通表示法时只需要
(a*p)*invp%mod
这里的invp是乘模运算下p对mod的逆元
 
如果我们要计算a*b%mod;
 
蒙哥马利表示法会找一个进制数使得普通数字转化为蒙哥马利表示法
 
一般来说这是一个刚好大于模数的进制数
 
但是对于大数运算,例如64位整形,我们会直接取2**64
我们令p=2**64;
所以上述a*b%mod会表示为(a*p)*(b*p)%mod=a*b*p**2%mod
 
这里出现了问题,多了一个p,所以乘模不是这样写的
应该是a*b*p%mod
 
对比我们发现,只需要微调即可,即(((a*p)*(b*p))/p)%mod【/p表示逆元法】
 
于是这里出现了除法,这正是我们需要的,我们就是靠着这个除以p使得乘模避开了对非进制数取模,也避开了很多高复杂度无谓的除法
 
对于ACM中,如果要求大量的乘模,且刚好范围在int128以内,我们就可以直接暴算出
(a*p)*(b*p),如此留下的问题就是大数除法,把它变回蒙哥马利表示法标准式子
所以现在问题简化到给出一个大数(小于int128),快速除法,我们假设这个大整数是BIG
 
由拓展欧几里得算法我们可以找到这样一个数字,使得inv*mod与1对与2**128同余,找到这样的数字以后,我们可以这样化简
BIG/p%mod=(BIG-(inv*mod)*BIG)/p%mod
首先,有一个问题:三个数相乘已经爆了精度,但是这并不影响,因为自然溢出使得inv*mod直接变成了1
然后inv*mod*BIG/p还是很大
所以这样写错了
 
应该是这样
BIG/p%mod=(BIG-(inv*BIG)*mod))/p%mod
可以证明这样的溢出不会影响结果,并且溢出导致inv*mod*BIG/p结果就比较小了
 
 
于是我们可以预测结果,BIG-(inv*BIG)*mod)的低64位全部变成了0,直接移位就行了
于是我们又避开了逆元的选取,
 
由此受到启发,我们从一开始便可以避开逆元,
 
于是给出蒙哥马利类
struct Mod64 {
    
        static u64 mod,inv,r2;
    
        u64 n_;
    
};
 
        由拓展欧几里得算法根据模数算出inv
        然后有一个削减函数
static u64 reduce(u128 x) {
    
            u64 y=u64(x>>64)-u64((u128(u64(x)*inv)*mod)>>64);
    
            return ll(y)<0?y+mod:y;
    
        }
用与蒙哥马利乘模运算之后的大数除法(快速除法)
我们尝试通过这一个函数完成所有的操作
构造函数
 Mod64(u64 n):n_(init(n)) {}
 static u64 init(u64 w) { return reduce(u128(w) * r2); }
这里的r2就是(2**128)**2对于mod的模,其实也就是r**2
所以我们发现这个代码里面的r便是上文中的p
 
蒙哥马利类内部乘法
 Mod64& operator *= (Mod64 rhs) { n_=reduce(u128(n_)*rhs.n_); return *this; }
    Mod64 operator * (Mod64 rhs) const { return Mod64(*this)*=rhs; }
 
很明显了他是先做除法,再做乘法
 
 
蒙哥马利类还原至整形类
   u64 get() const { return reduce(n_); }
直接调用削减函数,除以p(r)
最后给出dls的代码
typedef long long ll;
typedef unsigned long long u64;
typedef __int128_t i128;
typedef __uint128_t u128;

struct Mod64 {
    static u64 mod,inv,r2;
    u64 n_;
    Mod64():n_(0) {}
    Mod64(u64 n):n_(init(n)) {}
    static u64 init(u64 w) { return reduce(u128(w) * r2); }
    static void set_mod(u64 m) {
        mod=m; //assert(mod&1);
        inv=m; for(int i=0;i<5;i++) inv*=2-inv*m;
        r2=-u128(m)%m;
    }
    static u64 reduce(u128 x) {
        u64 y=u64(x>>64)-u64((u128(u64(x)*inv)*mod)>>64);
        return ll(y)<0?y+mod:y;
    }
    Mod64& operator += (Mod64 rhs) { n_+=rhs.n_-mod; if (ll(n_)<0) n_+=mod; return *this; }
    Mod64 operator + (Mod64 rhs) const { return Mod64(*this)+=rhs; }
    Mod64& operator -= (Mod64 rhs) { n_-=rhs.n_; if (ll(n_)<0) n_+=mod; return *this; }
    Mod64 operator - (Mod64 rhs) const { return Mod64(*this)-=rhs; }
    Mod64& operator *= (Mod64 rhs) { n_=reduce(u128(n_)*rhs.n_); return *this; }
    Mod64 operator * (Mod64 rhs) const { return Mod64(*this)*=rhs; }
    u64 get() const { return reduce(n_); }
};
u64 Mod64::mod,Mod64::inv,Mod64::r2;