Diffrax 是一个基于 JAX 的库,提供数值微分方程求解器。
功能包括:
Tsit5
、Dopri8
、辛求解器、隐式求解器);从技术角度看,该库的内部结构非常酷——所有类型的方程(ODEs、SDEs、CDEs)都以统一的方式被求解(而不是分别处理),生成一个小巧紧凑的库。
pip install diffrax
需要 Python 3.9+、JAX 0.4.13+ 和 Equinox 0.10.11+。
可在 https://docs.kidger.site/diffrax 获取。
from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp def f(t, y, args): return -y term = ODETerm(f) solver = Dopri5() y0 = jnp.array([2., 3.]) solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
这里,Dopri5
是指 Dormand--Prince 5(4) 数值微分方程求解器,这是许多问题的标准选择。
如果您在学术研究中发现这个库有用,请引用:(arXiv链接)
@phdthesis{kidger2021on, title={{O}n {N}eural {D}ifferential {E}quations}, author={Patrick Kidger}, year={2021}, school={University of Oxford}, }
(也请考虑在GitHub上为该项目加星。)
总是有用
Equinox: 神经网络和核心JAX中还没有的所有东西!
jaxtyping: 数组形状/数据类型的类型注解。
深度学习
Optax: 一阶梯度(SGD、Adam等)优化器。
Orbax: 检查点(异步/多宿主/多设备)。
Levanter: 可扩展且可靠的基础模型(如LLMs)训练。
科学计算
Optimistix: 求根、最小化、固定点和最小二乘法。
Lineax: 线性求解器 。
BlackJAX: 概率+贝叶斯采样。
sympy2jax: SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR: 符号回归。(非JAX荣誉提名!)
Awesome JAX
Awesome JAX: 更长的其他JAX项目列表。