模版

FFT
const int maxn = (1<<22) + 5;  // 注意这里需要是 > 2^k

struct Complex {
    double x,y;
    Complex(double _x = 0.0, double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 1; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 1; i < n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;

// calculate h(x) = f(x) * g(x), n1 = deg(f) + 1, n2 = deg(g) + 1
void poly_multiply(const double f[], int n1, const double g[], int n2, double h[]) {
    int n = 1;
    n1--, n2--;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn];
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = a[i].x;
}

int n1, n2;
double f[maxn/2], g[maxn/2], h[maxn];

int main() {
    cin >> n1 >> n2;  // deg(f) = n1 - 1, deg(g) = n2 - 1
    for (int i = 0; i < n1; i++) cin >> f[i];
    for (int i = 0; i < n2; i++) cin >> g[i];
    poly_multiply(f, n1, g, n2, h);
    for (int i = 0; i <= n1+n2-2; i++) cout << h[i] << " ";
}

使用板子的注意事项:

  1. maxn 的值必须大于一个 $2^k$,且 $2^k > n_1+n_2$

介绍

FFT (Fast Fourier Transform, 快速傅立叶变换) 可以在 $O(n \log n)$ 的时间内计算两个多项式的乘法,也可以用来计算卷积。

• 一般卷积的形式为 $$C_k = \sum\limits_{i=0}^kA_iB_{k-i} = A_0B_k + A_1B_{k-1} + … + A_kB_0$$

在开始介绍 FFT 之前,我们需要一些基本的前置知识。

多项式

一个多项式有两种表示方法:系数表示法点值表示法

系数表示法

$$f(x) = a_0 + a_1x + a_2x^2 + … + a_nx^n$$

则 $\{a_0,a_1,…,a_n\}$ 即可表示这个多项式 $f(x)$。

点值表示法

选取多项式函数上的 $(n+1)$ 个点 $(x_0,y_0), (x_1,y_1), …, (x_n,y_n)$ 也可以唯一的表示这个多项式 $f(x)$。

证明:使用高斯消元即可,$(n+1)$ 个未知量,对应 $(n+1)$ 个方程。

复数

$n$ 次单位根

若 $\omega$ 满足 $\omega^n = 1$,那么 $\omega$ 就是 $n$ 次单位根 (n-th root of unity)。

若 $\omega$ 满足 $\omega^n = 1$,且满足 $\omega ^ k \neq 1, \forall k \in [1, n-1]$,则 $\omega$ 就是 $n$ 次本原单位根 (primitive n-th root of unity)

$\omega$ 的取值

定理

$$\omega_n^k = e^{i\frac{2k\pi}{n}} = \cos(\frac{2k\pi}{n}) + i \sin(\frac{2k\pi}{n})$$

$$\omega = e^{i \frac{2\pi}{n}} = \cos(\frac{2\pi}{n}) + i \sin(\frac{2\pi}{n})$$

为了方便,下文我们就写 $\omega = e^{i \frac{2\pi}{n}}$

$\omega$ 有几个优秀的性质:

  1. $\omega$ 为 $n$ 次本原单位根,这说明 $\forall k \in [1, n-1], \omega^n = 1$,且 $\omega ^ k \neq 1$
  2. $\omega^{k+\frac{n}{2}} = -\omega^{k}$

FFT原理

问题

给定两个已知多项式 $f(x)$ 和 $g(x)$,$deg(f) = n_1, deg(g) = n_2$,求 $h(x) = f(x) * g(x)$?

FFT 分为三步:

  1. DFT:令 $n = n_1 + n_2$,在 $x$ 轴上选择 $(n + 1)$ 个特殊点 $x_0,x_1,…,x_{n}$,求出 $f(x_0), f(x_1), …, f(x_n)$ 与 $g(x_0), g(x_1), …, g(x_n)$ 的值。

    复杂度:$O(n\log n)$

  2. 点值乘法:求出 $(n+1)$ 个点上,点值的乘积,即 $h(x_i) = f(x_i) * g(x_i), i \in [0, n]$

    复杂度:$O(n)$

  3. IDFT:给定一个未知 $h(x)$ 上的 $(n+1)$ 个特殊点 $\{(x_0,y_0), (x_1,y_1), …, (x_n,y_n)\}$,求出 $h(x)$ 的系数表达式。

    复杂度:$O(n\log n)$

第一步 DFT

问题

假设现在有一个 $deg(f) = n$ 的多项式,且 $n$ 为偶数。我们想要求 $f(x)$ 在 $n$ 个不同的特殊点 $x_0,x_1,…,x_{n-1}$ 上的值。

注意到这个问题也等价于一个矩阵乘法,即给定 一个矩阵 $X$ 和一个向量 $\vec a$,在我们可以自由选择 $x_0,x_1,…,x_{n-1}$ 的情况下,用 $O(n \log n)$ 的时间求出 $\vec y = X \vec a$ 的值?

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix} \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} = \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} $$

那么,如何快速求出 $f(x_0), f(x_1), …, f(x_{n-1})$呢?我们只要选一些特殊的 $x_0, x_1, …, x_{n-1}$ 即可。

$X = \{x_0, x_1, …, x_{n-1}\}$ 需要有的性质:$|X^2| = \frac{|X|}{2}$

结论:$x_i = \omega^i$


证明:

因为

$$f(x) = a_0 + a_1x + a_2x^2 + … + a_nx^n$$

$$=(a_0 + a_2x^2 + a_4x^4 + … + a_{n-2}x^{n-2}) + (a_1x + a_3x^3 + a_5x^5 + … + a_{n-1}x^{n-1})$$

$$f_{even}(x) = a_0 + a_2x + a_4x^2 + … + a_{n-2}x^{\frac{n}{2} - 1}$$ $$f_{odd}(x) = a_1 + a_3x + a_5x^2 + … + a_{n-1}x^{\frac{n}{2} - 1}$$

则有

$$f(x) = f_{even}(x^2) + xf_{odd}(x^2)$$

现在,我们选择 $x_i = \omega^i$,则我们求 $f(x_i)$ 和 $f(x_{i+\frac{n}{2}})$,可以发现

$$x_{\frac{n}{2}} = \omega^{\frac{n}{2}} = -1, ~~x_{i+\frac{n}{2}} = \omega^{i+\frac{n}{2}} = -\omega^i = -x_i$$

所以

$$\forall i \in [0,\frac{n}{2}-1], ~~f(x_i) = f_{even}(x_i^2) + x_if_{odd}(x_i^2), ~~f(x_{i+\frac{n}{2}}) = f_{even}(x_i^2) - x_if_{odd}(x_i^2)$$

又因为,求 $\forall i \in [0,n], f(x_i)$ 的值,等价于求 $\forall i \in [0,\frac{n}{2}-1], f(x_i) 和 f(x_{i+\frac{n}{2}})$ 的值,

所以这个问题就可以等价转化为求 $f_{even}(x_i^2) + x_if_{odd}(x_i^2)$ 的值。而这就是一个规模减半了的子问题。

所以设 $T(n)$ 为:求出 $\forall i \in [0,n], f(x_i)$ 的值所需的时间,则可以列出方程:

$$T(n) = 2T(\frac{n}{2}) + O(n)$$

所以 $T(n) = O(n\log n)$

第二步 点值乘法

第一步DFT以后,我们有了 $f(x_0), f(x_1), …, f(x_{n-1})$ 与 $g(x_0), g(x_1), …, g(x_{n-1})$ 的值,我们就可以直接将这些值乘起来,得到

$$h(x_0), h(x_1), …, h(x_{n-1})$$

的值。

第三步 IDFT

问题

现在给定一个未知的多项式 $h(x)$,且 $deg(h) = n-1$,已知 $h(x)$ 上的 $n$ 个点 $(x_0,h(x_0)), (x_1,h(x_1)), …, (x_{n-1},h(x_{n-1}))$,求 $h(x)$ 的表达式?

注意到这个问题也等价于一个矩阵乘法,即给定 一个矩阵 $X$ 和一个向量 $\vec y$,已知 $X\vec a = \vec y$,如何用 $O(n \log n)$ 的时间求出 $\vec a$ 的值?

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix} \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} = \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} $$

发现了什么?这实际上和第一步 DFT 的过程是一样的,如果我们把 $X^{-1}$ 给两边乘上 ,就变成了

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix}^{-1} \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix}$$

那么,令

$$X=\begin{bmatrix}1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix}$$

把 $x_i = \omega^i$ 代进去,就可以得到:

$$X = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^1 & \omega^{1\cdot 2} & … & \omega^{1\cdot (n-1)} \\
…\\
1 & \omega^{1} & \omega^{(n-1)\cdot 2} & … & \omega^{(n-1)\cdot (n-1)} \\
\end{bmatrix}$$

$$V(\omega) = X = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^1 & \omega^{1\cdot 2} & … & \omega^{1\cdot (n-1)} \\
…\\
1 & \omega^{1} & \omega^{(n-1)\cdot 2} & … & \omega^{(n-1)\cdot (n-1)} \\
\end{bmatrix}, ~~~V(\omega^{-1}) = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^{-1} & \omega^{-1\cdot 2} & … & \omega^{-1\cdot (n-1)} \\
…\\
1 & \omega^{-1} & \omega^{-(n-1)\cdot 2} & … & \omega^{-(n-1)\cdot (n-1)} \\
\end{bmatrix}$$

则有:

$$V(\omega) \cdot V(\omega^{-1}) = nI$$

证明:直接进行矩阵运算即可,注意要用到 $\omega^n = 1$ 的特性。

这说明

$$X^{-1} = V(w)^{-1} = \frac{1}{n} V(\omega^{-1})$$

所以上面的 $X^{-1}\vec y = \vec a$ 就可以表示为:

$$V(\omega)^{-1} \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} ~~\Rightarrow~~ \frac{1}{n}V(\omega^{-1})\begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix}$$

这就等价于我们有一个系数为 $y_0, y_1, … y_{n}$ 的多项式,然后要在 $x_0^{-1}, x_1^{-1}, …, x_n^{-1}$ 这些特殊点上进行求值。

我们发现 $x_0^{-1} = \omega^{-1}$ 仍然是一个 $n$ 次本原单位根 (primitive n-th root of unity)(证明略),所以它仍然有上述优秀的性质。

所以这个问题可以用 DFT 解决,求出来的值,就是我们想要的 $h(x)$ 的多项式系数。

注意一下,上述所有过程都是完美的分治过程,所以我们需要把 $deg(h)$ 补成一个 $2^k$ 形式,其中 $deg(h) = 2^k \geq 2(n_1+n_2)$

非递归 FFT

递归 FFT 常数过大,不好用,所以我们可以写一个非递归版本的。

我们手推一下 FFT 的分治过程,每一次分治都是拆分奇数项和偶数项,模拟拆分过程有:

初始序列:$\{a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7\}$

第一次拆分:$\{a_0,a_2,a_4,a_6\},\{a_1,a_3,a_5,a_7\}$

第二次拆分:$\{a_0,a_4\},\{a_2,a_6\},\{a_1,a_5\},\{a_3,a_7\}$

第三次拆分:$\{a_0\},\{a_4\},\{a_2\},\{a_6\},\{a_1\},\{a_5\},\{a_3\},\{a_7\}$

发现了什么?

拆分前:0, 1, 2, 3, 4, 5, 6, 7

拆分后:0, 4, 2, 6, 1, 5, 3, 7

化成二进制:

拆分前:000, 001, 010, 011, 100, 101, 110, 111

拆分后:000, 100, 010, 110, 001, 101, 011, 111

机智的你或许发现了,拆分前后,每个数字对应的二进制刚好就是翻转了一下,最左边的bit跑到最右边去了,反之亦然。

那么这个二进制翻转就可以直接 $O(n \log n)$ 实现了,但是这样不够高效,有一个 $O(n)$ 的基于 DP 的方法:

设 $R(x)$ 为 $x$ 翻转后的结果,可知 $R(0) = 0$。

我们从小到大求 $R(x)$,在求 $R(x)$ 时,我们已知了 $R(\frac{x}{2})$ 的值,相当于:

$$x = abc[0/1], ~ \frac{x}{2} = abc, ~ R(\frac{x}{2}) = cba, ~ R(x) = [0/1]cba$$

则我们只要判断一下 $x$ 的最低位为 $0$ 还是 $1$,然后给 $R(\frac{x}{2})$ 的最高位补上一个 $0$ 或者 $1$ 就可以得到 $R(x)$ 了。

void rearrange(Complex a[], const int n) {
    static int rev[maxn];  
    for (int i = 0; i <= n; i++) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) rev[i] |= (n >> 1);
    }
    for (int i = 0; i <= n; i++) {
        if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
    }
}

至于非递归写法的原理,我懒得写了先鸽着吧(

优化(三步变两步)

暂时没遇到需要优化的情况,先鸽着。

FFT 字符串通配符匹配

题意

给定一个长度为 $m$ 的模式串 $A$ 和一个长度为 $n$ 的文本串 $B$,字符串中有通配符,通配符可以匹配任何的字符。

问模式串在文本串中出现的所有位置?

其中,$n, m \leq 3 \times 10^5$。

如果定义字符串的匹配函数 $C(x,y) = A_x - B_y$,所以如果 $C(x,y) = 0$ 那么 $A$ 的第 $x$ 个字符和 $B$ 的第 $y$ 个字符匹配,如果我们定义完全匹配函数 $P(x) = \sum\limits_{i=0}^{m-1} C(i, x+i)$,那么若 $P(x) = 0$ 那么意味着 $A$ 出现在了 $B$ 的位置 $x$ 处。

但是,这样的话 "ab""ba" 会匹配在一起。所以我们重新定义匹配函数为 $C(x,y) = (A_x - B_y)^2$,保证每一项均 $\geq 0$,那么展开后:

$$P(x) = \sum\limits_{i=0}^{m-1} C(i, x+i) = \sum\limits_{i=0}^{m-1} (A_i - B_{x+i})^2$$

$$= \sum\limits_{i=0}^{m-1} A_i^2 + \sum\limits_{i=0}^{m-1} B_{x+i}^2 - 2 \sum\limits_{i=0}^{m-1} A_i B_{x+i}$$

注意第一项是个常数,第二项可以前缀和预处理,第三项呢?

FFT经典套路:翻转多项式,转成卷积。设 $A_i’ = A_{m-1-i}$,那么有第三项等于

$$\sum\limits_{i=0}^{m-1} A_i B_{x+i} = \sum\limits_{i=0}^{m-1} A_{m-1-i}’ B_{x+i}$$

$$ = A_{m-1}’ B_x + A_{m-2}’ B_{x+1} + … + A_0’ B_{x+m-1}$$

注意,下标之和等于 $x+m-1$。虽然这个不是严格意义上的卷积,因为按理说还要加上 $A_{x+m-1}’ B_0 + A_{x+m-2}’ B_1 + … + A_m’ B_{x-1}$ 才对,但是由于 $A'$ 的长度只有 $m-1$,这些多出来的项就设为 $0$ 即可,所以这个式子的结果是等于

$$\sum\limits_{i+j=x+m-1} A_i’ B_j$$

的。也就是多项式 $A’ * B$ 的第 $(x+m-1)$ 项。

• 这样的话可以在 $O(n\log n)$ 的时间内求出 $A$ 在 $B$ 中出现的所有位置,但并没有处理通配符问题。

现在考虑有通配符的情况,我们修改匹配函数:$C(x,y) = 0$ 既然代表 $A$ 的第 $x$ 个字符和 $B$ 的第 $y$ 个字符匹配,那么匹配时有三种情况:

  1. $A_x$ 是通配符。
  2. $B_y$ 是通配符。
  3. $A_x = B_y$。

只要三个条件满足其一即可。所以不妨令通配符等于 $0$,也就是 $A_x = 0 \iff $ A[x] = '*'

然后用乘法来表示 “OR” 的逻辑,重新定义匹配函数:

$$C(x,y) = A_x B_y (A_x-B_y)^2$$

这样完全匹配函数就等于

$$P(x) = \sum\limits_{i=0}^{m-1} C(i, x+i) = \sum\limits_{i=0}^{m-1} A_i B_{x+i} (A_i-B_{x+i})^2$$

$$= \sum\limits_{i=0}^{m-1} (A_i^3 B_{x+i} - 2 A_i^2 B_{x+i}^2 + A_i B_{x+i}^3)$$

设 $A_i’ = A_{m-1-i}$,有:

$$P(x)= \sum\limits_{i=0}^{m-1} (A_{m-1-i}’ ^3 B_{x+i} - 2 A_{m-1-i}’ ^2 B_{x+i}^2 + A_{m-1-i}’ B_{x+i}^3)$$

• 对于 $A_{m-1-i}’ ^3 B_{x+i}$ 这种类型的怎么计算呢?只要多创建一个多项式 $C$,其中 $C_i = A_i’ ^3$ 即可。

所以总计创建 $6$ 个多项式,乘起来然后相加后得到多项式 $F$,就可以得到:

$A$ 出现在 $B$ 的位置 $x$ 处 $\iff P(x) = 0 \iff F_{x+m-1} = 0$

例题

例1 洛谷P1919 【模板】A*B Problem升级版(FFT快速傅里叶)

题意

给定正整数 $a,b$,求 $a*b$?

其中,$1\leq a,b \leq 10^{10^6}$

题解

把大整数看成多项式,例如 $356$ 就看作 $f(x) = 3x^2 + 5x + 6$

两个大整数相乘就看作多项式相乘,最后代入 $x = 10$,处理一下进位和前缀零即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = (1<<21) + 5;

struct Complex {
    long double x,y;
    Complex(long double _x = 0.0, long double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;


// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(int f[], int n1, int g[], int n2, int h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn], c[maxn];
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = (int)(a[i].x + 0.5);

}

int n1, n2;
int f[maxn/2], g[maxn/2], h[maxn];

int main() {

    string s1, s2;
    cin >> s1 >> s2;
    int n1 = s1.size() - 1, n2 = s2.size() - 1;
    for (int i = 0; i <= n1; i++) f[i] = s1[n1-i] - '0';
    for (int i = 0; i <= n2; i++) g[i] = s2[n2-i] - '0';
    poly_multiply(f, n1, g, n2, h);

    int i = 0;
    for (i = 0; i < maxn-2; i++) {
        h[i+1] += h[i] / 10;
        h[i] %= 10;
    }
    string ans = "";
    for (int i = 0; i < maxn-2; i++) ans += (char)(h[i] + '0');

    for (int j = ans.size()-1; j >= 0; j--) {
        if (ans[j] == '0') {
            ans.erase(ans.begin() + j);
        } else break;
    }

    reverse(ans.begin(), ans.end());
    cout << ans << endl;
}

例2 洛谷P3338 [ZJOI2014]力

题意

给定 $n$ 个实数 $q_1,q_2,…,q_n$,定义:

$$F_j = \sum\limits_{i=1}^{j-1}\frac{q_i \times q_j}{(i-j)^2} - \sum\limits_{i=j+1}^{n}\frac{q_i \times q_j}{(i-j)^2}$$

$$E_j = \frac{F_j}{q_j}$$

求 $E_1, E_2, …, E_n$ 的值?

其中,$n \leq 10^5$

题解

FFT 可以用于处理卷积问题,我们先化简一下式子:

$$E_j = \frac{F_j}{q_j} = \sum\limits_{i=1}^{j-1}\frac{q_i}{(i-j)^2} - \sum\limits_{i=j+1}^{n}\frac{q_i}{(i-j)^2}$$

$$f_i = q_i, ~~g_i = \frac{1}{i^2}$$

$$E_j = \sum\limits_{i=1}^{j-1}f_ig_{j-i} - \sum\limits_{i=j+1}^{n}f_ig_{i-j}$$

注意到第一项 $\sum\limits_{i=1}^{j-1}f_ig_{j-i}$ 可以写成 $i=0$ 开始,到 $j$ 结束(令 $f_0 = g_0 = 0$ 即可),所以第一项等于 $\sum\limits_{i=0}^{j}f_ig_{j-i}$,也就是一个卷积形式,可以利用 FFT 计算。


第二项 $\sum\limits_{i=j+1}^{n}f_ig_{i-j}$ 怎么写成卷积呢?

注意到,如果我们令 $f_j’ = f_{n-j}$,则有

$$\sum\limits_{i=j+1}^{n}f_ig_{i-j} = \sum\limits_{i=j}^{n}f’_{n-i}g_{i-j} = g_0f’_{n-j} + g_1f’_{n-j-1} + … + g_{n-j}f’_0$$

那么,这就又是一个卷积了,所以最后我们的结果可以表示为:

$$E_j = \sum\limits_{i=0}^{j}f_ig_{j-i} - \sum\limits_{i=j}^{n}f’_{n-i}g_{i-j}$$


第一项 $\sum\limits_{i=0}^{j}f_ig_{j-i}$:计算多项式 $f(x) * g(x)$,取第 $j$ 项的系数即可。

第二项 $\sum\limits_{i=j}^{n}f’_{n-i}g_{i-j}$:计算多项式 $f’(x) * g(x)$,取第 $n-j$ 项的系数即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = (1<<21) + 5;

struct Complex {
    double x,y;
    Complex(double _x = 0.0, double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;

// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(const double f[], int n1, const double g[], int n2, double h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn];
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = a[i].x;
}

double f[maxn/2], g[maxn/2], h[maxn];
double f2[maxn], h2[maxn];

int main() {
    int n; cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> f[i];
        g[i] = (1.0/i/i);
    }

    poly_multiply(f, n, g, n, h);
    for (int i = 0; i <= n; i++) f2[i] = f[n-i];  // f'[i] = f[n-i]
    poly_multiply(f2, n, g, n, h2);

    for (int i = 1; i <= n; i++) {
        printf("%.3f\n", h[i] - h2[n-i]);
    }
}

例3 洛谷P3723 [AH2017/HNOI2017]礼物

题意

给定两个长度为 $n$ 的正整数序列 $x,y$,我们进行一次以下操作:

第一步:给其中一个序列的所有数加上一个非负整数 $c$。

第二步:旋转一个序列,可以旋转任意长度。

序列的旋转意思是全部元素左移,或者右移,超出界限的部分从另外一边补齐。

例如: $[1,2,3,4,5]$ 向右旋转 $2$ 个长度,得到 $[4,5,1,2,3]$。

现在我们想知道,如何选择 $c$ 和旋转序列的方式,使得

$$\sum\limits_{i=1}^n(x_i-y_i)^2$$

最小?

其中,$1\leq n \leq 50000, 1\leq x_i, y_i \leq 100$

题解

首先由于第一步中,我们可以选择任意一个序列来加上 $c$,所以可以看作选择一个序列,然后加上一个 任意整数 $c$

最后我们设旋转后,对齐的两个序列为 $a,b$。

$$\sum\limits_{i=1}^n(x_i-y_i)^2$$

可以表示为:

$$\sum\limits_{i=1}^n(a_i+c-b_i)^2$$

$$=\sum\limits_{i=1}^n ((a_i^2+b_i^2+c^2)+2c(a_i-b_i) - 2a_ib_i)$$

$$=nc^2 + \sum\limits_{i=1}^n (a_i^2+b_i^2)+2c\sum\limits_{i=1}^n(a_i-b_i) - \sum\limits_{i=1}^n 2a_ib_i$$

如果我们暂时不考虑 $c$,则会发现,无论最终旋转的 $a,b$ 是什么样的,前三项均为定值,只有 $\sum\limits_{i=1}^n 2a_ib_i$ 的值是个变量。

所以我们只需要 最大化 $\sum\limits_{i=1}^n a_ib_i$ 的值即可。


现在问题转化为,如何求一种旋转方式,使得 $\sum\limits_{i=1}^n a_ib_i$ 的值最大?

我们不妨直接把每一种旋转方式的值都求出来,然后取最大即可。

怎么全部一次性求出呢?卷积

$$E_j = \sum\limits_{i=j}^na_ib_{i-j+1} + \sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j}$$

则 $E_j$ 所代表的就是,如果我们如此排列:

$a_j$ $a_{j+1}$ $a_n$ $a_1$ $a_2$ $a_{j-1}$
$b_1$ $b_2$ $b_{n+1-j}$ $b_{n+2-j}$ $b_{n+3-j}$ $b_{n}$

最后得出来对应位置的乘积的和,就是 $E_j$。

所以我们只要求出 $E_1, E_2, …, E_n$,然后取最小值即可。

如何求呢?


先考虑第一项 $\sum\limits_{i=j}^na_ib_{i-j+1}$

把它转成卷积的形式,令 $c_i = a_{n-i}$,则有:

$$\sum\limits_{i=j}^na_ib_{i-j+1} = \sum\limits_{i=j}^nc_{n-i}b_{i-j+1} = b_1c_{n-j} + b_2c_{n-j-1} + … + b_{n-j+1}c_{0}$$

令 $a_0 = b_0 = 0$,就可以再给上式加一个 $b_0c_{n-j+1}$,得到:

$$b_0c_{n-j+1} + b_1c_{n-j} + b_2c_{n-j-1} + … + b_{n-j+1}c_{0}$$

那么这就是一个卷积了,相当于 $b(x) * c(x)$ 的第 $(n-j+1)$ 项。


再考虑第二项 $\sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j}$

仍然转成卷积形式,令 $d_i = b_{n-i}$,则有:

$$\sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j} = \sum\limits_{i=1}^{j-1}a_id_{j-i-1} = a_1d_{j-2} + a_2d_{j-3} + … + a_{j-1}d_{0}$$

由于 $a_0 = 0$,给上式加上 $a_0d_{j-1}$,得到:

$$a_0d_{j-1} + a_1d_{j-2} + a_2d_{j-3} + … + a_{j-1}d_{0}$$

这又是一个卷积,相当于 $a(x) * d(x)$ 的第 $(j-1)$ 项。


所以 $\sum\limits_{i=1}^n a_ib_i$ 的最大值求出来了,剩下的变量只有 $c$ 了。

注意到 $1 \leq x_i,y_i \leq 100$,所以 $c$ 只要取遍所有的可能值,然后判断最小值即可。

在代码中,$c \in [-200, 200]$。

代码
#include <bits/stdc++.h>
using namespace std;

const int mod = 998244353;
const int maxn = (1<<20) + 5;

struct NTT {
    const ll g = 3, invg = inv(g);  // mod = 998244353
    inline ll qpow(ll a, ll b) {
        ll res = 1;
        while (b) {
            if (b & 1) res = res * a % mod;
            a = a * a % mod;
            b >>= 1;
        }
        return res;
    }
    inline ll inv(ll a) {
        return qpow(a, mod-2);
    }

    void rearrange(ll a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void ntt(ll a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            ll wn = qpow(on == 1 ? g : invg, (mod-1)/k);
            for (int i = 0; i < n; i += k) {
                ll w = 1;
                for (int j = i; j < i + (k>>1); j++) {
                    ll x = a[j], y = w * a[j+(k>>1)] % mod;
                    a[j] = (x + y) % mod;
                    a[j+(k>>1)] = (x - y + mod) % mod;
                    w = w * wn % mod;
                }
            }
        }
        if (on == -1) {
            ll invn = inv(n);
            for (int i = 0; i < n; i++) a[i] = a[i] * invn % mod;
        }
    }
} ntt;

// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(ll f[], int n1, ll g[], int n2, ll h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    memset(h, 0, sizeof(ll) * n);
    ntt.ntt(f, n, 1);  // 注意这里用的是 n (不是 n1)
    ntt.ntt(g, n, 1);
    for (int i = 0; i <= n; i++) h[i] = f[i] * g[i] % mod;
    ntt.ntt(h, n, -1);
}

int n1, n2;
ll a[maxn/2], b[maxn/2], c[maxn/2], d[maxn/2], r1[maxn], r2[maxn];

int n,m;
ll ans = 0;

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i], c[n-i] = a[i];
    for (int i = 1; i <= n; i++) cin >> b[i], d[n-i] = b[i];
    ll absum = 0;  // a_i - b_i
    for (int i = 1; i <= n; i++) absum += a[i] - b[i];
    ll res = 1e18;
    for (ll c = -200; c <= 200; c++) {
        res = min(res, n * c * c + 2LL * c * absum);
    }

    ans += res;

    for (int i = 1; i <= n; i++) {
        ans += (a[i] * a[i]) + (b[i] * b[i]);
    }

    poly_multiply(a, n, d, n, r1);  // r1 = a*d
    poly_multiply(b, n, c, n, r2);  // r2 = b*c

    ll ab = 0;
    for (int i = 1; i <= n; i++) {
        ab = max(ab, (r1[i-1] + r2[n-i+1]));
    }

    ans -= 2LL * ab;

    cout << ans << endl;
}

例4 洛谷P4173 残缺的字符串

题意

给定一个长度为 $m$ 的模式串 $A$ 和一个长度为 $n$ 的文本串 $B$,字符串中有通配符,通配符可以匹配任何的字符。

问模式串在文本串中出现的所有位置?

其中,$n, m \leq 3 \times 10^5$。

代码
#include <bits/stdc++.h>
using namespace std;
std::vector<int> rev;
std::vector<Z> roots{0, 1};
void dft(std::vector<Z> &a) {
    int n = a.size();
    
    if (int(rev.size()) != n) {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; i++) {
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
        }
    }
    
    for (int i = 0; i < n; i++) {
        if (rev[i] < i) {
            std::swap(a[i], a[rev[i]]);
        }
    }
    if (int(roots.size()) < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1 << k) < n) {
            Z e = qpow(Z(3), (mod - 1) >> (k + 1));
            for (int i = 1 << (k - 1); i < (1 << k); i++) {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = roots[i] * e;
            }
            k++;
        }
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; j++) {
                Z u = a[i + j];
                Z v = a[i + j + k] * roots[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
        }
    }
}
void idft(std::vector<Z> &a) {
    int n = a.size();
    std::reverse(a.begin() + 1, a.end());
    dft(a);
    Z inv = (1 - mod) / n;
    for (int i = 0; i < n; i++) {
        a[i] *= inv;
    }
}
struct Poly {
    std::vector<Z> a;
    Poly() {}
    Poly(const int n) { resize(n); }
    Poly(const std::vector<Z> &a) : a(a) {}
    Poly(const std::initializer_list<Z> &a) : a(a) {}
    int size() const {
        return a.size();
    }
    void resize(int n) {
        a.resize(n);
    }
    Z operator[](int idx) const {
        if (idx < size()) {
            return a[idx];
        } else {
            return 0;
        }
    }
    Z &operator[](int idx) {
        return a[idx];
    }
    Poly mulxk(int k) const {
        auto b = a;
        b.insert(b.begin(), k, 0);
        return Poly(b);
    }
    Poly modxk(int k) const {
        k = std::min(k, size());
        return Poly(std::vector<Z>(a.begin(), a.begin() + k));
    }
    Poly divxk(int k) const {
        if (size() <= k) {
            return Poly();
        }
        return Poly(std::vector<Z>(a.begin() + k, a.end()));
    }
    friend Poly operator+(const Poly &a, const Poly &b) {
        std::vector<Z> res(std::max(a.size(), b.size()));
        for (int i = 0; i < int(res.size()); i++) {
            res[i] = a[i] + b[i];
        }
        return Poly(res);
    }
    friend Poly operator-(const Poly &a, const Poly &b) {
        std::vector<Z> res(std::max(a.size(), b.size()));
        for (int i = 0; i < int(res.size()); i++) {
            res[i] = a[i] - b[i];
        }
        return Poly(res);
    }
    friend Poly operator*(Poly a, Poly b) {
        if (a.size() == 0 || b.size() == 0) {
            return Poly();
        }
        int sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) {
            sz *= 2;
        }
        a.a.resize(sz);
        b.a.resize(sz);
        dft(a.a);
        dft(b.a);
        for (int i = 0; i < sz; ++i) {
            a.a[i] = a[i] * b[i];
        }
        idft(a.a);
        a.resize(tot);
        return a;
    }
    friend Poly operator*(Z a, Poly b) {
        for (int i = 0; i < int(b.size()); i++) {
            b[i] *= a;
        }
        return b;
    }
    friend Poly operator*(Poly a, Z b) {
        for (int i = 0; i < int(a.size()); i++) {
            a[i] *= b;
        }
        return a;
    }
    friend Poly operator*(Poly a, int b) {
        return operator*(a, Z(b));
    }
    Poly &operator+=(Poly b) {
        return (*this) = (*this) + b;
    }
    Poly &operator-=(Poly b) {
        return (*this) = (*this) - b;
    }
    Poly &operator*=(Poly b) {
        return (*this) = (*this) * b;
    }
    Poly deriv() const {
        if (a.empty()) {
            return Poly();
        }
        std::vector<Z> res(size() - 1);
        for (int i = 0; i < size() - 1; ++i) {
            res[i] = (i + 1) * a[i + 1];
        }
        return Poly(res);
    }
    Poly integr() const {
        std::vector<Z> res(size() + 1);
        for (int i = 0; i < size(); ++i) {
            res[i + 1] = a[i] / (i + 1);
        }
        return Poly(res);
    }
    Poly inv(int m) const {
        Poly x{a[0].inv()};
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x * (Poly{2} - modxk(k) * x)).modxk(k);
        }
        return x.modxk(m);
    }
    Poly log(int m) const {
        return (deriv() * inv(m)).integr().modxk(m);
    }
    Poly exp(int m) const {
        Poly x{1};
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x * (Poly{1} - x.log(k) + modxk(k))).modxk(k);
        }
        return x.modxk(m);
    }
    Poly pow(int k, int m) const {
        int i = 0;
        while (i < size() && a[i].val() == 0) {
            i++;
        }
        if (i == size() || 1LL * i * k >= m) {
            return Poly(std::vector<Z>(m));
        }
        Z v = a[i];
        auto f = divxk(i) * v.inv();
        return (f.log(m - i * k) * k).exp(m - i * k).mulxk(i * k) * qpow(v, k);
    }
    Poly sqrt(int m) const {
        Poly x{1};
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x + (modxk(k) * x.inv(k)).modxk(k)) * ((mod + 1) / 2);
        }
        return x.modxk(m);
    }
    Poly mulT(Poly b) const {
        if (b.size() == 0) {
            return Poly();
        }
        int n = b.size();
        std::reverse(b.a.begin(), b.a.end());
        return ((*this) * b).divxk(n - 1);
    }
    std::vector<Z> eval(std::vector<Z> x) const {
        if (size() == 0) {
            return std::vector<Z>(x.size(), 0);
        }
        const int n = std::max(int(x.size()), size());
        std::vector<Poly> q(4 * n);
        std::vector<Z> ans(x.size());
        x.resize(n);
        std::function<void(int, int, int)> build = [&](int p, int l, int r) {
            if (r - l == 1) {
                q[p] = Poly{1, -x[l]};
            } else {
                int m = (l + r) / 2;
                build(2 * p, l, m);
                build(2 * p + 1, m, r);
                q[p] = q[2 * p] * q[2 * p + 1];
            }
        };
        build(1, 0, n);
        std::function<void(int, int, int, const Poly &)> work = [&](int p, int l, int r, const Poly &num) {
            if (r - l == 1) {
                if (l < int(ans.size())) {
                    ans[l] = num[0];
                }
            } else {
                int m = (l + r) / 2;
                work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
                work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
            }
        };
        work(1, 0, n, mulT(q[1].inv(n)));
        return ans;
    }
};

int m, n;
string s, t;
int main() {
    fastio;
    cin >> m >> n;
    cin >> s >> t;
    // s 是短串(模式串), t是文本串

    Poly A(m), A2(m), A3(m), B(n), B2(n), B3(n);
    for (int i = 0; i < m; i++) {
        if (s[i] == '*') A[i] = 0;
        else A[i] = s[i] - 'a' + 1;
    }

    reverse(A.a.begin(), A.a.end());
    for (int i = 0; i < m; i++) {
        A2[i] = A[i] * A[i];
        A3[i] = A2[i] * A[i];
    }

    for (int i = 0; i < n; i++) {
        if (t[i] == '*') B[i] = 0;
        else B[i] = t[i] - 'a' + 1;
        B2[i] = B[i] * B[i];
        B3[i] = B2[i] * B[i];
    }
    Poly F = (A3 * B) - (A2 * B2 * 2) + (A * B3);
    int cnt = 0;
    for (int i = 0; i < n-m+1; i++) {
        if (!F[i + m - 1].val()) {
            cnt++;
        }
    }
    cout << cnt << endl;
    for (int i = 0; i < n-m+1; i++) {
        if (!F[i+m-1].val()) {
            cout << i+1 << " ";
        }
    }
    cout << "\n";
}

例5 CF528D Fuzzy Search

题意

给定两个字符串 $s,t$,长度分别为 $n,m$。字符串均只包含 AGCT 四种字符。同时给定一个非负整数 $k$。

定义 $t$ 在 $s$ 的第 $i$ 位出现当前仅当,在将 $t$ 的首个字符与 $s$ 的第 $i$ 位对齐后,$t$ 的第 $j$ 个字符都能在 $s$ 的第 $[i+j-k, i+j+k]$ 中找到一个相同的字符。(这里的字符 index 是从 0 开始的)。

输出 $t$ 在 $s$ 中出现的所有位置。

其中,$m \leq n \leq 2 \times 10^5, k \in [0, 2 \times 10^5]$。

题解

这种字符只有 $4$ 种的还是要注意一下的,虽然不能直接用通配符匹配。不过观察一下可以发现:

不妨将 AGCT 四种情况分开来看,每一种都分开进行匹配,最后合在一起,四种字符都能匹配上的位置才算 $t$ 在 $s$ 中出现了。

那么怎么分开匹配?

假如只考虑字符 A,那么可以将原字符串的所有 A 向左右分别延伸 $k$ 个位置。形成一个新的字符串。

例如 s = "AGTTAG",若 $k=1$ 则有 s = "AA_AAA"(非 A 的直接看作其他字符即可)。

而 t 则同理,只不过这时非 A 的字符看作通配符,那么如 t = "AGTA" 就可以看作 "A**A"

然后用通配符匹配,来判断 $t$ 在 $s$ 的哪些位置能够出现,最后把四种字符的情况全部 AND 起来即可。

• 顺带一提,因为这个题只有 $t$ 中有通配符,实际上匹配函数可以写作 $C(x,y) = B_y(A_x-B_y)^2$。

代码
#include <bits/stdc++.h>
using namespace std;

int n, m, k;
string s, t;
bool bad[maxn];
 
int main() {
    cin >> n >> m >> k;
    cin >> s >> t;
 
    for (char cur : "AGCT") {
        vector<int> a(n+1, 0);
        for (int i = 0; i < n; i++) {
            if (s[i] == cur) {
                // [i-k, i+k]
                a[max(0, i-k)]++, a[min(n, i+k+1)]--;
            }
        }
        for (int i = 1; i < n; i++) a[i] += a[i-1];
 
        Poly S(n), S2(n), T(m), T2(m);
        Z sum = 0;
        for (int i = 0; i < m; i++) {
            if (t[i] == cur) T[m-1-i] = 1;
            T2[m-1-i] = T[m-1-i] * T[m-1-i];
            sum += T[m-1-i] * T2[m-1-i];
        }
 
        for (int i = 0; i < n; i++) {
            if (a[i]) S[i] = 1;
            else S[i] = 2;
            S2[i] = S[i] * S[i];
        }
 
        Poly F = (S2 * T) - (S * T2 * 2);
        for (int i = 0; i < n-m+1; i++) {
            if ((F[i+m-1] + sum).val()) {
                bad[i] = 1;
            }
        }
    }
 
    int ans = 0;
    for (int i = 0; i < n-m+1; i++) {
        ans += (!bad[i]);
    }
    cout << ans << endl;
}

例5 ICPC WF2021 G. Mosaic Browsing

题意

给定一个 $r_p \times c_p$ 的模式矩阵 $T$,和一个 $r_q \times c_q$ 的文本矩阵 $S$。

模式矩阵里每个数在 $[0,100]$ 之间,文本矩阵每个数在 $[1,100]$ 之间。

现在问模式矩阵在文本矩阵的哪些位置出现过,输出所有的位置。

一个模式矩阵在文本矩阵的 $(x,y)$ 出现过,当且仅当把模式矩阵的左上角放在 $(x,y)$ 处时,整个模式矩阵可以完全匹配文本矩阵对应的子矩阵。

• 模式矩阵中的 $0$ 可以匹配任何数字,相当于通配符。

其中,$r_p,c_p,r_q,c_q \in [1,1000]$。

题解

一开始想着是分行进行匹配,但这样的话复杂度不对。

于是考虑到把矩阵拍平成数组进行匹配,直接开始列式子:

由于只有 $T$ 中具有通配符,那么我们定义二维的匹配函数:

$$C(x_1,y_1,x_2,y_2) = (S_{x_1,y_1} - T_{x_2,y_2})^2 T_{x_2,y_2} = S_{x_1,y_1}^2 T_{x_2,y_2} - 2S_{x_1,y_1} T_{x_2,y_2}^2 + T_{x_2,y_2}^3$$

有 $C(x_1,y_1,x_2,y_2) = 0 \iff S$ 的 $(x_1,y_1)$ 与 $T$ 的 $(x_2,y_2)$ 匹配上。

那么完全匹配函数有

$$P(x,y) = \sum\limits_{i=0}^{r_p-1} \sum\limits_{j=0}^{c_p-1} C(x+i,y+j,i,j)$$

$$= \sum\limits_{i=0}^{r_p-1} \sum\limits_{j=0}^{c_p-1} (S_{x+i,y+j}^2 T_{i,j} - 2 S_{x+i,y+j} T_{i,j}^2 + T_{i,j}^3)$$

显然第三项是个定值,对于前两项,可以发现如果把 $x,y$ 坐标都翻转一下,分别加起来还是个定值。

为了方便,我们把 $T$ 扩展到和 $S$ 一样的大小,然后翻转一下,令:

$$T_{i,j}’ = T_{r_q-1-i, c_q-1-j}$$

然后就有:

$$P(x,y) = \sum\limits_{i=0}^{r_p-1} \sum\limits_{j=0}^{c_p-1} (S_{x+i,y+j}^2 T_{r_q-1-i,c_q-1-j}’ - 2 S_{x+i,y+j} T_{r_q-1-i,c_q-1-j}‘^2 + T_{i,j}^3)$$

可以发现 $x$ 坐标之和为 $x+i+r_q-1-i = x+r_q-1$,而 $y$ 坐标之和为 $y+j+c_q-1-j = y+c_q-1$。

然后我们把 $S,T$ 都拍成一个数组,使得:

$$S_{x,y} \rightarrow S_{x*c_q + y}$$

那么按照这个方法拍平了以后,求出 $S * T'$ 的第 $((x+r_q-1) * c_q + y+c_q-1)$ 项即可。

代码
#include <bits/stdc++.h>
using namespace std;
int rp, cp, rq, cq;
int tmp[maxn][maxn];
Z tmppoly[maxn*maxn];
int main() {
    cin >> rp >> cp;
    Z sum = 0;
    for (int i = 0; i < rp; i++) {
        for (int j = 0; j < cp; j++) {
            int x; cin >> x;
            tmp[i][j] = x;
            sum += (x * x * x);
        }
    }
    cin >> rq >> cq;

    Poly S((int)(rq*cq)), T((int)(rq*cq));
    Poly T2 = T, S2 = S;
    for (int i = 0; i < rq; i++) {
        for (int j = 0; j < cq; j++) {
            tmppoly[i*cq + j] = tmp[i][j];
            int x; cin >> x;
            S[i*cq + j] = x;
            S2[i*cq + j] = x * x;
        }
    }

    for (int i = 0; i < rq; i++) {
        for (int j = 0; j < cq; j++) {
            int idx1 = i * cq + j, idx2 = (rq - 1 - i) * cq + cq - 1 - j;
            T[idx1] = tmppoly[idx2];
            T2[idx1] = T[idx1] * T[idx1];
        }
    }

    if (rp > rq || cp > cq) {
        cout << 0 << endl;
        return 0;
    }

    Poly F = (S2 * T) - (S * T2 * 2);
    vector<pii> vec;
    for (int i = 0; i < rq - rp + 1; i++) {
        for (int j = 0; j < cq - cp + 1; j++) {
            if (!(F[(rq - 1 + i) * cq + j + cq - 1] + sum).val()) vec.push_back({i+1, j+1});
        }
    }
    cout << vec.size() << "\n";
    for (auto [x, y] : vec) cout << x << " " << y << "\n";
}

例6 ABC291G. OR Sum

题意

给定两个长度为 $n$ 的数组 $A, B$。

我们可以将 $A$ 整体右移几位(最右边的换到左边去)。

求出所有移位方案中,$\sum\limits_{i=0}^{n-1} (A_i|B_i)$ 的最大值。

其中,$n \leq 5 \times 10^5, A_i, B_i \in [0,31]$。

题解

$[0,31]$ 太过明显了,拆位处理。

然后发现 $\sum\limits_{i=0}^{n-1} (A_i|B_i)$,这不是个经典翻转后卷积吗?

于是把 $B$ 翻转一下,得到 $\sum\limits_{i=0}^{n-1} (A_i|B_{n-1-i}')$。

不过这是 OR 操作,不是乘法。

但是注意到,拆位后只有 0 和 1,而对于只有 $0$ 和 $1$ 的情况下,有:

$$a|b = a+b - ab$$

其中 $a, b \in \{0,1\}$。

于是

$$\sum\limits_{i=0}^{n-1} (A_i|B_{n-1-i}') = sum(A) + sum(B’) - \sum\limits_{i=0}^{n-1} (A_iB_{n-1-i}')$$

FFT 即可。

代码
#include <bits/stdc++.h>
using namespace std;
int n, a[maxn], b[maxn];
int main() {
    cin >> n;
    Poly A(n*2), B(n*2);
    ll ans = 0;
    for (int i = 0; i < n; i++) {
        int x; cin >> x; 
        a[i] = a[i+n] = x;
        ans += x;
    }
    for (int i = 0; i < n; i++) {
        cin >> b[n-1-i];
        ans += b[n-1-i];
    }
    vector<ll> res(2*n, 0);
    for (int j = 0; j < 5; j++) {
        for (int i = 0; i < 2*n; i++) {
            if (a[i] & (1<<j)) A[i] = 1;
            else A[i] = 0;

            if (b[i] & (1<<j)) B[i] = 1;
            else B[i] = 0;
        }
        Poly F = A * B;
        for (int i = n-1; i <= 2*n-1; i++) {
            res[i] += F[i].val() * (1<<j);
        }
    }
    ll mn = 1e18;
    for (int i = n-1; i <= 2*n-1; i++) {
        mn = min(mn, res[i]);
    }
    cout << ans - mn << endl;
}

参考资料

  1. HKU COMP3250 FFT课件
  2. https://oi-wiki.org/math/poly/fft/
  3. https://oi.men.ci/fft-notes/
  4. https://zhuanlan.zhihu.com/p/128661674