这是使用Patchout高效训练音频Transformer的实现
Patchout显著减少了在音频频谱图上训练Transformer所需的训练时间和GPU内存需求,同时提高了它们的性能。
<p align="center"><img src="https://yellow-cdn.veclightyear.com/0a4dffa0/2662410b-bd7b-4ff7-a79c-f2bf03186b40.png?raw=true" width="600"/></p>Patchout通过在训练过程中丢弃一些输入块来实现。 可以是非结构化方式(随机丢弃,类似于dropout), 也可以是提取的块的整个时间帧或频率bin(类似于SpecAugment), 这对应于下图步骤3中的行/列。
<p align="center"><img src="https://yellow-cdn.veclightyear.com/0a4dffa0/2b3fcda0-f351-4216-a0b5-20f302ce3050.png?raw=true" width="600"/></p>如果你只想使用预训练模型生成的嵌入,使用自己的微调框架,或只需要进行推理,可以在这里找到该repo的精简版本。 该包遵循HEAR 2021 NeurIPS Challenge API,可以通过以下方式安装:
pip install hear21passt
这个repo是一个完整的框架,用于训练模型并在下游任务上微调Audioset的预训练模型。
from hear21passt.base import get_basic_model,get_model_passt import torch # 获取PaSST模型包装器,包括Melspectrogram和默认预训练的transformer model = get_basic_model(mode="logits") print(model.mel) # 从原始波形中提取mel频谱图 print(model.net) # transformer网络 # 推理示例 model.eval() model = model.cuda() with torch.no_grad(): # audio_wave的形状为[batch, seconds*32000],采样率为32k # 示例:batch=3,10秒的音频 audio = torch.ones((3, 32000 * 10))*0.5 audio_wave = audio.cuda() logits=model(audio_wave)
from hear21passt.base import get_basic_model,get_model_passt import torch # 获取PaSST模型包装器,包括Melspectrogram和默认预训练的transformer model = get_basic_model(mode="logits") print(model.mel) # 从原始波形中提取mel频谱图 # 可选:将transformer替换为具有所需类别数的transformer,例如50个类别 model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50) print(model.net) # transformer网络 # 现在model包含mel + 预训练的transformer模型,可以进行微调 # 它仍然期望输入的形状为[batch, seconds*32000],采样率为32k model.train() model = model.cuda()
如果你想使用与论文中相同的环境,可以按照以下说明进行操作。
对于从头开始训练模型或使用与论文相同的设置进行微调:
conda create -n passt python=3.8 conda activate passt
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt
或者,你可以使用导出的conda环境environment.yml
来创建环境。
建议使用Mamba进行设置,因为它比conda
工作得更快:
conda install mamba -n base -c conda-forge
现在你可以从environment.yml
导入环境
mamba env create -f environment.yml
现在你有了一个名为ba3l
的环境。
为了检查你的环境是否与我们在运行中使用的环境匹配,请查看environment.yml
和pip_list.txt
文件,这些文件是使用以下命令导出的:
conda env export --no-builds | grep -v "prefix" > environment.yml pip list > pip_list.txt
如果你想使用自己的设置,只使用这个repo中的模型,你可以如上所述用于推理和提取嵌入的预训练模型获取模型,从头开始训练或在自己的数据集上进行微调。本节的其余部分解释了如何使用这个repo来训练和微调模型。为此,首先需要按照上面的说明设置开发环境。
该存储库使用sacred进行实验管理和配置, 使用pytorch-lightning进行训练,使用wandb进行日志记录。
每个数据集都有一个主要的实验文件,如ex_audioset.py
和ex_openmic.py
,以及一个数据集文件夹。实验文件包含主要的训练和验证逻辑。数据集文件夹包含下载、预处理和加载数据集进行训练所需的特定代码。
通常,你可以通过以下命令获取实验文件的帮助信息,这将打印可用的命令和基本选项:
python ex_audioset.py help
每个实验都有一组默认的配置选项,定义在实验文件中,例如ex_audioset.py
。你可以使用sacred语法覆盖任何配置。你可以使用print_config
命令打印配置值,而不训练模型:
python ex_audioset.py print_config
然后你可以使用命令行界面覆盖任何配置选项(sacred语法),使用with
,例如:
python ex_audioset.py with trainer.precision=16
这将使用16位精度在Audioset上进行训练。
整体配置如下所示:
... seed = 542198583 # 此实验的随机种子 slurm_job_id = '' speed_test_batch_size = 100 swa = True swa_epoch_start = 50 swa_freq = 5 use_mixup = True warm_up_len = 5 weight_decay = 0.0001 basedataset: base_dir = 'audioset_hdf5s/' # 数据集的基础目录 ,更改它或创建一个链接 eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf' wavmix = 1 .... roll_conf: axis = 1 shift = None shift_range = 50 datasets: test: batch_size = 20 dataset = {CMD!}'/basedataset.get_test_set' num_workers = 16 validate = True training: batch_size = 12 dataset = {CMD!}'/basedataset.get_full_training_set' num_workers = 16 sampler = {CMD!}'/basedataset.get_ft_weighted_sampler' shuffle = None train = True models: mel: freqm = 48 timem = 192 hopsize = 320 htk = False n_fft = 1024 n_mels = 128 norm = 1 sr = 32000 ... net: arch = 'passt_s_swa_p16_128_ap476' fstride = 10 in_channels = 1 input_fdim = 128 input_tdim = 998 n_classes = 527 s_patchout_f = 4 s_patchout_t = 40 tstride = 10 u_patchout = 0 ... trainer: accelerator = None accumulate_grad_batches = 1 amp_backend = 'native' amp_level = 'O2' auto_lr_find = False auto_scale_batch_size = False ...
有很多内容可以从命令行更新。 简而言之:
trainer
下的所有配置选项都是pytorch lightning训练器的api。例如,要关闭cuda基准测试,在命令行中添加trainer.benchmark=False
。wandb
是wandb配置。例如,要更改wandb项目,在命 令行中添加wandb.project="test_project"
。models.net
是PaSST(或选择的神经网络)的选项。例如:models.net.u_patchout
、models.net.s_patchout_f
、models.net.s_patchout_t
控制非结构化patchout和在频率和时间上的结构化patchout。input_fdim
和input_tdim
是输入频谱图在频率和时间上的维度。models.net.fstride
和models.net.tstride
是输入patches在频率和时间上的步幅,将这些设置为16意味着没有patch重叠。models.mel
是预处理选项(梅尔频谱图)。mel.sr
是采样率,mel.hopsize
是STFT窗口的跳跃大小,mel.n_mels
是梅尔滤波器组的数量,mel.freqm
和mel.timem
是spec-augment的频率和时间掩蔽参数。在config_updates.py
中有许多预定义的配置包(称为named_configs)。这些包括不同的模型、设置等...
你可以使用以下命令列出这些配置:
python ex_audioset.py print_named_configs
例如,passt_s_20sec
是一个配置包,它将模型设置为在Audioset上预训练的PaSST-S,并接受最长20秒的片段。
按照audioset页面中的说明下载和准备数据集。
例如,可以这样训练基础PaSST模型:
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p
例如,仅使用400的非结构化patchout:
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 models.net.u_patchout=400 models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p
通过设置环境变量DDP
可以启用多GPU训练,例如使用2个GPU:
DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -c "PaSST base 2 GPU"
请查看发布页面以下载预训练模型。 通常,你可以获取在Audioset上预训练的模型
from models.passt import get_model model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1, fstride=10, tstride=10,input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=40, s_patchout_f=4)
这将自动下载在 Audioset 上预训练的 PaSST 模型,其 mAP 为 0.476
。该模型使用 s_patchout_t=40, s_patchout_f=4
进行训练,但你可以根据你的任务或计算需求来调整这些参数。
有几个可用的预训练模型,它们具有不同的步长(重叠)并且使用/不使用 SWA:passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470
。
例如,在 passt_s_swa_p16_s16_128_ap473
中:p16
表示补丁大小为 16x16
,s16
表示无重叠(步长=16),128 个梅尔频带,ap473
指的是该模型在 Audioset 上的性能 mAP=0.479。
通常,你可以使用以下方式获取预训练模型:
from models.passt import get_model passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
使用该框架,你可以通过以下方式评估此模型:
python ex_audioset.py evaluate_only with trainer.precision=16 passt_s_swa_p16_s16_128_ap473 -p
这些模型的集成也已提供:
一个大型集成模型,其 mAP=.4956
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
一个由 stride=14
和 stride=16
的两个模型组成的集成,其 mAP=.4858
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
还有其他集成模型 ensemble_4
,ensemble_5
引用 Interspeech 2022 接受的论文:
@inproceedings{koutini22passt, author = {Khaled Koutini and Jan Schl{\"{u}}ter and Hamid Eghbal{-}zadeh and Gerhard Widmer}, title = {Efficient Training of Audio Transformers with Patchout}, booktitle = {Interspeech 2022, 23rd Annual Conference of the International Speech Communication Association, Incheon, Korea, 18-22 September 2022}, pages = {2753--2757}, publisher = {{ISCA}}, year = {2022}, url = {https://doi.org/10.21437/Interspeech.2022-227}, doi = {10.21437/Interspeech.2022-227}, }
该仓库将会更新,同时如果 你有任何问题或遇到任何问题,请随时在 GitHub 上开启一个 issue,或直接联系作者。
字节跳动发布的AI编程神器IDE
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升 开发效率的理想工具。
全能AI智能助手,随时解答生活与工作的多样问题
问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活 中提高效率,轻松管理个人事务。
实时语音翻译/同声传译工具
Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有 需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科 研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。
AI助力,做PPT更简单!
咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导 入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。
选题、配图、成文,一站式创作,让内容运营更高效
讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。
专业的AI公文写作平台,公文写作神器
AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。
OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。
openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号