项目简介
该存储库包含 OpenAI 的 Whisper 模型的优化 JAX代码,主要基于 Hugging Face Transformers Whisper 实现。与 OpenAI的 PyTorch 代码相比,Whisper JAX 的运行速度快了70 倍以上,使其成为可用的最快的 Whisper 运行。
JAX 代码在 CPU、GPU 和 TPU 上兼容,并且可以独立运行或作为推理端点。
安装
Whisper JAX 使用 Python 3.9 和 JAX 版本 0.4.5 进行了测试。安装假定您的设备上已安装最新版本的 JAX 包。您可以使用官方 JAX 安装指南来执行此操作:
https://github.com/google/jax#installation
一旦安装了适当版本的 JAX,就可以通过 pip 安装 Whisper JAX
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
要将 Whisper JAX 包更新到最新版本,只需运行:
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
管道使用
运行 Whisper JAX 的推荐方式是通过FlaxWhisperPipline抽象类。此类处理所有必要的预处理和后处理,并包装生成方法以实现跨加速器设备的数据并行性。
Whisper JAX 利用 JAX 的pmap功能实现跨 GPU/TPU 设备的数据并行性。该函数在第一次调用时进行即时 (JIT)编译。此后,该函数将被缓存,使其能够以超快的速度运行:
from whisper_jax import FlaxWhisperPipline
# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
# used cached function thereafter - super fast!!
text = pipeline("audio.mp3")
半精度
通过在实例化管道时传递 dtype 参数,可以以半精度运行模型计算。通过以半精度存储中间张量,这将大大加快计算速度。模型权重的精度没有变化。
对于大多数 GPU,dtype 应设置为jnp.float16. 对于 A100 GPU 或 TPU,dtype 应设置为jnp.bfloat16:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
批处理
Whisper JAX 还提供跨加速器设备批量处理单个音频输入的选项。音频首先被分成 30 秒的片段,然后将片段分派到模型进行并行转录。所得到的转录在边界处缝合在一起以给出单个、统一的转录。实际上,如果选择的批处理大小足够大,则与顺序转录音频样本相比,批处理提供了 10 倍的加速,并且 WER 1的损失不到 1%。
要启用批处理,请batch_size在实例化管道时传递参数:
from whisper_jax import FlaxWhisperPipline
# instantiate pipeline with batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
任务
默认情况下,管道以所说的语言转录音频文件。对于语音翻译,请将参数设置 task为"translate":
# translate
text = pipeline("audio.mp3", task="translate")
时间戳
FlaxWhisperPipline还支持时间戳预测。请注意,启用时间戳将需要对前向调用进行第二次 JIT 编译,包括时间戳输出:
# transcribe and return timestamps
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
text = outputs["text"] # transcription
chunks = outputs["chunks"] # transcription + timestamps
高级用法
更高级的用户可能希望探索不同的并行化技术。Whisper JAX 代码构建在T5x 代码库之上,这意味着它可以使用 T5x 分区约定的模型、激活和数据并行性来运行。要使用 T5x 分区,必须定义逻辑轴规则和模型分区数量。更多详细信息,用户可参考官方T5x分区指南:https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md
基准测试
我们将 Whisper JAX 与官方OpenAI 实现和🤗 Transformers 实现进行比较。我们在长度不断增加的音频样本上对模型进行基准测试,并报告 10 次重复运行的平均推理时间(以秒为单位)。对于所有三个系统,我们将预加载的音频文件传递给模型并测量前向传递的时间。将加载音频文件的任务留给系统会增加所有基准时间的相等偏移,因此加载和转录音频文件的实际时间将高于报告的数字。
OpenAI 和 Transformers 都在 GPU 上的 PyTorch 中运行。Whisper JAX 在 GPU 和 TPU 上的 JAX 中运行。OpenAI 按照说话的顺序依次转录音频。Transformers 和 Whisper JAX 都使用批处理算法,其中音频块被一起批处理并并行转录(请参阅批处理部分)。
表 :长度增加的音频文件的平均推理时间(以秒为单位)。GPU设备是单个A100 40GB GPU。TPU 设备是单个 TPU v4-8。