Fast Fourier Transform

为什么需要FFT

我们都知道对于两个多项式f(x)g(x),朴素的方法计算它们的乘积h(x) = f(x) * g(x)需要O(n^2)的时间。这里隐含了一个条件:f(x)g(x)都是系数形式表示的多项式,比如说: f(x) = a0x^0 + a1x^1 + a2x^2 + ... + an-1x^n-1

实际上多项式还有一种点值表示法,就是 f = {(p1, f(p1)), (p2, f(p2)), (p3, f(p3)), ..., (pn, f(pn))} ,也就是初中熟悉的待定系数法求多项式。这种看起来十分反人类的表示方法其实有一个非常大的好处:在两个多项式的采样点相同的时候,它们的四则运算都是O(n)的复杂度。比如说对于f = {(pi, f(pi))}g = {(pi, g(pi))},求h = f*g,那么很容易得到h = {(pi, f(pi)*g(pi))},也就是x不变,y相乘。

于是,就有了一种比较曲折的计算f(x)*g(x)的方法:

  1. f(x)g(x)求值(系数形式转成点值形式 DFT
  2. fg点值形式上进行乘法得到h
  3. h插值,得到h(x)(点值形式转成系数形式 IDFT

对于朴素方法,上面三步的复杂度分别为O(n^2), O(n), O(n^3),然而巧妙的使用单位负根可以使复杂度变成O(nlogn), O(n), O(nlogn),赶快去膜拜傅立叶大神吧!

什么时候需要多项式乘法呢?比如高精乘,比如卷积h[i] = f[j]*g[i-j]

Discrete Fourier Transform

DFT的作用就是把一个系数形式的多项式转换成点值形式。IDFT就是反过来。FFT和IFFT实际上就是用O(nlogn)的时间完成DFT和IDFT。

单位复根

FFT利用的是单位复根的一些非常有意思的性质。

欧拉恒等式告诉我们e^ix = cos(x) + isin(x),而n次单位复根Wn则是指e^(2πi/n),可以发现Wn0n-1次幂正好把一个复平面上的单位圆平均分成了n份。有一些性质:

  • Wn^j = -Wn^(j+n/2) 这个是因为它们在复平面上方向相反
  • W(an)^(aj) = Wn^j 这个叫做相消引理
  • (Wn^j)^2 = W(n/2)^j 这个由相消引理容易得到
  • {(Wn^j)^2 | j = 0..n-1} = {W(n/2)^k | k = 0..n/2-1} 这个叫做折半引理,实际上就是上面的集合形式

Fast Fourier Transform

FFT的精髓在于,选取n个n次单位复根作为采样点,由折半引理可以知道,只需要计算n/2个采样值,递归下去做,由主定理就可以得到总的复杂度为O(nlogn)。为了方便起见,一般把n补充为2^k的形式。

1
2
3
4
5
6
7
如果把
    f(x)  = a0x^0 + a1x^1 + a2x^2 + a3x^3 + ... + an-1x^n-1
拆成
    f0(x) = a0x^0 + a2x^1 + a4x^2 + ... + an-2x^n/2-1
    f1(x) = a1x^0 + a3x^1 + a5x^2 + ... + an-1x^n/2-1
那么就有
    f(x)  = f0(x^2) + x*f1(x^2)

写成程序大概就是

1
2
3
4
5
6
7
8
9
Recursive_FFT(A[n])
    Wn, w = e^(2πi/n), 1
    Y0 = Recursive_FFT(A[0], A[2], A[4], ...)
    Y1 = Recursive_FFT(A[1], A[3], A[5], ...)
    For (k=0;k<n/2;k++)
        Y[k]     = Y0[k] + w*Y1[k]
        Y[k+n/2] = Y0[k] - w*Y1[k]
        w=w*wn
    Return(Y)

Inversed Fast Fourier Transform

转回来的方法就是:单位复根Wn转成e^(-2πi/n),然后把y[i] /= n

避免递归

上面那个递归版本的FFT实际上常数很大,因为要复制数组。我们观察一下递归树,假设这是一个长度为8的多项式。

1
2
3
4
5
6
7
8
9
+===============================================+
| 000   001   010   011   100   101   110   111 |
+-----------------------------------------------+
| 000   010   100   110 | 001   011   101   111 |
+-----------------------------------------------+
| 000   100 | 010   110 | 001   101 | 011   111 |
+-----------------------------------------------+
| 000 | 100 | 010 | 110 | 001 | 101 | 011 | 111 |
+===============================================+

实际上,从叶子往上做的话就是把每个数的二进制位颠倒过来。

C++ Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
struct Complex
{
  double real, imag;
  Complex(const double r = .0, const double i = .0): real(r), imag(i) { }
};
const Complex ZERO(.0, .0);
inline Complex operator+(const Complex& lhs, const Complex& rhs) { return Complex(lhs.real+rhs.real, lhs.imag+rhs.imag); }
inline Complex operator-(const Complex& lhs, const Complex& rhs) { return Complex(lhs.real-rhs.real, lhs.imag-rhs.imag); }
inline Complex operator*(const Complex& lhs, const Complex& rhs) { return Complex(lhs.real*rhs.real-lhs.imag*rhs.imag, lhs.real*rhs.imag+lhs.imag*rhs.real); }
inline int bit_reverse(const int x, const int n)
{
  int res = 0;
  for (int i = 0; i < n; ++i)
    res |= (x>>i&1)<<(n-i-1);
  return res;
}
void fft(Complex y[], const Complex a[], const int n, const int rev)
{
  const int len = 1<<n;
  for (int i = 0; i < len; ++i) y[i] = a[bit_reverse(i, n)];
  for (int d = 1; d <= n; ++d)
  {
    const int m = 1<<d;
    const Complex wn(std::cos(2*PI/m*rev), std::sin(2*PI/m*rev));
    for (int k = 0; k < len; k += m)
    {
      Complex w(1., .0);
      for (int j = 0; j < m/2; ++j)
      {
        const Complex u = y[k+j], t = w*y[k+j+m/2];
        y[k+j] = u+t, y[k+j+m/2] = u-t;
        w = w*wn;
      }
    }
  }
  if (rev == -1)
    for (int i = 0; i < len; ++i) y[i].real /= len, y[i].imag = .0;
}
void convolution(Complex c[], Complex a[], Complex b[], const int la, const int lb)
{
  static Complex tmp1[MAXN], tmp2[MAXN];
  int n = 0, len = 1<<n;
  for (; len < 2*la || len < 2*lb; ++n) len <<= 1;
  std::fill(a+la, a+len, ZERO);
  std::fill(b+lb, b+len, ZERO);
  fft(tmp1, a, n, 1);
  fft(tmp2, b, n, 1);
  for (int i = 0; i < len; ++i) tmp1[i] = tmp1[i]*tmp2[i];
  fft(c, tmp1, n, -1);
}

FNT

有时候答案不会很大,可以用mod p的剩余系来代替复数。因为对于质数p的原根g,有g^(p-1) mod p = 1,所以可以把g^(p-1)/n作为单位根使用。p可以选取形如c*2^k + 1的质数。

IFNT的时候,只需要把1~n-1翻过来,并且乘上n的乘法逆元即可。

C++ Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
const long long P = 1004535809LL, G = 3;
inline long long powmod(long long base, long long k, const long long mod)
{
  long long ans = 1;
  for (; k; k >>= 1)
  {
    if (k & 1) ans = ans*base%mod;
    base = base*base%mod;
  }
  return ans;
}
inline int bit_reverse(const int x, const int n)
{
  int res = 0;
  for (int i = 0; i < n; ++i)
    res |= (x>>i&1)<<(n-i-1);
  return res;
}
void fnt(long long y[], const long long a[], const int n, const int rev = 1)
{
  const int len = 1<<n;
  for (int i = 0; i < len; ++i) y[i] = a[bit_reverse(i, n)];
  for (int d = 1; d <= n; ++d)
  {
    const int m = 1<<d;
    const long long wn = powmod(G, (P-1)/m, P);
    for (int k = 0; k < len; k += m)
    {
      long long w = 1;
      for (int j = 0; j < m/2; ++j)
      {
        const long long u = y[k+j], t = w*y[k+j+m/2];
        y[k+j] = (u+t)%P, y[k+j+m/2] = (u-t)%P;
        w = w*wn%P;
      }
    }
  }
  if (rev == 1) return;
  const long long inv = powmod(len, P-2, P);
  for (int i = 0; i < len; ++i) y[i] = (y[i]*inv%P+P)%P;
  std::reverse(y+1, y+len);
}
void convolution(long long c[], long long a[], long long b[], const int la, const int lb)
{
  static long long tmp1[MAXN], tmp2[MAXN];
  int n = 0, len = 1<<n;
  for (; len < 2*la || len < 2*lb; ++n) len <<= 1;
  std::fill(a+la, a+len, 0);
  std::fill(b+lb, b+len, 0);
  fnt(tmp1, a, n);
  fnt(tmp2, b, n);
  for (int i = 0; i < len; ++i) tmp1[i] = tmp1[i]*tmp2[i]%P;
  fnt(c, tmp1, n, -1);
}

Comments