torch-imle

torch-imle

将离散优化算法融入深度学习的创新方法

torch-imle是一个PyTorch库,通过I-MLE梯度估计器将离散优化算法融入深度学习。它使用创新的采样和分布方法,实现了离散优化问题在深度学习中的应用,如最短路径学习。该库采用Perturb-and-MAP方法和新颖的噪声扰动来近似采样复杂分布,并提供替代经验分布。torch-imle通过梯度下降学习最优路径权重,为深度学习中的离散优化问题提供强大的解决方案。

I-MLE深度学习梯度估计组合优化PyTorchGithub开源项目

torch-imle

这是一个简洁且独立的PyTorch库,实现了我们在NeurIPS 2021论文《Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions》中提出的I-MLE梯度估计器。

该仓库包含一个库,用于将任何组合黑盒求解器转换为可微分层。NeurIPS论文中所有实验的复现代码都可在NEC欧洲实验室的官方仓库中找到。

概述

隐式最大似然估计(I-MLE)使得在标准深度学习架构中包含离散组合优化算法(如Dijkstra算法或整数线性规划求解器)成为可能。I-MLE的核心思想是定义一个隐式的最大似然目标,其梯度用于更新模型的上游参数。每个I-MLE实例需要两个要素:

  1. 一种从复杂且难以处理的分布中近似采样的方法,该分布由组合求解器在解空间上诱导,其中最优解具有最高的概率质量。为此,我们使用扰动映射(又称Gumbel-max技巧),并提出了一种针对特定问题的新型噪声扰动家族。

  2. 一种计算替代经验分布的方法:普通MLE减少当前分布与经验分布之间的KL散度。由于在我们的设置中无法获得经验分布,我们必须设计替代经验分布。这里我们提出了两种广泛适用且实践效果良好的替代分布家族。

示例

例如,让我们考虑一个简单游戏的地图,任务是找到从左上角到右下角的最短路径。较暗的区域成本较高,较亮的区域成本较低。 在中间,你可以看到当我们使用提出的伽玛噪声分布之和来采样路径时会发生什么。 在右侧,你可以看到每个格子的边际概率结果(每个格子成为采样路径一部分的概率)。

[图片1] [图片2] [图片3]

梯度和学习

假设最优最短路径是左侧的路径。 从随机权重开始,模型可以通过梯度下降学习产生将导致最优最短路径的权重,方法是最小化生成路径与黄金路径之间的汉明损失。 这里我们展示了训练过程中产生的路径(中间)和相应的地图权重(右侧)。

输入噪声温度设置为0.0,目标噪声温度设置为0.0

[图片4] [图片5] [图片6]

输入噪声温度设置为1.0,目标噪声温度设置为1.0

[图片7] [图片8] [图片9]

输入噪声温度设置为2.0,目标噪声温度设置为2.0

[图片10] [图片11] [图片12]

输入噪声温度设置为5.0,目标噪声温度设置为5.0

[图片13] [图片14] [图片15]

输入噪声温度设置为5.0,目标噪声温度设置为0.0

[图片16] [图片17] [图片18]

所有动画都由这个脚本生成。

代码

使用这个库非常简单 -- 请参考这个示例。假设我们有一个实现黑盒组合求解器(如Dijkstra算法)的方法:

import numpy as np import torch from torch import Tensor def torch_solver(weights_batch: Tensor) -> Tensor: weights_batch = weights_batch.detach().cpu().numpy() y_batch = np.asarray([solver(w) for w in list(weights_batch)]) return torch.tensor(y_batch, requires_grad=False)

我们可以通过以下方式获得相应的分布和梯度:

from imle.wrapper import imle from imle.target import TargetDistribution from imle.noise import SumOfGammaNoiseDistribution target_distribution = TargetDistribution(alpha=0.0, beta=10.0) noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100) def torch_solver(weights_batch: Tensor) -> Tensor: weights_batch = weights_batch.detach().cpu().numpy() y_batch = np.asarray([solver(w) for w in list(weights_batch)]) return torch.tensor(y_batch, requires_grad=False) imle_solver = imle(torch_solver, target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=10, input_noise_temperature=input_noise_temperature, target_noise_temperature=target_noise_temperature)

或者,使用简单的函数注解:

@imle(target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=10, input_noise_temperature=input_noise_temperature, target_noise_temperature=target_noise_temperature) def imle_solver(weights_batch: Tensor) -> Tensor: return torch_solver(weights_batch)

使用I-MLE的论文

  • Patrick Betz, Mathias Niepert, Pasquale Minervini, 和 Heiner Stuckenschmidt:《Backpropagating through Markov Logic Networks》,NeSy'20/21 @ IJCLR:第15届神经符号学习与推理国际研讨会

参考文献

@inproceedings{niepert21imle, author = {Mathias Niepert and Pasquale Minervini and Luca Franceschi}, title = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family Distributions}, booktitle = {NeurIPS}, series = {Proceedings of Machine Learning Research}, publisher = {{PMLR}}, year = {2021} }

编辑推荐精选

Vora

Vora

免费创建高清无水印Sora视频

Vora是一个免费创建高清无水印Sora视频的AI工具

Refly.AI

Refly.AI

最适合小白的AI自动化工作流平台

无需编码,轻松生成可复用、可变现的AI自动化工作流

酷表ChatExcel

酷表ChatExcel

大模型驱动的Excel数据处理工具

基于大模型交互的表格处理系统,允许用户通过对话方式完成数据整理和可视化分析。系统采用机器学习算法解析用户指令,自动执行排序、公式计算和数据透视等操作,支持多种文件格式导入导出。数据处理响应速度保持在0.8秒以内,支持超过100万行数据的即时分析。

AI工具使用教程AI营销产品酷表ChatExcelAI智能客服
TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

热门AI工具生产力协作转型TraeAI IDE
AIWritePaper论文写作

AIWritePaper论文写作

AI论文写作指导平台

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

数据安全AI助手热门AI工具AI辅助写作AI论文工具论文写作智能生成大纲
博思AIPPT

博思AIPPT

AI一键生成PPT,就用博思AIPPT!

博思AIPPT,新一代的AI生成PPT平台,支持智能生成PPT、AI美化PPT、文本&链接生成PPT、导入Word/PDF/Markdown文档生成PPT等,内置海量精美PPT模板,涵盖商务、教育、科技等不同风格,同时针对每个页面提供多种版式,一键自适应切换,完美适配各种办公场景。

热门AI工具AI办公办公工具智能排版AI生成PPT博思AIPPT海量精品模板AI创作
潮际好麦

潮际好麦

AI赋能电商视觉革命,一站式智能商拍平台

潮际好麦深耕服装行业,是国内AI试衣效果最好的软件。使用先进AIGC能力为电商卖家批量提供优质的、低成本的商拍图。合作品牌有Shein、Lazada、安踏、百丽等65个国内外头部品牌,以及国内10万+淘宝、天猫、京东等主流平台的品牌商家,为卖家节省将近85%的出图成本,提升约3倍出图效率,让品牌能够快速上架。

iTerms

iTerms

企业专属的AI法律顾问

iTerms是法大大集团旗下法律子品牌,基于最先进的大语言模型(LLM)、专业的法律知识库和强大的智能体架构,帮助企业扫清合规障碍,筑牢风控防线,成为您企业专属的AI法律顾问。

SimilarWeb流量提升

SimilarWeb流量提升

稳定高效的流量提升解决方案,助力品牌曝光

稳定高效的流量提升解决方案,助力品牌曝光

Sora2��视频免费生成

Sora2视频免费生成

最新版Sora2模型免费使用,一键生成无水印视频

最新版Sora2模型免费使用,一键生成无水印视频

下拉加载更多