快速入门 | 转换 | 安装指南 | 神经网络库 | 更新日志 | 参考文档
JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。
通过其更新版本的 Autograd,JAX 可以自动微分原生 Python 和 NumPy 函数。它可以对循环、分支、递归和闭包进行微分,还可以求导数的导数的导数。它支持通过 grad
进行反向模式微分(又称反向传播),以及前向模式微分,两者可以任意组合到任何阶数。
JAX 的新特性是使用 XLA 来编译和在 GPU 和 TPU 上运行 NumPy 程序。编译默认在后台进行,库调用会即时编译并执行。但 JAX 还允许您使用单函数 API jit
将自己的 Python 函数即时编译成 XLA 优化的内核。编译和自动微分可以任意组合,因此您可以表达复杂的算法并获得最大性能,而无需离开 Python。您甚至可以使用 pmap
同时编程多个 GPU 或 TPU 核心,并对整个过程进行微分。
深入一点,您会发现 JAX 实际上是一个用于可组合函数转换的可扩展系统。grad
和 jit
都是此类转换的实例。其他转换包括用于自动向量化的 vmap
和用于多个加速器单程序多数据 (SPMD) 并行编程的 pmap
,未来还会有更多。
这是一个研究项目,而不是 Google 的官方产品。请预期会有错误和棘手问题。请通过尝试使用、报告错误并让我们知道您的想法来提供帮助!
import jax.numpy as jnp from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.tanh(outputs) # 下一层的输入 return outputs # 最后一层没有激活函数 def loss(params, inputs, targets): preds = predict(params, inputs) return jnp.sum((preds - targets)**2) grad_loss = jit(grad(loss)) # 编译后的梯度评估函数 perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # 快速逐样本梯度
使用浏览器中的笔记本直 接开始,连接到 Google Cloud GPU。 以下是一些入门笔记本:
JAX 现在可以在 Cloud TPU 上运行。 要试用预览版,请参阅 Cloud TPU Colabs。
深入了解 JAX:
JAX 的核心是一个用于转换数值函数的可扩展系统。以下是四个主要的转换:grad
、jit
、vmap
和 pmap
。
grad
进行自动微分JAX 的 API 与 Autograd 大致相同。最常用的函数是用于反向模式梯度的 grad
:
from jax import grad import jax.numpy as jnp def tanh(x): # 定义一个函数 y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) grad_tanh = grad(tanh) # 获取其梯度函数 print(grad_tanh(1.0)) # 在 x = 1.0 处求值 # 输出 0.4199743
您可以使用 grad
求任意阶导数。
print(grad(grad(grad(tanh)))(1.0)) # 输出 0.62162673
对于更高级的自动微分,您可以使用 jax.vjp
进行反向模式向量-雅可比积,使用 jax.jvp
进行前向模式雅可比-向量积。这两者可以与其他 JAX 转换任意组合。以下是一种组合它们以创建高效计算完整 Hessian 矩阵的函数的方法:
from jax import jit, jacfwd, jacrev def hessian(fun): return jit(jacfwd(jacrev(fun)))
与 Autograd 一样,您可以自由地在 Python 控制结构中使用微分:
def abs_val(x): if x > 0: return x else: return -x abs_val_grad = grad(abs_val) print(abs_val_grad(1.0)) # 输出 1.0 print(abs_val_grad(-1.0)) # 输出 -1.0(abs_val 被重新求值)
有关更多信息,请参阅自动微分参考文档和 JAX 自动微分食谱。
jit
进行编译您可以使用 XLA 通过 jit
对函数进行端到端编译,可以用作 @jit
装饰器或高阶函数。
import jax.numpy as jnp from jax import jit def slow_f(x): # 元素级操作从融合中获得巨大收益 return x * x + x * 2.0 x = jnp.ones((5000, 5000)) fast_f = jit(slow_f) %timeit -n10 -r3 fast_f(x) # 在 Titan X 上约 4.5 ms/循环 %timeit -n10 -r3 slow_f(x) # 约 14.5 ms/循环(通过 JAX 也在 GPU 上)
您可以随意组合 jit
、grad
和任何其他 JAX 转换。
使用 jit
会对函数可以使用的 Python 控制流类型施加限制;更多信息请参阅注意事项笔记本。
vmap
进行自动向量化vmap
是向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是将循环保持在外部,而是将循环下推到函数的基本操作中以获得更好的性能。
使用 vmap
可以避免在代码中携带批次维度。例如,考虑这个简单的非批处理神经网络预测函数:
def predict(params, input_vec): assert input_vec.ndim == 1 activations = input_vec for W, b in params: outputs = jnp.dot(W, activations) + b # `activations` 在右侧! activations = jnp.tanh(outputs) # 下一层的输入 return outputs # 最后一层没有激活函数
我们通常会写成 jnp.dot(activations, W)
以允许 activations
左侧有一个批次维度,但这个特定的预测函数只适用于单个输入向量。如果我们想一次性对一批输入应用这个函数,从语义上讲,我们可以这样写:
from functools import partial predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
但是一次只处理一个样本会很慢!最好将计算向量化,这样在每一层我们都在进行矩阵-矩阵乘法,而不是矩阵-向量乘法。
vmap
函数为我们完成了这种转换。也就是说,如果我们写:
from jax import vmap predictions = vmap(partial(predict, params))(input_batch) # 或者,另一种写法 predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
那么 vmap
函数会将外部循环推到函数内部,我们的机器最终会执行矩阵-矩阵乘法,就好像我们手动进行了批处理一样。
不使用 vmap
手动批处理一个简单的神经网络很容易,但在其他情况下,手动向量化可能不切实际或不可能。比如高效计算每个样本梯度的问题:对于一组固定的参数,我们想要计算损失函数在批次中每个样本上单独评估的梯度。使用 vmap
,这很容易:
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
当然,vmap
可以任意组合 jit
、grad
和任何其他 JAX 变换!我们在 jax.jacfwd
、jax.jacrev
和 jax.hessian
中使用 vmap
进行正向和反向自动微分,以快速计算雅可比矩阵和海森矩阵。
pmap
进行 SPMD 编程对于多个加速器(如多个 GPU)的并行编程,使用 pmap
。使用 pmap
,你可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。应用 pmap
意味着你编写的函数将由 XLA 编译(类似于 jit
),然后在设备上复制并并行执行。
以下是在 8 GPU 机器上的示例:
from jax import random, pmap import jax.numpy as jnp # 创建 8 个随机 5000 x 6000 矩阵,每个 GPU 一个 keys = random.split(random.key(0), 8) mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) # 在每个设备上并行运行本地矩阵乘法(无数据传输) result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape 是 (8, 5000, 5000) # 在每个设备上并行计算平均值并打印结果 print(pmap(jnp.mean)(result)) # 打印 [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表达纯映射外,你还可以使用设备间的快速集体通信操作:
from functools import partial from jax import lax @partial(pmap, axis_name='i') def normalize(x): return x / lax.psum(x, 'i') print(normalize(jnp.arange(4.))) # 打印 [0. 0.16666667 0.33333334 0.5 ]
你甚至可以嵌套 pmap
函数以实现更复杂的通信模式。
所有这些都可以组合,所以你可以自由地对并行计算进行微分:
from jax import grad @pmap def f(x): y = jnp.sin(x) @pmap def g(z): return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() return grad(lambda w: jnp.sum(g(w)))(x) print(f(x)) # [[ 0. , -0.7170853 ], # [-3.1085174 , -0.4824318 ], # [10.366636 , 13.135289 ], # [ 0.22163185, -0.52112055]] print(grad(lambda x: jnp.sum(f(x)))(x)) # [[ -3.2369726, -1.6356447], # [ 4.7572474, 11.606951 ], # [-98.524414 , 42.76499 ], # [ -1.6007166, -1.2568436]]
当对 pmap
函数进行反向模式微分(例如使用 grad
)时,计算的反向传播会像正向传播一样并行化。
更多信息请参见 SPMD Cookbook 和 SPMD MNIST 分类器从头开始示例。
要更全面地了解当前的注意事项,包括示例和解释,我们强烈建议阅读 Gotchas Notebook。一些突出的问题包括:
JAX 变换仅适用于纯函数,这些函数没有副作用并遵守引用透明性(即使用is
进行对象身份测试不会保留)。如果您对非纯Python函数使用JAX变换,可能会看到类似Exception: Can't lift Traced...
或Exception: Different traces at same level
的错误。
数组的原地突变更新,如x[i] += y
,不受支持,但有函数式替代方案。在jit
下,这些函数式替代方案会自动在原地重用缓冲区。
如果您在寻找卷积运算符,它们在jax.lax
包中。
JAX默认强制使用单精度(32位,例如float32
)值,要启用双精度(64位,例如float64
),需要在启动时设置jax_enable_x64
变量(或 设置环境变量JAX_ENABLE_X64=True
)。在TPU上,JAX默认对所有内容使用32位值,除了"类矩阵乘法"操作(如jax.numpy.dot
和lax.conv
)中的内部临时变量。这些操作有一个precision
参数,可以通过三次bfloat16传递来近似32位操作,可能会导致运行时间变慢。TPU上的非矩阵乘法操作会转换为通常强调速度而非精度的实现,因此实际上TPU上的计算会比其他后端上的类似计算精度更低。
NumPy的一些涉及Python标量和NumPy类型混合的dtype提升语义没有保留,即np.add(1, np.array([2], np.float32)).dtype
是float64
而不是float32
。
一些转换,如jit
,限制了您使用Python控制流的方式。如果出现问题,您总会收到明确的错误提示。您可能需要使用jit
的static_argnums
参数,结构化控制流原语如lax.scan
,或者仅对较小的子函数使用jit
。
Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | 是 | 是 | 是 | 是 | 是 | 是 |
NVIDIA GPU | 是 | 是 | 否 | 不适用 | 否 | 实验性 |
Google TPU | 是 | 不适用 | 不适用 | 不适用 | 不适用 | 不适用 |
AMD GPU | 实验性 | 否 | 否 | 不适用 | 否 | 否 |
Apple GPU | 不适用 | 否 | 实验性 | 实验性 | 不适用 | 不适用 |
硬件 | 指令 |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU | 使用Docker或从源代码构建。 |
Apple GPU | 按照Apple的说明进行操作。 |
有关其他安装策略的信息,请参阅文档。这些包括从源代码编译、使用Docker安装、使用其他版本的CUDA、社区支持的conda构建,以及一些常见问题的答案。
多个Google研究团队开发并分享了用JAX训练神经网络的库。如果您想要一个功能齐全的神经网络训练库,并附有示例和操作指南,可以尝试Flax。查看新的NNX API以获得简化的开发体验。
Google X维护神经网络库Equinox。这被用作JAX生态系统中几个其他库的基础。
此外,DeepMind已开源了围绕JAX的一系列库,包括用于梯度处理和优化的Optax,用于强化学习算法的RLax,以及用于可靠代码和测试的chex。(观看NeurIPS 2020 JAX Ecosystem在DeepMind的演讲点击这里)
要引用此仓库:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.3.13},
year = {2018},
}
在上述bibtex条目中,名字按字母顺序排列,版本号应为jax/version.py中的版本,年份对应项目的开源发布。
JAX的一个初始版本,仅支持自动微分和编译到XLA,在2018年SysML会议上的一篇论文中有描述。我们目前正在撰写一篇更全面和最新的论文,涵盖JAX的理念和功能。
有关JAX API的详细信息,请参阅参考文档。
对于希望成为JAX开发者的入门指南,请参阅开发者文档。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。
字节跳动发布的AI编程神器IDE
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
AI助力,做PPT更简单!
咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美 化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。
选题、配图、成文,一站式创作,让内容运营更高效
讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。
专业的AI公文写作平台,公文写作神器
AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。
OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。
openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。
高分辨率纹理 3D 资产生成
Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。
一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。
3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号