Euler法, Heun法, Runge-Kutta法の比較
Euler 法, Heun 法, 古典的 Runge-Kutta 法について誤差の収束性のプロットを作成したので, コードを残しておきます.
コメント
例によって主な計算は Rust で実装し, それを PyO3 により Python から呼び出して可視化しています.
Rust 側は汎用のライブラリとして Python 側から ODE をコールバック関数として渡そうとしたのですが,
PyO3 が (現時点ではまだ) クロージャに陽に対応しておらず (doc),
PyAny
を触るのはさすがに時間がかかりすぎるので諦めました.
なので解くべき微分方程式は src/lib.rs
にハードコードしてあります.
そしていつのまにか PyO3 が version 0.13 になっていました.
Rust
Cargo.toml
[lib]
name = "odeint"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.13", features = ["extension-module"] }
src/solver.rs
pub fn euler<F>(f: &F, t: f64, x: f64, h: f64) -> f64
where
F: Fn(f64, f64) -> f64,
{
x + h*f(t, x)
}
pub fn heun<F>(f: &F, t: f64, x: f64, h: f64) -> f64
where
F: Fn(f64, f64) -> f64,
{
let k = f(t, x);
x + h*( k + f(t+h, x+h*k))*0.5
}
pub fn rk4<F>(f: &F, t: f64, x: f64, h: f64) -> f64
where
F: Fn(f64, f64) -> f64,
{
let k1 = f(t, x);
let k2 = f(t + 0.5*h, x + 0.5*h*k1);
let k3 = f(t + 0.5*h, x + 0.5*h*k2);
let k4 = f(t + h, x + h*k3);
x + h*(k1 + 2.*(k2 + k3) + k4)/6.
}
src/lib.rs
mod solver;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
fn f(_: f64, x: f64) -> f64 { x.cos() }
#[pyfunction]
fn solve(method: String, n: usize) -> PyResult<f64> {
let method = match method.as_str() {
"euler" => solver::euler,
"heun" => solver::heun,
"rk4" => solver::rk4,
_ => unimplemented!(),
};
let h = 1. / n as f64;
let mut t = 0.;
let mut x = 0.;
for _ in 0..n {
x = method(&f, t, x, h);
t += h;
}
Ok(x)
}
#[pymodule]
fn odeint(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!( solve ))?;
Ok(())
}
Python
plot.py
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Source Han Serif JP"
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["font.size"] = 16
import odeint
Ns = np.array([ 3, 10, 30, 100, 301, 1_000, ], dtype=int)
exact = np.arcsin(np.tanh(1.))
def error(method):
err = []
for N in Ns:
x = odeint.solve(method, N)
err.append(x)
return np.fabs(np.array(err) / exact - 1.)
fig = plt.figure()
plt.subplots_adjust(left=0.15, right=0.9, bottom=0.14, top=0.92)
ax = fig.add_subplot(111)
ax.plot( 1./Ns, error("euler"), "^-", label="Euler")
ax.plot( 1./Ns, error("heun"), "d-", label="Heun")
ax.plot( 1./Ns, error("rk4"), "o-", label="RK4")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_ylim([1e-13, 1e-1])
ax.grid()
ax.grid(which="minor", alpha=0.3)
ax.legend()
ax.set_xlabel(r"刻み幅", fontsize=18)
ax.set_ylabel(r"相対誤差", fontsize=18)
plt.savefig("odesolver-comparison.svg")
plt.show()
plt.close()