该存储库包含 OpenAI 的 Whisper 模型的优化 JAX代码,主要基于 Hugging Face Transformers Whisper 实现。与 OpenAI的 PyTorch 代码相比,Whisper JAX 的运行速度快了70 倍以上,使其成为可用的最快的 Whisper 运行。
JAX 代码在 CPU、GPU 和 TPU 上兼容,并且可以独立运行或作为推理端点。
Whisper JAX 使用 Python 3.9 和 JAX 版本 0.4.5 进行了测试。安装假定您的设备上已安装最新版本的 JAX 包。您可以使用官方 JAX 安装指南来执行此操作:
https://github.com/google/jax#installation
一旦安装了适当版本的 JAX,就可以通过 pip 安装 Whisper JAX
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
要将 Whisper JAX 包更新到最新版本,只需运行:
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
运行 Whisper JAX 的推荐方式是通过FlaxWhisperPipline抽象类。此类处理所有必要的预处理和后处理,并包装生成方法以实现跨加速器设备的数据并行性。
Whisper JAX 利用 JAX 的pmap功能实现跨 GPU/TPU 设备的数据并行性。该函数在第一次调用时进行即时 (JIT)编译。此后,该函数将被缓存,使其能够以超快的速度运行:
from whisper_jax import FlaxWhisperPipline # instantiate pipeline pipeline = FlaxWhisperPipline("openai/whisper-large-v2") # JIT compile the forward call - slow, but we only do once text = pipeline("audio.mp3") # used cached function thereafter - super fast!! text = pipeline("audio.mp3")
通过在实例化管道时传递 dtype 参数,可以以半精度运行模型计算。通过以半精度存储中间张量,这将大大加快计算速度。模型权重的精度没有变化。
对于大多数 GPU,dtype 应设置为jnp.float16. 对于 A100 GPU 或 TPU,dtype 应设置为jnp.bfloat16:
from whisper_jax import FlaxWhisperPipline import jax.numpy as jnp # instantiate pipeline in bfloat16 pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
Whisper JAX 还提供跨加速器设备批量处理单个音频输入的选项。音频首先被分成 30 秒的片段,然后将片段分派到模型进行并行转录。所得到的转录在边界处缝合在一起以给出单个、统一的转录。实际上,如果选择的批处理大小足够大,则与顺序转录音频样本相比,批处理提供了 10 倍的加速,并且 WER 1的损失不到 1%。
要启用批处理,请batch_size在实例化管道时传递参数:
from whisper_jax import FlaxWhisperPipline # instantiate pipeline with batching pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
默认情况下,管道以所说的语言转录音频文件。对于语音翻译,请将参数设置 task为"translate":
# translate text = pipeline("audio.mp3", task="translate")
FlaxWhisperPipline还支持时间戳预测。请注意,启用时间戳将需要对前向调用进行第二次 JIT 编译,包括时间戳输出:
# transcribe and return timestamps outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True) text = outputs["text"] # transcription chunks = outputs["chunks"] # transcription + timestamps
更高级的用户可能希望探索不同的并行化技术。Whisper JAX 代码构建在T5x 代码库之上,这意味着它可以使用 T5x 分区约定的模型、激活和数据并行性来运行。要使用 T5x 分区,必须定义逻辑轴规则和模型分区数量。更多详细信息,用户可参考官方T5x分区指南:https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md
我们将 Whisper JAX 与官方OpenAI 实现和🤗 Transformers 实现进行比较。我们在长度不断增加的音频样本上对模型进行基准测试,并报告 10 次重复运行的平均推理时间(以秒为 单位)。对于所有三个系统,我们将预加载的音频文件传递给模型并测量前向传递的时间。将加载音频文件的任务留给系统会增加所有基准时间的相等偏移,因此加载和转录音频文件的实际时间将高于报告的数字。
OpenAI 和 Transformers 都在 GPU 上的 PyTorch 中运行。Whisper JAX 在 GPU 和 TPU 上的 JAX 中运行。OpenAI 按照说话的顺序依次转录音频。Transformers 和 Whisper JAX 都使用批处理算法,其中音频块被一起批处理并并行转录(请参阅批处理部分)。
表 :长度增加的音频文件的平均推理时间(以秒为单位)。GPU设备是单个A100 40GB GPU。TPU 设备是单个 TPU v4-8。



阿里Qoder团队推出的桌面端AI智能体
QoderWork 是阿里推出的本地优先桌面 AI 智能体,适配 macOS14+/Windows10+,以自然语言交互实现文件管理、数据分析、AI 视觉生成、浏览器自动化等办公任务,自主拆解执行复杂工作流,数据本地运行零上传,技能市场可无限扩展,是高效的 Agentic 生产力办公助手。


全球首个AI音乐 社区
音述AI是全球首个AI音乐社区,致力让每个人都能用音乐表达自我。音述AI提供零门槛AI创作工具,独创GETI法则帮助用户精准定义音乐风格,AI润色功能支持自动优化作品质感。音述AI支持交流讨论、二次创作与价值变现。针对中文用户的语言习惯与文化背景进行专门优化,支持国风融合、C-pop等本土音乐标签,让技术更好地承载人文表达。


一站式搞定所有学习需求
不再被海量信息淹没,开始真正理解知识。Lynote 可摘要 YouTube 视频、PDF、文章等内容。即时创建笔记,检测 AI 内容并下载资料,将您的学习效率提升 10 倍。


为AI短剧协作而生
专为AI短剧协作而生的AniShort正式发布,深度重构AI短剧全流程生产模式,整合创意策划、制作执行、实时协作、在线审片、资产复用等全链路功能,独创无限画布、双轨并行工业化工作流与Ani智能体助手,集成多款主流AI大模型,破解素材零散、版本混乱、沟通低效等行业痛点,助力3人团队效率提升800%,打造标准化、可追溯的AI短剧量产体系,是AI短剧团队协同创作、提升制作效率的核心工具。


能听懂你表达的视频模型
Seedance two是基于seedance2.0的中国大模型,支持图像、视频、音频、文本四种模态输入,表达方式更丰富,生成也更可控。


国内直接访问,限时3折
输入简单文字,生成想要的图片,纳米香蕉中文站基于 Google 模型的 AI 图片生成网站,支持文字生图、图生图。官网价格限时3折活动


职场AI,就用扣子
AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!


多风格AI绘画神器
堆友平台由阿里巴巴设计团队创建,作为一款AI驱动的设计工具,专为设计师提供一站式增长服务。功能覆盖海量3D素材、AI绘画、实时渲染以及专业抠图,显著提升设计品质和效率。平台不仅提供工具,还是一个促进创意交流和个人发展的空间,界面友好,适合所有级别的设计师和创意工作者。


零代码AI应用开发平台
零代码AI应用开发平台,用户只需一句话简单描述需求,AI能自动生成小程序、APP或H5网页应用,无需编写代码。


免费创建高清无水印Sora视频
Vora是一个免费创建高清无水印Sora视频的AI工具
最新AI工具、AI资讯
独家AI资源、AI项目落地

微信扫一扫关注公众号