aw_nas

aw_nas

模块化设计实现多种NAS算法

aw_nas是一个模块化的神经架构搜索框架,实现了ENAS、DARTS等多种主流NAS算法。框架将NAS系统分解为搜索空间、控制器等组件,通过接口实现灵活组合。支持分类、检测等多种应用场景,并提供硬件分析接口。aw_nas采用插件机制便于扩展,已应用于容错性、对抗鲁棒性等研究方向。

NAS框架神经架构搜索模块化可扩展硬件相关Github开源项目

aw_nas:模块化可扩展的神经架构搜索框架

<p align="middle"> <img src="https://yellow-cdn.veclightyear.com/0a4dffa0/f35cd8d4-b16b-45c6-9e89-8f68f09d31c7.jpg" width="35%" hspace="30" /> <img src="https://yellow-cdn.veclightyear.com/0a4dffa0/68373265-9a2a-4230-bbaf-efb76f26db74.jpg" width="35%" hspace="20" /> </p>

清华大学NICS-EFC实验室北京诺瓦奥图科技有限公司维护。

简介

神经架构搜索(NAS)因其能够以自动化方式发现神经网络架构而受到广泛关注。aw_nas是一个以模块化方式实现各种NAS算法的框架。目前,aw_nas可用于重现许多主流NAS算法的结果,如ENAS、DARTS、SNAS、FBNet、OFA、基于预测器的NAS等。我们已将aw_nas应用于各种应用场景,包括用于分类、检测、文本建模、硬件容错、对抗鲁棒性、硬件推理效率等的NAS。

此外,硬件相关的性能分析和解析接口设计得通用且易用。aw_nas还提供了多种硬件的延迟表和一些校正模型。详情请参见硬件相关

欢迎各种贡献,包括新的NAS组件实现、新的NAS应用、错误修复、文档等。

NAS系统的组成部分

NAS系统中有多个相互协作的参与者,可以分为以下几个组成部分:

  • 搜索空间
  • 控制器
  • 权重管理器
  • 评估器
  • 目标函数

这些组件之间的接口是明确定义的。我们使用awnas.rollout.base.BaseRollout类来表示所有这些组件之间的接口对象。通常,一个搜索空间定义一个或多个rollout类型(BaseRollout的子类)。例如,基本的基于单元的搜索空间cnnawnas.common.CNNSearchSpace类)对应两种rollout类型:discrete离散rollout,用于基于强化学习、进化算法的控制器等(awnas.rollout.base.Rollout类);differentiable可微rollout,用于基于梯度的NAS(awnas.rollout.base.DifferentiableRollout类)。

NAS框架

这是NAS流程和相应方法调用的图示。这里是aw_nas的简要技术概述,包括一些复现结果和硬件成本预测模型的描述。该技术概述也可在arXiv上获取(GitHub/ArXiv版本可能略有不同)。

安装

建议使用虚拟Python环境。例如,使用Anaconda,你可以先运行conda create -n awnas python==3.7.3 pip

  • 支持的Python版本:2.7、3.6、3.7
  • 支持的PyTorch版本:>=1.0.0,<1.5.0(目前,DataParallel复制中的一些补丁在1.5.0之后不兼容)

要安装awnas,运行pip install -r requirements.txt。如果你不想安装检测相关的额外内容(运行在VOC/COCO检测数据集上搜索时需要),在安装时省略",det"额外内容(参见requirements文件的最后一行)。注意,对于RTX 3090,requirements.txt中的torch==1.2.0不再适用:使用torch会导致永久卡住。请查看requirements.cu110.txt中的注释。

架构绘图依赖于graphviz包,确保安装了graphviz,例如在Ubuntu上,你可以运行sudo apt-get install graphviz

使用

安装后,你可以运行awnas --help查看可用的子命令。

示例运行输出(版本0.3.dev3):

07/04 11:41:44 PM plugin              INFO: Check plugins under /home/foxfi/awnas/plugins
07/04 11:41:44 PM plugin              INFO: Loaded plugins:
Usage: awnas [OPTIONS] COMMAND [ARGS]...

  awnas NAS框架命令行接口。使用`AWNAS_LOG_LEVEL`环境变量修改日志级别。

Options:
  --version             显示版本并退出
  --local_rank INTEGER  此进程的等级  [默认:-1]
  --help                显示此消息并退出

Commands:
  search                   搜索架构
  mpsearch                 多进程搜索架构
  random-sample            随机采样架构
  sample                   采样架构,加载pickle控制器
  eval-arch                从文件评估架构
  derive                   派生架构
  mptrain                  多进程最终训练架构
  train                    训练一个架构
  test                     测试最终训练的模型
  gen-sample-config        导出采样配置
  gen-final-sample-config  导出最终训练的采样配置
  registry                 打印注册信息

准备数据

运行awnas程序时,它会假设名为<NAME>的数据集位于AWNAS_DATA/<NAME>下,其中AWNAS_DATA基础目录从环境变量AWNAS_DATA中读取。如果未指定环境变量,默认为AWNAS_HOME/data,其中AWNAS_HOME是默认为~/awnas的环境变量。

  • Cifar-10/Cifar-100:无需特殊准备。
  • PTB:执行bash scripts/get_data.sh ptb,PTB数据将下载到${DATA_BASE}/ptb目录下。默认情况下${DATA_BASE}~/awnas/data
  • Tiny-ImageNet:执行bash scripts/get_data.sh tiny-imagenet,Tiny-ImageNet数据将下载到${DATA_BASE}/tiny-imagenet目录下。
  • 目标检测数据集VOC/COCO:执行bash scripts/get_data.sh vocbash scripts/get_data.sh coco

运行NAS搜索

ENAS 尝试运行ENAS [Pham et. al., ICML 2018]搜索(结果包括配置备份、搜索日志,保存在<TRAIN_DIR>中):

awnas search examples/basic/enas.yaml --gpu 0 --save-every <SAVE_EVERY> --train-dir <TRAIN_DIR>

配置文件中包含几个部分,描述了NAS框架中不同组件的配置。例如,在example/basic/enas.yaml中,不同的配置部分组织如下:

  1. 基于单元的CNN搜索空间:这是原始ENAS论文中5个原语微搜索空间的扩展版本。
  2. cifar-10数据集
  3. 使用embed_lstm RNN网络的RL学习控制器
  4. 基于共享权重的评估器
  5. 基于共享权重的权重管理器:超网络
  6. 分类目标
  7. 训练器:整体NAS搜索流程的编排

有关ENAS搜索配置的详细分解,请参阅配置说明

DARTS 此外,你可以通过运行以下命令来执行DARTS [Liu et. al., ICLR 2018]搜索的改进版本:

awnas search examples/basic/darts.yaml --gpu 0 --save-every <SAVE_EVERY> --train-dir <TRAIN_DIR>

我们在这里提供了组件和流程的详细说明。请注意,该配置与原始DARTS略有不同:1) entropy_coeff: 0.01:使用0.01的熵正则化系数,鼓励操作分布更接近于one-hot;2) use_prob: false:使用Gumbel-softmax采样,而不是直接使用概率。

结果复现 关于各种流行方法的精确结果复现,请参阅examples/mloss/下的文档、配置和结果。

生成样例搜索配置

要生成用于搜索的样例配置文件,可以尝试使用awnas gen-sample-config工具。例如,如果你想要一个用于在NAS-Bench-101上搜索的样例配置,运行:

awnas gen-sample-config -r nasbench-101 -d image ./sample_nb101.yaml

然后,检查sample_nb101.yaml文件,对于每种组件类型,所有声明支持nasbench-101展开类型的类都会列在文件中。删除不需要的,取消注释需要的,更改默认设置,然后该配置就可以用于在NAS-Bench-101上运行NAS。

导出与评估架构

awnas derive工具使用训练好的NAS组件采样架构。如果--test标志关闭(默认),只加载控制器来采样展开;否则,还会加载权重管理器和训练器来测试这些展开,并根据性能对采样的基因型进行排序,保存在输出文件中。

示例运行是采样10个基因型,并将它们保存到sampled_genotypes.yaml中。

awnas derive search_cfg.yaml --load <awnas搜索期间保存的检查点目录> -o sampled_genotypes.yaml -n 10 --test --gpu 0 --seed 123

注意,<TRAIN_DIR>/<EPOCH>/文件夹中的"controller/evaluator/trainer"文件包含组件的状态字典,可以加载(每<SAVE_EVERY>个周期保存一次),而"<TRAIN_DIR>/final/"文件夹中的最终检查点"controller.pt/evaluator.pt"包含整个组件对象的pickle,不能直接加载。如果你忘记指定--save-every命令行参数而没有获得状态字典检查点,你可以加载最终检查点,然后通过cd <TRAIN_DIR>/final/; python -c "controller = torch.load('./controller.pt'); controller.save('controller')"导出所需的状态字典检查点。

awnas eval-arch工具使用训练好的NAS组件评估基因型。给定一个包含基因型列表的yaml文件,可以使用保存的NAS检查点评估这些基因型:

awnas eval-arch search_cfg.yaml sampled_genotypes.yaml --load <awnas搜索期间保存的检查点目录> --gpu 0 --seed 123

基于单元架构的最终训练

awnas.final 子包提供了基于单元的架构的最终训练功能。examples/basic/final_templates/final_template.yaml 是一个常用的配置模板,用于在类 ENAS 搜索空间中进行架构的最终训练。要使用该模板,请在 final_model_cfg.genotypes 字段中填入从搜索过程中得到的基因型字符串。基因型字符串示例如下:

CNNGenotype(normal_0=[('dil_conv_3x3', 1, 2), ('skip_connect', 1, 2), ('sep_conv_3x3', 0, 3), ('sep_conv_3x3', 2, 3), ('skip_connect', 3, 4), ('sep_conv_3x3', 0, 4), ('sep_conv_5x5', 1, 5), ('sep_conv_5x5', 0, 5)], reduce_1=[('max_pool_3x3', 0, 2), ('dil_conv_5x5', 0, 2), ('avg_pool_3x3', 1, 3), ('avg_pool_3x3', 2, 3), ('sep_conv_5x5', 1, 4), ('avg_pool_3x3', 1, 4), ('sep_conv_3x3', 1, 5), ('dil_conv_5x5', 3, 5)], normal_0_concat=[2, 3, 4, 5], reduce_1_concat=[2, 3, 4, 5])

插件机制

aw_nas 提供了一个简单的插件机制,支持在包外添加额外组件或扩展现有组件。在初始化过程中,~/awnas/plugins/ 目录下的所有 Python 脚本(文件名以 .py 结尾,不包括以 test_ 开头的文件)都会被导入。因此,这些文件中定义的组件将自动注册。

例如,为了复现 FBNet [Wu et. al., CVPR 2019],我们在 examples/plugins/fbnet/fbnet_plugin.py 中添加了 FBNet 原始块的实现,并使用 aw_nas.ops.register_primitive 注册这些原始操作。为了重用 DiffSuperNet 实现的大部分代码(用于 DARTS [Liu et. al., ICLR 2018]、SNAS [Xie et. al., ICLR 2018] 等),我们创建了一个继承自 DiffSuperNetWeightInitDiffSuperNet 类,唯一的区别是添加了一个为 FBNet 量身定制的权重初始化。此外,还实现了一个 LatencyObjective 目标函数,它将延迟损失和交叉熵损失的加权和作为损失计算。

examples/plugins/robustness 目录下是用于实现对抗鲁棒性神经架构搜索的插件模块。例如,定义了各种用于评估对抗鲁棒性的目标函数。由于密集连接是对抗鲁棒性的一个重要特性,而 ENAS/DARTS 搜索空间将节点输入度限制为小于或等于 2,因此定义了一个具有可变节点输入度的新搜索空间。实现了几个具有对抗样本缓存的超网络(weights_manager),以避免多次为同一子网络重新生成对抗样本。

除了定义新组件外,你还可以使用这种机制来进行猴子补丁技巧。例如,在 examples/research/ftt-nas/fixed_point_plugins/ 下有各种定点插件。在这些插件中,诸如 nn.Conv2dnn.Linear 等原始操作被修补为具有量化和故障注入功能的模块。

硬件相关:硬件分析和解析

有关硬件分析和解析的流程和示例,请参阅 Hardware related

开发新组件

有关开发新组件的指南,请参阅 Develop New Components

研究

本代码库与以下研究相关(*: 贡献相同; ^: 共同通讯作者)

更多详情请参见examples/research/下的子目录。

如果您发现本代码库有帮助,可以引用以下研究:

@misc{ning2020awnas,
      title={aw_nas: A Modularized and Extensible NAS framework},
      author={Xuefei Ning and Changcheng Tang and Wenshuo Li and Songyi Yang and Tianchen Zhao and Niansong Zhang and Tianyi Lu and Shuang Liang and Huazhong Yang and Yu Wang},
      year={2020},
      eprint={2012.10388},
      archivePrefix={arXiv},
      primaryClass={cs.NE}
}

参考文献

  • FBNet Wu, Bichen等人。"FBNet:通过可微分神经架构搜索进行硬件感知的高效卷积网络设计"。发表于IEEE计算机视觉与模式识别会议论文集,第10734-10742页。2019年。
  • ENAS Pham, Hieu等人。"通过参数共享实现高效神经架构搜索"。发表于国际机器学习会议,第4095-4104页。2018年。
  • DARTS Liu, Hanxiao等人。"DARTS:可微分架构搜索"。发表于国际学习表示会议。2018年。
  • SNAS Xie, Sirui等人。"SNAS:随机神经架构搜索"。发表于国际学习表示会议。2018年。
  • OFA Cai, Han等人。"一劳永逸:训练一个网络并针对高效部署进行专门化"。发表于国际学习表示会议。2019年。

单元测试

覆盖率百分比(版本0.4.0-dev1)

运行pytest -x ./tests来执行单元测试。

NAS-Bench-101和NAS-Bench-201的测试默认被跳过,设置AWNAS_TEST_NASBENCH环境变量并运行pytest来执行这些测试:AWNAS_TEST_NASBENCH=1 pytest -x ./tests/test_nasbench*。还有一些其他测试由于可能非常耗时而被跳过(参见测试输出(标记为"s")和tests/下的测试用例)。

联系我们

  • 如有技术问题或改进建议,请在Github上提交问题,我们是一个小团队,但会尽最大努力及时回复。
  • 如果想讨论NAS或高效深度学习,请通过foxdoraame@gmail.com(宁学飞)和yu-wang@tsinghua.edu.cn(王玉)联系我们。
  • 我们的团队正在招募访问学生和工程师,如果您感兴趣,请查看我们网站上的信息。

编辑推荐精选

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

下拉加载更多