四年前写过一篇介绍 FFT 的博客,但其中漏洞百出,应该坑了不少人,故现在写一份新的,希望对算法竞赛选手有帮助。我们先回顾复数的一些知识,再介绍多项式的点值表达,最后给出 FFT 的详细过程和代码实现。

复数的概念

  我们每个人都很熟悉一元二次方程 $ax^2+bx+c = 0$. 它有个判别式 $\Delta = b ^ 2 - 4ac$,如果 $\Delta \geq 0$,方程就有两个互不相等的实根;如果 $\Delta = 0$,则方程有一个实根(或者说,两个相等的实根);如果 $\Delta < 0$,则方程没有实根。经过一番推导,我们可以给出方程的解:$$x _ 1 = \frac{-b + \sqrt{\Delta}}{2a}, \quad x _ 2 = \frac{-b - \sqrt{\Delta}}{2a}$$

  显然,如果 $\Delta < 0$,上面的 $x _ 1, x_2$ 都是无意义的。例如 $x^2 + x + 1$ 的判别式 $\Delta = -3$,于是求不出 $x _ 1, x_ 2$.

  我们把 $ x _ 1, x _ 2$ 相加、相乘,于是可以得到韦达定理:$$x _ 1 + x _ 2 = -\frac{b}{a}, \quad x _1 \cdot x _ 2 = \frac{c}{a}$$

  韦达定理的式子里面并没有根号!拿着韦达定理去套 $x ^ 2 + x + 1$,我们会获得 $$x _ 1 + x_ 2 = -1,\quad  x _1 \cdot x _ 2 = 1$$

  但 $x _ 1, x _ 2$ 在实数域上是不存在的,它们怎么会有实数的和、积呢?现在有两条路:要么限定韦达定理不准用于 $\Delta > 0$ 的情况,要么承认 $x_ 1, x _ 2$ 都是存在的,只是不存在于实数域上。

  前者是初中生干的事,我们选择后一种。引入一个数 $i$,它满足性质 $ i ^ 2 = -1$,称之为“虚数单位”。相应地,借助虚数单位,立刻可以把实数域 $R$ 扩展到复数域 $C$:$$C: \left\{ a + b\cdot i ~ | ~ a, b \in R\right\}$$

  那么,对于 $x ^ 2 + x + 1$,就有了两个复根。把实部作为横轴坐标,虚部作为纵轴坐标,如图:

▲ 方程 $x^2 + x +1 = 0$ 的根

复数的运算

  复数相对于实数,无非引入了虚数单位;其余的运算法则没有什么影响。显然

  • $ (a + bi) + (c + di) = (a+c) + (b+d)i$
  • $ (a+bi)(c+di) = (ac - bd) + (ad+bc)i$

  另外,如果把坐标系改为以 $x$ 轴正方向为极轴的极坐标,那么复数有另一种表达方式——模长、辐角的二元组。模长就是极坐标系上的极径,辐角就是极角。

  上图展示了两个复数在直角坐标系、极坐标系上的情形。其中,$x_1 = -\frac12 - \frac {\sqrt{3}}{2}i$ 的极坐标形式是 $(1, \frac{2\pi}{3})$;$x _ 2$ 则是 $(1, \frac{4\pi}{3})$.

  经过繁杂的运算,不难验证,复数的乘法也可以表述为“模长相乘、辐角相加”。例如 $$x_1 \cdot x _2 = (1\cdot 1, \frac{2\pi}3 + \frac{4\pi}{3}) = (1, 2\pi) = (1, 0)$$

  也就是 $x_ 1 \cdot  x_2 = 1$,这与韦达定理算出 $x_1 \cdot x _2 = \frac{c}{a}=1$ 是吻合的。

欧拉公式

  以下这个著名的公式,被称为欧拉公式:$$e ^ {i \theta} = \cos \theta + i\sin \theta$$

  它是怎么来的?还记得 $\sin(x), \cos(x)$ 的泰勒展开吗?$$\begin{aligned}\sin(x) &= x - \frac{x ^ 3}{3!} + \frac{x ^ 5}{5!} - \frac{x ^ 7}{7!} + \frac{x ^ 9}{9!} \cdots \\ \cos(x) &= 1 - \frac{x ^ 2}{2!} + \frac{x ^ 4}{4!} - \frac{x ^ 6}{6!} + \frac{x ^ 8}{8!}\cdots\end{aligned}$$

好的,现在写出 $\cos \theta + i\sin \theta$ 的展开式:$$1 + i\theta - \frac{\theta ^ 2}{2!} - i\frac{\theta ^ 3}{3!} + \frac{\theta ^ 4}{4!} + i\frac{\theta ^ 5}{5!} \cdots$$

立刻发现这和 $e ^ x$ 的展开非常相似。我们有$$e ^ x = 1 + x + \frac{x ^ 2}{2!} + \frac{x ^ 3}{3!} + \frac{x ^ 4}{4!} + \frac{x ^ 5}{5!}\cdots$$

现在代入 $x = i\theta$,来看式子变成什么样吧:$$\begin{aligned}e ^ {i\theta} &= 1 + i\theta + \frac{i ^ 2 \theta ^ 2}{2!}+ \frac{i ^ 3 \theta ^ 3}{3!}+ \frac{i ^ 4 \theta ^ 4}{4!}+ \frac{i ^ 5 \theta ^ 5}{5!} \\ &=1 + i\theta - \frac{\theta ^ 2}{2!} - i\frac{\theta ^ 3}{3!} + \frac{\theta ^ 4}{4!} + i\frac{\theta ^ 5}{5!} \cdots \end{aligned}$$

注意到 $e ^ {i\theta}$ 的展开式与 $\cos \theta + i\sin \theta$ 的一模一样,于是我们有了欧拉公式 $$e ^ {i\theta} = \cos\theta + i\sin\theta$$

  进一步地,欧拉公式还能给出一个推论:$e ^ {\pi i} = -1$.

复平面上的单位圆

  直角坐标系上的单位圆,上面的点可以表示成 $(\cos \theta, \sin \theta)$,每一个点与一个角度一一对应。复平面上也有“单位圆”,也就是模长为 1 的全体复数的集,可以记为 $\cos \theta + i\sin\theta$,其中 $\theta$ 是辐角。立刻注意到,这个单位圆上的点 $\cos \theta + i\sin\theta$,也可以写成 $e ^ {i\theta}$. 例如我们上文举过的例子 $x_1, x _2$ 都在单位元上,有$$x _ 1 = -\frac12 - \frac {\sqrt{3}}{2}i = (1, \frac{2\pi}{3}) = e ^ {\frac{2\pi}{3}i}$$

  接下来,我们在 FFT 中,要讨论的求值点都是单位圆上的点。


多项式的点值表达

  一个 $n$ 次多项式,是形如 $A(x) = a_0 + a_1 x + a_2 x ^ 2 +\cdots + a_n x ^n$ 的函数。其中的 $n$ 称为这个多项式的;显然多项式是 $R \to R$ 的一个映射。

  接下来我们要给出一个结论:

给定 $n+1$ 个互不相同的点,可以拟合出一条 $n$ 阶的多项式经过这些点。这个过程称为“插值”。
例如:两个点可以确定一条直线,三个点可以确定一条抛物线。

  最朴素的插值算法(也是最适合手算的方法)是高斯消元。对于所有的点列出方程,最后解这个 $n+1$ 元一次方程组,即可得到这个多项式的全部 $n+1$ 个系数。复杂度是 $O(n ^ 3)$.

  存在效率更高的算法,例如拉格朗日插值法是 $O(n ^ 2)$ 的时间复杂度。原理很简单,本博客不再赘述。现在我们利用拉格朗日插值法找到一条经过 $(1,3), (2,4), (3, -1), (4, 2)$ 的多项式:

R.<x> = RR[x]
f = R.lagrange_polynomial([(1,3), (2,4), (3,-1), (4,2)])
f # 2.33333333333333*x^3 - 17.0000000000000*x^2 + 35.6666666666667*x - 18.0000000000000
▲ 用 SageMath 执行拉格朗日插值法
▲ 拉格朗日插值法得到的多项式

多项式乘法与 FFT

  朴素的多项式乘法,需要 $O(n ^ 2)$ 的复杂度。分块乘法有一点提升,大概是 $O(n ^ {1.59})$ 的复杂度。

  现在我们手上有两个 $n$ 阶多项式 $A, B$,该如何快速求出它们的乘积呢?考虑下面的算法:

  • 指定至少 $2n+1$ 个 x 轴坐标点 $x_i$,求出 $A, B$ 在这些点上的值。
  • 记多项式 $C = A \cdot B$,不难发现 $C(x_i) = A(x _ i) \cdot B(x _ i)$.
  • 于是我们把 $C( x _i)$ 的值全都求出来,由这些点插值得到 $C$ 的系数,就是 $A\cdot B$ 的系数了。

  下面给出一个例子。计算 $A = 1+2x+3x^2$ 与 $B=3+x+x ^ 2$ 的乘积:

R.<x> = RR[x]

A = 1+2*x+3*x^2
B = 3+x+x^2

p = [1, 2, 3, 4, 5] # 求值点

Ay = [A(xi) for xi in p]
By = [B(xi) for xi in p]

Cy = [(xi, a*b) for xi, a, b in zip(p, Ay, By)]
Cy
# [(1, 30.0000000000000),
#  (2, 153.000000000000),
#  (3, 510.000000000000),
#  (4, 1311.00000000000),
#  (5, 2838.00000000000)]
C = R.lagrange_polynomial(Cy)
C
# 3*x^4 + 5*x^3 + 12*x^2 + 7*x + 3
▲ 先求值,再点乘,再插值,得到 $C=A*B$

  多项式的单点求值过程是 $O(n)$ 的,采用霍纳法则(或称秦九韶算法)实现。一共要求出 $n+1$ 个点的值,总代价是 $O(n ^ 2)$;点乘过程是 $O(n)$ 的;插值得到 $C$ 的过程是 $O(n^2)$的。总复杂度仍然是 $O(n ^ 2)$,相当于没有改进。

  瓶颈是在求值、插值上。如果我们能提出一个快速的求值、插值算法,就能快速完成多项式乘法。这些 $2n+1$ 个求值点是可以由我们算法决定的。

  快速傅里叶变换(FFT)的原理就是:我们选择一组非常特殊的求值点,然后利用各种性质简化运算,最终在 $O(n \log n)$ 的时间复杂度内完成求值和插值。

单位根

  回到那个复平面上的圆。我们把这个圆 $n$ 等分,其中固定一个点在 $(1, 0)$ 处。于是这个圆上有了 $n$ 个等分点,称为 $n$ 次单位根。显然,$n$ 次单位根的 $n$ 次方就等于 1. 下面展示了 8 次、4 次单位根:

  我们把 $n$ 次单位根的第 $k$ 个记为 $\omega_n ^ k$. 显然有 $$\omega_n ^ k = e ^ {\frac{k}{n}2\pi i}$$

  用上面的式子,立刻可以得到几条性质。

  • 消去引理:$\omega _ {dn} ^ {dk} = \omega_n ^ k$
  • 折半引理:$2n$ 次单位根的平方的集合,就是 $n$ 次单位根的集合。
    证明:对于 $(\omega_{2n} ^ {k}) ^ 2$,其中 $k < n$ 的情况有 $(\omega_{2n} ^ {k}) ^ 2 = \omega_{2n} ^ {2k} = \omega _n ^k$. 对于其他情况:
    $$(\omega_{2n} ^ {k + n})^2 = \omega_{2n}^{2k+2n} = \omega _ {2n} ^ {2k} \cdot \omega_{2n}^{2n} = \omega_{2n} ^ {2k} = \omega_n ^ k$$

  取 $n$ 次单位根为求值点,对多项式求值,这个过程称为“离散傅里叶变换(DFT)”。逆过程(插值)称为“离散傅里叶反变换(IDFT)”。FFT 算法是 DFT、IDFT 的快速实现,折半引理是 FFT 分治的基础。

分治求值

  现在我们的目标是快速完成 DFT,也就是求出多项式 $A$ 在 $n$ 次单位根上的值。按传统方法肯定得 $O(n^2)$,但这一次求值点有特殊性质。我们考虑多项式$$A = a_0 + a_1x + a_2x ^2 + a_3 x ^ 3 + a_4 x ^ 4 +\cdots$$拆成两个多项式:$$\begin{aligned} A(x) &= A_ 0 (x^2) + x\cdot A _1 (x ^ 2)  \\ A _ 0 &= a _0 + a _ 2 x + a _ 4 x ^ 2 + \cdots \\ A _ 1 &= a_1 + a _3 x + a _5 x ^ 2+\cdots \end{aligned}$$

  注意到 $A_0, A _ 1$ 的阶都只有 $A$ 的一半,而且有另一个性质:虽然一共有 $n$ 个 $A(x)$ 需要求值,但 $x ^ 2$ 一共只有 $n /2$ 种取值,于是我们只需要对 $n/2$ 个数求 $A_0$ 的值、$n/2$ 个数求 $A_1$ 的值!而这些值可以直接交由递归计算。

  具体来讲,我们先递归求出 $A_0 (\omega _{n/2} ^ k), A_1 (\omega _{n/2} ^ k)$ 的值,其中 $k < n/2$;再枚举 $n/2$ 以内的 $k$:

  • 求 $A\left(\omega _ n ^ k\right)$ $$\begin{aligned}A(\omega _ n ^ k) &= A_0 \left((\omega _n ^k) ^ 2\right ) + \omega _ n ^ k \cdot A_1 \left((\omega _n ^k) ^ 2\right )  \\ &= A_0 \left(\omega _{n/2} ^k \right ) + \omega _ n ^ k \cdot A_1 \left(\omega _{n/2} ^k \right )\end{aligned}$$
  • 求 $A\left(\omega _ n ^ {k + n/2}\right)$ $$\begin{aligned}A\left(\omega _ n ^ {k + n/2}\right) &= A_0 \left((\omega _n ^{k + n/2}) ^ 2\right ) + \omega _ n ^ {k + n/2} \cdot A_1 \left((\omega _n ^{k + n/2}) ^ 2\right )  \\ &= A_0 \left(\omega _{n/2} ^k \right ) + \omega _ n ^ {k + n/2} \cdot A_1 \left(\omega _{n/2} ^k \right ) \\ &= A_0 \left(\omega _{n/2} ^k \right ) - \omega _ n ^ {k} \cdot A_1 \left(\omega _{n/2} ^k \right )\end{aligned}$$

  于是每轮循环可以求出两个 $A$ 值,我们只需要枚举 $n/2$ 个 $k$,就可以求出 $A$ 的 DFT!

def fft(A):    # A 为多项式的点值表达,求出 A 在 n 次单位根的值
    n = len(A)
    
    if n==1:
        return [A[0]]
    
    A0 = fft(A[0::2])  # A0[k] 是 A0((W_nk)^2) 的值
    A1 = fft(A[1::2])
    
    res = np.zeros(n, dtype = np.complex)
    
    for k in range(n//2):
        x = e^(k/n * 2 * pi * I)
        
        res[k] = A0[k] + x * A1[k]
        res[k+n//2] = A0[k] - x * A1[k]
    
    return res

r = fft([1, 2, 3, 4, 5, 6, 0, 0])
print(r.round(3))
# [21.   +0.j -9.657+3.j  3.   +4.j  1.657-3.j -3.   +0.j  1.657+3.j
#  3.   -4.j -9.657-3.j]

f(x) = 1 + 2*x + 3*x^2 + 4*x^3 + 5*x^4 + 6*x^5
y = [f(e^(k/8 * 2 * pi * I)) for k in range(8)]
print(np.array(y, dtype=complex).round(3))
# [21.   +0.j -9.657+3.j  3.   +4.j  1.657-3.j -3.   +0.j  1.657+3.j
#  3.   -4.j -9.657-3.j]
▲ 利用 FFT 计算 $1+2x+3x^2 + 4 x ^ 3 + 5x ^ 4 + 6 x ^ 5$ 的 DFT,并验证正确性

  来计算时间复杂度。对于阶为 $n$ 的多项式求 $n$ 个点的 DFT,把大问题拆分成子问题的代价是 $O(n)$;递归为两个子问题,每个子问题需要求 $(n/2)$ 阶多项式的 $(n/2)$ 个点的 DFT,问题规模是 $(n/2)$;合并子问题的解得到大问题的解,代价是 $O(n)$. 于是复杂度满足递推式:$$T(n) = 2T(n/2) + O(n)$$

  解得 $T(n) = O(n \log n)$. 我们实现了快速的 DFT 算法。

逆变换

  现在的问题是,如何在 $O(n\log n)$ 的时间复杂度内,完成插值。我们不难注意到,以上对于多项式 $A$ 求值的过程,实际上可以视为下面的矩阵乘法:

$$\begin{bmatrix}1 & 1 & 1 & \cdots & 1 \\1 & x_1 & x_1^2 & \cdots & x_1 ^ n\\1 & x_2 & x_2^2 & \cdots & x_2 ^ n \\ \vdots & \vdots &\vdots & \ddots & \vdots \\  1 & x_{n-1} & x_{n-1}^2 & \cdots & x_{n-1} ^ n\end{bmatrix} \cdot \begin{bmatrix}a_0 \\ a _ 1 \\ a _ 2 \\ \vdots \\ a _ n\\ \end{bmatrix} = \begin{bmatrix}y_0 \\ y _ 1 \\ y _ 2 \\ \vdots \\ y _ n\\ \end{bmatrix}$$

  其中 $x _ k$ 是 $\omega _ n ^ k$. 在 IDFT 中,我们的任务是:已知向量 $\boldsymbol{Y}$,需要推出向量 $\boldsymbol{A}$. 这只需要把左边矩阵的逆矩阵乘以 $\boldsymbol{Y}$ 就能得到。注意到左边的矩阵是一个范德蒙德矩阵,由于单位复数根的性质,逆矩阵形式非常漂亮,具体过程请看下面的文章,我实在懒得打字了orz:

浅谈范德蒙德(Vandermonde)方阵的逆矩阵的求法以及快速傅里叶变换(FFT)中IDFT的原理 - Deadecho - 博客园
浅谈范德蒙德(Vandermonde)方阵的逆矩阵与拉格朗日(Lagrange)插值的关系以及快速傅里叶变换(FFT)中IDFT的原理 标签: 行列式 矩阵 线性代数 FFT 拉格朗日插值 只要稍微看

  体现到代码里面,只需要稍微修改 fft:

def ifft(A):
    n = len(A)
    
    if n==1:
        return [A[0]]
    
    A0 = ifft(A[0::2])
    A1 = ifft(A[1::2])

    res = np.zeros(n, dtype = np.complex)
    
    for k in range(n//2):
        x = e^((n-k)/n * 2 * pi * I)
        
        res[k] = A0[k] + x * A1[k]
        res[k+n//2] = A0[k] - x * A1[k]
    
    return res

print((ifft(r) / 8).round())
# [1.+0.j 2.+0.j 3.-0.j 4.-0.j 5.-0.j 6.-0.j 0.+0.j 0.+0.j]
▲ 稍微修改 FFT 的代码,使之变成 IFFT

  可见成功复原了我们原多项式 $1+2x+3x^2 + 4 x ^ 3 + 5x ^ 4 + 6 x ^ 5$ 的系数。

多项式乘法的代码实现

  现在考虑把两个多项式 $A, B$ 相乘。我们的 FFT 只能处理 2 的整次幂的情况,所以需要把 $A, B$ 的阶都扩展到 2 的某个幂,高位用 0 填充。此外,由于 $C = A\cdot B$ 有 $\deg A + \deg B$ 的阶,我们还得再把求值点的个数扩充得比它多(否则就没法恢复 $C$ 的那么多系数了)。

  多项式乘法的模板题是 洛谷P3803 【模板】多项式乘法(FFT)。由于洛谷没有 SageMath,代码稍微改动了一下,仅依赖 numpy.

import numpy as np

PI = 3.1415926535

def fft(A):
    n = len(A)
    
    if n==1:
        return [A[0]]
    
    A0 = fft(A[0::2])
    A1 = fft(A[1::2])
    
    res = np.zeros(n, dtype = np.complex)
    
    omega = np.cos(2.0*PI/n) + np.sin(2.0*PI/n)*1.0j
    x = 1
    
    for k in range(n//2):        
        res[k] = A0[k] + x * A1[k]
        res[k+n//2] = A0[k] - x * A1[k]
    
        x *= omega
    
    return res

def ifft(A):
    n = len(A)
    
    if n==1:
        return [A[0]]
    
    A0 = ifft(A[0::2])
    A1 = ifft(A[1::2])
    
    res = np.zeros(n, dtype = np.complex)
    
    omega = np.cos(2.0*PI/n) - np.sin(2.0*PI/n)*1.0j
    x = 1
    
    for k in range(n//2):
        res[k] = A0[k] + x * A1[k]
        res[k+n//2] = A0[k] - x * A1[k]
        
        x *= omega
    
    return res

def times(A, B):
    lenC = len(A) + len(B) - 1
    pts = 1
    
    while pts < lenC:
        pts *= 2
    
    X = np.zeros(pts, dtype=float)
    X[0:len(A)] = A
    Y = np.zeros(pts, dtype=float)
    Y[0:len(B)] = B
    
    Z = ifft(fft(X) * fft(Y))[:lenC] / pts
    
    print(' '.join([str(int(x)) for x in Z.real.round()]))

_, _ = input().split()
A = [int(x) for x in input().split()]
B = [int(x) for x in input().split()]

times(A, B)