library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kk2a/library

:heavy_check_mark: math_mod/butterfly.hpp

Depends on

Required by

Verified with

Code

#ifndef KK2_MATH_MOD_BUTTERFLY_HPP
#define KK2_MATH_MOD_BUTTERFLY_HPP 1

#include <algorithm>

#include "primitive_root.hpp"

namespace kk2 {

template <class FPS, class mint = typename FPS::value_type> void butterfly(FPS &a) {
    static int g = primitive_root<mint::getmod()>;
    int n = int(a.size());
    int h = 0;
    while ((1U << h) < (unsigned int)(n)) h++;
    static bool first = true;
    static mint sum_e2[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
    static mint sum_e3[30];
    static mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
    if (first) {
        first = false;
        int cnt2 = __builtin_ctz(mint::getmod() - 1);
        mint e = mint(g).pow((mint::getmod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_e2[i] = es[i] * now;
            now *= ies[i];
        }
        now = 1;
        for (int i = 0; i <= cnt2 - 3; i++) {
            sum_e3[i] = es[i + 1] * now;
            now *= ies[i + 1];
        }
    }

    int len = 0;
    while (len < h) {
        if (h - len == 1) {
            int p = 1 << (h - len - 1);
            mint rot = 1;
            for (int s = 0; s < (1 << len); s++) {
                int offset = s << (h - len);
                for (int i = 0; i < p; i++) {
                    auto l = a[i + offset];
                    auto r = a[i + offset + p] * rot;
                    a[i + offset] = l + r;
                    a[i + offset + p] = l - r;
                }
                if (s + 1 != (1 << len)) rot *= sum_e2[__builtin_ctz(~(unsigned int)(s))];
            }
            len++;
        } else {
            int p = 1 << (h - len - 2);
            mint rot = 1, imag = es[0];
            for (int s = 0; s < (1 << len); s++) {
                mint rot2 = rot * rot;
                mint rot3 = rot2 * rot;
                int offset = s << (h - len);
                for (int i = 0; i < p; i++) {
                    auto a0 = a[i + offset];
                    auto a1 = a[i + offset + p] * rot;
                    auto a2 = a[i + offset + p * 2] * rot2;
                    auto a3 = a[i + offset + p * 3] * rot3;
                    auto a1na3imag = (a1 - a3) * imag;
                    a[i + offset] = a0 + a2 + a1 + a3;
                    a[i + offset + p] = a0 + a2 - a1 - a3;
                    a[i + offset + p * 2] = a0 - a2 + a1na3imag;
                    a[i + offset + p * 3] = a0 - a2 - a1na3imag;
                }
                if (s + 1 != (1 << len)) rot *= sum_e3[__builtin_ctz(~(unsigned int)(s))];
            }
            len += 2;
        }
    }
}

template <class FPS, class mint = typename FPS::value_type> void butterfly_inv(FPS &a) {
    static constexpr int g = primitive_root<mint::getmod()>;
    int n = int(a.size());
    int h = 0;
    while ((1U << h) < (unsigned int)(n)) h++;
    static bool first = true;
    static mint sum_ie2[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
    static mint sum_ie3[30];
    static mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
    static mint invn[30];
    if (first) {
        first = false;
        int cnt2 = __builtin_ctz(mint::getmod() - 1);
        mint e = mint(g).pow((mint::getmod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_ie2[i] = ies[i] * now;
            now *= es[i];
        }
        now = 1;
        for (int i = 0; i <= cnt2 - 3; i++) {
            sum_ie3[i] = ies[i + 1] * now;
            now *= es[i + 1];
        }

        invn[0] = 1;
        invn[1] = mint::getmod() / 2 + 1;
        for (int i = 2; i < 30; i++) invn[i] = invn[i - 1] * invn[1];
    }
    int len = h;
    while (len) {
        if (len == 1) {
            int p = 1 << (h - len);
            mint irot = 1;
            for (int s = 0; s < (1 << (len - 1)); s++) {
                int offset = s << (h - len + 1);
                for (int i = 0; i < p; i++) {
                    auto l = a[i + offset];
                    auto r = a[i + offset + p];
                    a[i + offset] = l + r;
                    a[i + offset + p] = (l - r) * irot;
                }
                if (s + 1 != (1 << (len - 1))) irot *= sum_ie2[__builtin_ctz(~(unsigned int)(s))];
            }
            len--;
        } else {
            int p = 1 << (h - len);
            mint irot = 1, iimag = ies[0];
            for (int s = 0; s < (1 << ((len - 2))); s++) {
                mint irot2 = irot * irot;
                mint irot3 = irot2 * irot;
                int offset = s << (h - len + 2);
                for (int i = 0; i < p; i++) {
                    auto a0 = a[i + offset];
                    auto a1 = a[i + offset + p];
                    auto a2 = a[i + offset + p * 2];
                    auto a3 = a[i + offset + p * 3];
                    auto a2na3iimag = (a2 - a3) * iimag;

                    a[i + offset] = a0 + a1 + a2 + a3;
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot;
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2;
                    a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3;
                }
                if (s + 1 != (1 << (len - 2))) irot *= sum_ie3[__builtin_ctz(~(unsigned int)(s))];
            }
            len -= 2;
        }
    }

    for (int i = 0; i < n; i++) a[i] *= invn[h];
}

template <class FPS, class mint = typename FPS::value_type> void doubling(FPS &a) {
    int n = a.size();
    auto b = a;
    int z = 1;
    butterfly_inv(b);
    mint r = 1, zeta = mint(primitive_root<mint::getmod()>).pow((mint::getmod() - 1) / (n << 1));
    for (int i = 0; i < n; i++) {
        b[i] *= r;
        r *= zeta;
    }
    butterfly(b);
    std::copy(b.begin(), b.end(), std::back_inserter(a));
}

} // namespace kk2

#endif // KK2_MATH_MOD_BUTTERFLY_HPP
#line 1 "math_mod/butterfly.hpp"



#include <algorithm>

#line 1 "math_mod/primitive_root.hpp"



#line 1 "math_mod/pow_mod.hpp"



#include <cassert>

namespace kk2 {

template <class S, class T, class U> constexpr S pow_mod(T x, U n, T m) {
    assert(n >= 0);
    if (m == 1) return S(0);
    S _m = m, r = 1;
    S y = x % _m;
    if (y < 0) y += _m;
    while (n) {
        if (n & 1) r = (r * y) % _m;
        if (n >>= 1) y = (y * y) % _m;
    }
    return r;
}

} // namespace kk2


#line 5 "math_mod/primitive_root.hpp"

namespace kk2 {

constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    if (m == 1107296257) return 10;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) { x /= i; }
        }
    }
    if (x > 1) { divs[cnt++] = x; }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod<long long>(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}

template <int m> static constexpr int primitive_root = primitive_root_constexpr(m);

} // namespace kk2


#line 7 "math_mod/butterfly.hpp"

namespace kk2 {

template <class FPS, class mint = typename FPS::value_type> void butterfly(FPS &a) {
    static int g = primitive_root<mint::getmod()>;
    int n = int(a.size());
    int h = 0;
    while ((1U << h) < (unsigned int)(n)) h++;
    static bool first = true;
    static mint sum_e2[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
    static mint sum_e3[30];
    static mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
    if (first) {
        first = false;
        int cnt2 = __builtin_ctz(mint::getmod() - 1);
        mint e = mint(g).pow((mint::getmod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_e2[i] = es[i] * now;
            now *= ies[i];
        }
        now = 1;
        for (int i = 0; i <= cnt2 - 3; i++) {
            sum_e3[i] = es[i + 1] * now;
            now *= ies[i + 1];
        }
    }

    int len = 0;
    while (len < h) {
        if (h - len == 1) {
            int p = 1 << (h - len - 1);
            mint rot = 1;
            for (int s = 0; s < (1 << len); s++) {
                int offset = s << (h - len);
                for (int i = 0; i < p; i++) {
                    auto l = a[i + offset];
                    auto r = a[i + offset + p] * rot;
                    a[i + offset] = l + r;
                    a[i + offset + p] = l - r;
                }
                if (s + 1 != (1 << len)) rot *= sum_e2[__builtin_ctz(~(unsigned int)(s))];
            }
            len++;
        } else {
            int p = 1 << (h - len - 2);
            mint rot = 1, imag = es[0];
            for (int s = 0; s < (1 << len); s++) {
                mint rot2 = rot * rot;
                mint rot3 = rot2 * rot;
                int offset = s << (h - len);
                for (int i = 0; i < p; i++) {
                    auto a0 = a[i + offset];
                    auto a1 = a[i + offset + p] * rot;
                    auto a2 = a[i + offset + p * 2] * rot2;
                    auto a3 = a[i + offset + p * 3] * rot3;
                    auto a1na3imag = (a1 - a3) * imag;
                    a[i + offset] = a0 + a2 + a1 + a3;
                    a[i + offset + p] = a0 + a2 - a1 - a3;
                    a[i + offset + p * 2] = a0 - a2 + a1na3imag;
                    a[i + offset + p * 3] = a0 - a2 - a1na3imag;
                }
                if (s + 1 != (1 << len)) rot *= sum_e3[__builtin_ctz(~(unsigned int)(s))];
            }
            len += 2;
        }
    }
}

template <class FPS, class mint = typename FPS::value_type> void butterfly_inv(FPS &a) {
    static constexpr int g = primitive_root<mint::getmod()>;
    int n = int(a.size());
    int h = 0;
    while ((1U << h) < (unsigned int)(n)) h++;
    static bool first = true;
    static mint sum_ie2[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
    static mint sum_ie3[30];
    static mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
    static mint invn[30];
    if (first) {
        first = false;
        int cnt2 = __builtin_ctz(mint::getmod() - 1);
        mint e = mint(g).pow((mint::getmod() - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_ie2[i] = ies[i] * now;
            now *= es[i];
        }
        now = 1;
        for (int i = 0; i <= cnt2 - 3; i++) {
            sum_ie3[i] = ies[i + 1] * now;
            now *= es[i + 1];
        }

        invn[0] = 1;
        invn[1] = mint::getmod() / 2 + 1;
        for (int i = 2; i < 30; i++) invn[i] = invn[i - 1] * invn[1];
    }
    int len = h;
    while (len) {
        if (len == 1) {
            int p = 1 << (h - len);
            mint irot = 1;
            for (int s = 0; s < (1 << (len - 1)); s++) {
                int offset = s << (h - len + 1);
                for (int i = 0; i < p; i++) {
                    auto l = a[i + offset];
                    auto r = a[i + offset + p];
                    a[i + offset] = l + r;
                    a[i + offset + p] = (l - r) * irot;
                }
                if (s + 1 != (1 << (len - 1))) irot *= sum_ie2[__builtin_ctz(~(unsigned int)(s))];
            }
            len--;
        } else {
            int p = 1 << (h - len);
            mint irot = 1, iimag = ies[0];
            for (int s = 0; s < (1 << ((len - 2))); s++) {
                mint irot2 = irot * irot;
                mint irot3 = irot2 * irot;
                int offset = s << (h - len + 2);
                for (int i = 0; i < p; i++) {
                    auto a0 = a[i + offset];
                    auto a1 = a[i + offset + p];
                    auto a2 = a[i + offset + p * 2];
                    auto a3 = a[i + offset + p * 3];
                    auto a2na3iimag = (a2 - a3) * iimag;

                    a[i + offset] = a0 + a1 + a2 + a3;
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot;
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2;
                    a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3;
                }
                if (s + 1 != (1 << (len - 2))) irot *= sum_ie3[__builtin_ctz(~(unsigned int)(s))];
            }
            len -= 2;
        }
    }

    for (int i = 0; i < n; i++) a[i] *= invn[h];
}

template <class FPS, class mint = typename FPS::value_type> void doubling(FPS &a) {
    int n = a.size();
    auto b = a;
    int z = 1;
    butterfly_inv(b);
    mint r = 1, zeta = mint(primitive_root<mint::getmod()>).pow((mint::getmod() - 1) / (n << 1));
    for (int i = 0; i < n; i++) {
        b[i] *= r;
        r *= zeta;
    }
    butterfly(b);
    std::copy(b.begin(), b.end(), std::back_inserter(a));
}

} // namespace kk2
Back to top page