[Rust] Stockham型FFT
高速 Fourier 変換 (fast Fourier transformation, FFT) のアルゴリズムは何度勉強しても忘れるんですよね...
離散 Fourier 変換
本記事を通じて $i$ は虚数単位を表すものとします. $N$ 個の複素数 $a_j$ ($j = 0, 1, \cdots, N-1$) の離散 Fourier 変換 (discrete Fourier transformation, DFT) $c_k$ ($k = 0, 1, \cdots, N-1$) は $$a_j = \frac{ 1 }{ N } \sum_k c_k \exp \left( \frac{ 2 \pi i j k }{ N } \right)$$ $$c_k = \sum_k a_j \exp \left( - \frac{ 2 \pi i j k }{ N } \right)$$ により定義されます. 以下では $\omega_N = \exp ( - 2 \pi / N )$ と略記します. この記法では, DFT とはベクトル $a = ( a_j )$ に行列 $[ W ]_{k j} = \omega_N^{k j}$ を作用させることに過ぎません.
直接計算
DFT の定義 $$c_k = \sum_{j = 0}^{N - 1} a_j ( \omega_N^k )^j$$ をそのままコードに起こすとこうなると思われます.
use num_complex::Complex;
pub fn fft(a: &[Complex<f64>]) -> Vec<Complex<f64>> {
let n = a.len();
let omega = crate::omega(n);
(0..n).scan(Complex::new(1., 0.), |r, _k| {
let (c, _) = a.iter()
.fold((Complex::new(0.,0.), Complex::new(1., 0.)), |(acc, mode), a| {
(acc + a*mode, mode * *r)
});
*r *= omega;
Some(c)
}).collect()
}
無論このコードは計算時間が $O ( N^2 )$ を要するため, 大きな $N$ に対してこれを用いることはできません. 手元の環境では $N = 2 \times 1024$ を超えると秒単位で待たされるので実用には厳しいと思います.
Stockham 型 FFT
しかし DFT は極めて性質が良いため, 計算アルゴリズムを工夫すれば $O ( N \ln N )$ で計算できます. そのようなアルゴリズムを総称して FFT と呼んでいます. ここでは Stockham 型 FFT を取り上げます.
$N = 2^L$ である場合に議論を限定します. $\alpha_l = 2^l$, $\beta_l = 2^{L - l - 1}$ として, $N$ 個の複素数の組 $X$ を $$X_{l+1} [ 2 \alpha j + k ] = X_l ( \alpha j + k ) + \omega_N^{k \beta} X_l ( \alpha j + n/2 + k )$$ $$X_{l+1} [ 2 \alpha j + k + \alpha ] = X_l ( \alpha j + k ) - \omega_N^{k \beta} X_l ( \alpha j + n/2 + k )$$ (ただし $j$ は $0..\beta$ を, $k$ は $0..\alpha$ を走る) という漸化式に従って更新するとき, $X_{L-1} ( j ) = c_j$ は $X_0 ( j ) = a_j$ の DFT を与えます.
use std::f64::consts::PI;
use num_complex::Complex;
pub fn fft_recur(a: &[Complex<f64>]) -> Vec<Complex<f64>> {
let n = a.len();
let lmax = n.trailing_zeros() as usize;
assert_eq!(n, 2usize.pow(lmax as u32));
let buf: Vec<_> = a.iter().cloned().collect();
let tmp: Vec<_> = vec![ Complex::new(0., 0.); n ];
stockham_recur(n, lmax, 0, buf, tmp)
}
fn stockham_recur(n: usize, lmax: usize, l: usize, buf: Vec<Complex<f64>>, mut tmp: Vec<Complex<f64>>) -> Vec<Complex<f64>> {
let alpha = 1 << l;
let beta = n >> l >> 1;
let omega_beta = {
let phase = - 2. * PI * (beta as f64) / (n as f64);
Complex::from_polar(1., phase)
};
for j in 0..beta {
let mut mode = Complex::new(1., 0.);
for k in 0..alpha {
let a = buf[alpha*j + k];
let b = buf[alpha*j + n/2 + k] * mode;
tmp[2*alpha*j + k] = a + b;
tmp[2*alpha*j + k + alpha] = a - b;
mode *= omega_beta;
}
}
if l+1 == lmax {
tmp
} else {
stockham_recur(n, lmax, l+1, tmp, buf)
}
}
漸化式を解く際に, 不要なメモリコピーを避けるため,
読み出し側 (buf
) と書き込み側 (tmp
) の役割を $l$ が進む度に入れ替えると良さそうです (手元ではこれで 25% 高速化できました).
$N = 4 \times 1024$ では定義通りだと2.6秒, Stockham 型 FFT だと0.4秒でした.
まだ最適化できるはずですが, 力尽きたのでここで一旦終わりとします.