Lineax
Lineax 是一个用于线性求解和线性最小二乘的 JAX 库。也就是说,Lineax 提供了求解 $Ax = b$ 中 $x$ 的例程。(即使 $A$ 可能是病态或矩形的。)
特性包括:
- 支持 PyTree 值的矩阵和向量;
- 用于雅可比矩阵、转置等的通用线性算子;
- 高效的线性最小二乘(如 QR 求解器);
- 通过线性最小二乘的数值稳定梯度;
- 支持结构化(如对称)矩阵;
- 改进的编译时间;
- 部分算法的运行时间优化;
- 支持实值和复值输入;
- 使用 JAX 的所有优势:自动微分、自动并行、GPU/TPU 支持等。
安装
pip install lineax
需要 Python 3.9+、JAX 0.4.13+ 和 Equinox 0.11.0+。
文档
可在 https://docs.kidger.site/lineax 获取。
快速示例
Lineax 可以使用显式矩阵算子解决最小二乘问题:
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
或者 Lineax 可以在不具体化矩阵的情况下解决问题,如在这个二次求解中所做的:
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
引用
如果您发现这个库在学术工作中有用,请引用:(arXiv 链接)
@article{lineax2023,
title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={
AI for science workshop at Neural Information Processing Systems 2023,
arXiv:2311.17283
},
year={2023},
}
(也请考虑在 GitHub 上给项目加星。)
另请参阅:JAX 生态系统中的其他库
始终有用
Equinox:神经网络和核心 JAX 中未包含的所有内容!
jaxtyping:数组形状/数据类型的类型注解。
深度学习
Optax:一阶梯度(SGD、Adam 等)优化器。
Orbax:检查点(异步/多主机/多设备)。
Levanter:可扩展且可靠的基础模型(如 LLM)训练。
科学计算
Diffrax:数值微分方程求解器。
Optimistix:寻根、最小化、不动点和最小二乘。
BlackJAX:概率和贝叶斯采样。
sympy2jax:SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR:符号回归。(非 JAX 值得一提的项目!)
Awesome JAX
Awesome JAX:更多 JAX 项目的长列表。