概述 | 为什么选择俳句? | 快速开始 | 安装 | 示例 | 用户手册 | 文档 | 引用俳句
[!重要] 📣 截至2023年7月,Google DeepMind建议新项目采用Flax而不是俳句。Flax是由Google Brain最初开发,现在由Google DeepMind开发的神经网络库。 📣
在撰写本文时,Flax拥有比俳句更多的特性,较大的开发团队 和更活跃的社区。Flax在用户群体中有更高的接受度,具有更广泛的文档, 示例和一个活跃的社区来创建从头到尾的示例。
俳句将 保持尽力支持,但项目将进入维护模式,这意味着开发工作的重点将放在错误修复和与JAX新版本的兼容性上。
为了保持俳句与Python和JAX的新版本的兼容,我们将继续发布新的版本,但不会添加(或接受PR)新功能。
我们在Google DeepMind内部广泛使用俳句,并计划无限期地以这种模式支持俳句。
俳句是一种工具<br> 用于构建神经网络<br> 想象一下:“{JAX}的{十四行诗}”
俳句是由Sonnet的一些作者为JAX开发的简单神经网络库,Sonnet是一个为TensorFlow开发的神经网络库。
俳句的文档可以在https://dm-haiku.readthedocs.io/找到。
**说明:**如果您在寻找操作系统俳句,请访问https://haiku-os.org/。
JAX是一个结合NumPy、自动微分和一流GPU/TPU支持的数值计算库。
俳句是一个简单的JAX神经网络库,允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。
俳句提供了两个核心工具:模块抽象hk.Module
和一个简单的函数转换hk.transform
。
hk.Module
是持有自身参数、其他模块和应用用户输入函数的方法的Python对象。
hk.transform
将使用这些面向 对象、功能上“纯净”模块的函数转换成可以与jax.jit
、jax.grad
、jax.pmap
等一起使用的纯函数。
有很多为JAX开发的神经网络库。为什么要选择俳句?
Module
(模块)编程模型,同时保留了对JAX的函数转换的访问。hk.transform
)外,俳句旨在匹配Sonnet 2的API。模块、方法、参数名称、默认值和初始化方案应该匹配。hk.next_rng_key()
返回一个唯一的随机数生成器键。让我们来看一个示例神经网络、损失函数和训练循环。(欲了解更多示例, 请参见我们的示例目录。 MNIST 示例 是一个很好的起点。)
import haiku as hk import jax.numpy as jnp def softmax_cross_entropy(logits, labels): one_hot = jax.nn.one_hot(labels, logits.shape[-1]) return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1) def loss_fn(images, labels): mlp = hk.Sequential([ hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(10), ]) logits = mlp(images) return jnp.mean(softmax_cross_entropy(logits, labels)) loss_fn_t = hk.transform(loss_fn) loss_fn_t = hk.without_apply_rng(loss_fn_t) rng = jax.random.PRNGKey(42) dummy_images, dummy_labels = next(input_dataset) params = loss_fn_t.init(rng, dummy_images, dummy_labels) def update_rule(param, update): return param - 0.01 * update for images, labels in input_dataset: grads = jax.grad(loss_fn_t.apply)(params, images, labels) params = jax.tree_util.tree_map(update_rule, params, grads)
俳句的核心是hk.transform
。transform
函数允许您编写依赖于参数的神经网络函数(此处是Linear
层的权重)而不需要明确编写初始化那些参数的样板代 码。transform
通过将函数转换成纯函数对init
和apply
实现的形式来做到这点。
init
init
函数,签名为params = init(rng, ...)
(其中...
是未转换函数的参数),允许您收集网络中任何参数的初始值。俳句通过运行您的函数,跟踪任何通过hk.get_parameter
(由如hk.Linear
调用)请求的参数并返回给您。
params
对象是您的网络中所有参数的嵌套数据结构,设计供您检查和操作。
具体地,它是模块名称到模块参数的映射,其中模块参数是参数名称到参数值的映射。例如:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
函数,签名为result = apply(params, rng, ...)
,允许您向函数注入参数值。当调用hk.get_parameter
时,返回的值将来自您提供给apply
作为输入的params
:
loss = loss_fn_t.apply(params, rng, images, labels)
注意,由于我们的损失函数执行的实际计算不依赖于随机数,因此传入一个随机数生成器是没有必要的,因此我们也可以传入None
给rng
参数。 (请注意,如果您的计算确实使用了随机数,传入None
给rng
将引发错误)。在上面的示例中,我们让俳句自动对我们执行此操作:
loss_fn_t = hk.without_apply_rng(loss_fn_t)
既然apply
是一个纯函数,我们可以将其传递给jax.grad
(或其他JAX变换):
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
此示例中的训练循环非常简单。一个需要注意的细节是使用jax.tree_util.tree_map
来将sgd
函数应用于params
和grads
中的所有匹配条目。结果具有与原始params
相同的结构,并可以再次用于apply
。
俳句是用纯Python编写的,但依赖于通过JAX的C++代码。
由于JAX的安装因CUDA版本而异,俳句没有在requirements.txt
中列出JAX作为依赖项。
首先,按照这些说明来安装带有相关加速器支持的JAX。
然后,使用pip安装俳句:
$ pip install git+https://github.com/deepmind/dm-haiku
或者,您可以通过PyPI安装:
$ pip install -U dm-haiku
我们的示例依赖于额外的库(例如bsuite)。您可以使用pip安装所有额外的依赖:
$ pip install -r examples/requirements.txt
在俳句中,所有模块都是hk.Module
的子类。您可以实现任何方法(没有被特殊处理),但通常模块实现__init__
和__call__
。
让我们一起实现一个线性层:
class MyLinear(hk.Module): def __init__(self, output_size, name=None): super().__init__(name=name) self.output_size = output_size def __call__(self, x): j, k = x.shape[-1], self.output_size w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j)) w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init) b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros) return jnp.dot(x, w) + b
所有模块都有一个名称。如果没有传递name
参数,模块名将从Python类名中推断(例如MyLinear
变为my_linear
)。模块可以有通过hk.get_parameter(param_name, ...)
访问的命名参数。我们使用这个API(而不是直接使用对象属性),这样我们可以使用hk.transform
将您的代码转换为纯函数。
使用模块时,您需要定义函数并使用hk.transform
将其转换为一对纯函数。有关转换后函数返回的函数的详细信息,请参见我们的快速开始:
def forward_fn(x): model = MyLinear(10) return model(x) # 将`forward_fn`转换为具有`init`和`apply`方法的对象。 默认情况下, # `apply`将需要一个rng(可以是None),用于 # `hk.next_rng_key`。 forward = hk.transform(forward_fn) x = jnp.ones([1, 1]) # 当我们运行`forward.init`时,俳句将运行`forward_fn(x)`并收集初始参数值。由于参数 # 通常是随机初始化的,因此俳句要求您传递一个随机数生成器键给`init`: key = hk.PRNGSequence(42) params = forward.init(next(key), x) # 当我们运行`forward.apply`时,俳句将运行`forward_fn(x)`并从作为第一个参数传递的`params`中注入参数 # 值。 请注意,通过`hk.transform(f)`转换的模型必须使用额外的 # `rng`参数调用:`forward.apply(params, rng, x)`。 使用 # `hk.without_apply_rng(hk.transform(f))`如果不需要这样做。 y = forward.apply(params, None, x)
一些模型可能需要随机采样作为计算的一部分。例如,在使用重参数化技巧的变分自编码器中,需要从标准正态分布中进行随机采样。对于dropout,我们需要一个随机掩码来从输入中丢弃单元。使其与JAX一起工作的主要障碍在于PRNG键的管理。
在Haiku中,我们提供了一个简单的API来维护与模块关联的PRNG键序列:hk.next_rng_key()
(或对于多个键使用next_rng_keys()
):
class MyDropout(hk.Module): def __init__(self, rate=0.5, name=None): super().__init__(name=name) self.rate = rate def __call__(self, x): key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape) return x * p / (1.0 - self.rate) forward = hk.transform(lambda x: MyDropout()(x)) key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2) params = forward.init(key1, x) prediction = forward.apply(params, key2, x)
要更全面地了解与随机模型一起工作的情况,请参阅我们的VAE示例。
注意: hk.next_rng_key()
不是功能纯的,这意味着你应该避免在hk.transform
内使用它与JAX变换一起使用。欲了解更多信息和可能的解决方法,请查阅Haiku变换的文档和可用的Haiku网络中的JAX变换包装器。
一些模型可能希望维护一些内部的、可变的状态。例如,在批量归一化中,训练过程中遇到的值的移动平均值是维护的。
在Haiku中,我们提供了一个简单的API来维护与模 块关联的可变状态:hk.set_state
和hk.get_state
。使用这些函数时,您需要使用hk.transform_with_state
转换您的函数,因为返回的函数对的签名是不同的:
def forward(x, is_training): net = hk.nets.ResNet50(1000) return net(x, is_training) forward = hk.transform_with_state(forward) # `init`函数现在返回参数和状态。状态包含使用`hk.set_state`创建的任何内容。结构与参数相同(例如,这是一个按模块命名值的映射)。 params, state = forward.init(rng, x, is_training=True) # `apply`函数现在接受参数和状态。此外,它将返回更新的状态值。在resnet示例中,这将是用于批量规范化层中的移动平均值的更新值。 logits, state = forward.apply(params, state, rng, x, is_training=True)
如果你忘记使用hk.transform_with_state
,不要担心,我们会打印一个明确的错误,指向你hk.transform_with_state
,而不是默默地丢弃你的状态。
jax.pmap
进行分布式训练从hk.transform
(或hk.transform_with_state
)返回的纯函数完全兼容jax.pmap
。有关使用jax.pmap
进行SPMD编程的更多详细信息,请查看此处。
在Haiku中使用jax.pmap
的一个常见用途是对许多加速器进行数据并行训练,可能跨多个主机。在Haiku中,这可能如下所示:
def loss_fn(inputs, labels): logits = hk.nets.MLP([8, 4, 2])(x) return jnp.mean(softmax_cross_entropy(logits, labels)) loss_fn_t = hk.transform(loss_fn) loss_fn_t = hk.without_apply_rng(loss_fn_t) # 在单个设备上初始化模型。 rng = jax.random.PRNGKey(428) sample_image, sample_label = next(input_dataset) params = loss_fn_t.init(rng, sample_image, sample_label) # 将参数复制到所有设备上。 num_devices = jax.local_device_count() params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params) def make_superbatch(): """构造一个超级批次,即每个设备一个数据批次""" # 获取N个批次,然后拆分成图像列表和标签列表。 superbatch = [next(input_dataset) for _ in range(num_devices)] superbatch_images, superbatch_labels = zip(*superbatch) # 将超级批次堆叠为一个具有前导维度的数组,而不是一个Python列表。这是`jax.pmap`期望的输入。 superbatch_images = np.stack(superbatch_images) superbatch_labels = np.stack(superbatch_labels) return superbatch_images, superbatch_labels def update(params, inputs, labels, axis_name='i'): """基于输入和标签的表现更新参数。""" grads = jax.grad(loss_fn_t.apply)(params, inputs, labels) # 在所有数据并行副本之间取梯度的平均值。 grads = jax.lax.pmean(grads, axis_name) # 使用SGD或Adam或...更新参数 new_params = my_update_rule(params, grads) return new_params # 进行几次训练更新。 for _ in range(10): superbatch_images, superbatch_labels = make_superbatch() params = jax.pmap(update, axis_name='i')(params, superbatch_images, superbatch_labels)
要更全面地 了解分布式Haiku训练,请查看我们的ImageNet上的ResNet-50示例。
要引用此存储库:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.10},
year = {2020},
}
在此bibtex条目中,版本号应取自haiku/__init__.py
,年份对应于项目的开源发布年份。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。
字节跳动发布的AI编程神器IDE
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
AI助力,做PPT更简单!
咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。
选题、配图、成文,一站式创作,让内容运营更高效
讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。
专业的AI公文写作平台,公文写作神器
AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。
OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。
openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。
高分辨率纹理 3D 资产生成
Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领 域,为 3D 创作提供了高效解决方案。
一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。
3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号