torch-imle

torch-imle

将离散优化算法融入深度学习的创新方法

torch-imle是一个PyTorch库,通过I-MLE梯度估计器将离散优化算法融入深度学习。它使用创新的采样和分布方法,实现了离散优化问题在深度学习中的应用,如最短路径学习。该库采用Perturb-and-MAP方法和新颖的噪声扰动来近似采样复杂分布,并提供替代经验分布。torch-imle通过梯度下降学习最优路径权重,为深度学习中的离散优化问题提供强大的解决方案。

I-MLE深度学习梯度估计组合优化PyTorchGithub开源项目

torch-imle

这是一个简洁且独立的PyTorch库,实现了我们在NeurIPS 2021论文《Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions》中提出的I-MLE梯度估计器。

该仓库包含一个库,用于将任何组合黑盒求解器转换为可微分层。NeurIPS论文中所有实验的复现代码都可在NEC欧洲实验室的官方仓库中找到。

概述

隐式最大似然估计(I-MLE)使得在标准深度学习架构中包含离散组合优化算法(如Dijkstra算法或整数线性规划求解器)成为可能。I-MLE的核心思想是定义一个隐式的最大似然目标,其梯度用于更新模型的上游参数。每个I-MLE实例需要两个要素:

  1. 一种从复杂且难以处理的分布中近似采样的方法,该分布由组合求解器在解空间上诱导,其中最优解具有最高的概率质量。为此,我们使用扰动映射(又称Gumbel-max技巧),并提出了一种针对特定问题的新型噪声扰动家族。

  2. 一种计算替代经验分布的方法:普通MLE减少当前分布与经验分布之间的KL散度。由于在我们的设置中无法获得经验分布,我们必须设计替代经验分布。这里我们提出了两种广泛适用且实践效果良好的替代分布家族。

示例

例如,让我们考虑一个简单游戏的地图,任务是找到从左上角到右下角的最短路径。较暗的区域成本较高,较亮的区域成本较低。 在中间,你可以看到当我们使用提出的伽玛噪声分布之和来采样路径时会发生什么。 在右侧,你可以看到每个格子的边际概率结果(每个格子成为采样路径一部分的概率)。

[图片1] [图片2] [图片3]

梯度和学习

假设最优最短路径是左侧的路径。 从随机权重开始,模型可以通过梯度下降学习产生将导致最优最短路径的权重,方法是最小化生成路径与黄金路径之间的汉明损失。 这里我们展示了训练过程中产生的路径(中间)和相应的地图权重(右侧)。

输入噪声温度设置为0.0,目标噪声温度设置为0.0

[图片4] [图片5] [图片6]

输入噪声温度设置为1.0,目标噪声温度设置为1.0

[图片7] [图片8] [图片9]

输入噪声温度设置为2.0,目标噪声温度设置为2.0

[图片10] [图片11] [图片12]

输入噪声温度设置为5.0,目标噪声温度设置为5.0

[图片13] [图片14] [图片15]

输入噪声温度设置为5.0,目标噪声温度设置为0.0

[图片16] [图片17] [图片18]

所有动画都由这个脚本生成。

代码

使用这个库非常简单 -- 请参考这个示例。假设我们有一个实现黑盒组合求解器(如Dijkstra算法)的方法:

import numpy as np import torch from torch import Tensor def torch_solver(weights_batch: Tensor) -> Tensor: weights_batch = weights_batch.detach().cpu().numpy() y_batch = np.asarray([solver(w) for w in list(weights_batch)]) return torch.tensor(y_batch, requires_grad=False)

我们可以通过以下方式获得相应的分布和梯度:

from imle.wrapper import imle from imle.target import TargetDistribution from imle.noise import SumOfGammaNoiseDistribution target_distribution = TargetDistribution(alpha=0.0, beta=10.0) noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100) def torch_solver(weights_batch: Tensor) -> Tensor: weights_batch = weights_batch.detach().cpu().numpy() y_batch = np.asarray([solver(w) for w in list(weights_batch)]) return torch.tensor(y_batch, requires_grad=False) imle_solver = imle(torch_solver, target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=10, input_noise_temperature=input_noise_temperature, target_noise_temperature=target_noise_temperature)

或者,使用简单的函数注解:

@imle(target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=10, input_noise_temperature=input_noise_temperature, target_noise_temperature=target_noise_temperature) def imle_solver(weights_batch: Tensor) -> Tensor: return torch_solver(weights_batch)

使用I-MLE的论文

  • Patrick Betz, Mathias Niepert, Pasquale Minervini, 和 Heiner Stuckenschmidt:《Backpropagating through Markov Logic Networks》,NeSy'20/21 @ IJCLR:第15届神经符号学习与推理国际研讨会

参考文献

@inproceedings{niepert21imle, author = {Mathias Niepert and Pasquale Minervini and Luca Franceschi}, title = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family Distributions}, booktitle = {NeurIPS}, series = {Proceedings of Machine Learning Research}, publisher = {{PMLR}}, year = {2021} }

编辑推荐精选

商汤小浣熊

商汤小浣熊

最强AI数据分析助手

小浣熊家族Raccoon,您的AI智能助手,致力于通过先进的人工智能技术,为用户提供高效、便捷的智能服务。无论是日常咨询还是专业问题解答,小浣熊都能以快速、准确的响应满足您的需求,让您的生活更加智能便捷。

imini AI

imini AI

像人一样思考的AI智能体

imini 是一款超级AI智能体,能根据人类指令,自主思考、自主完成、并且交付结果的AI智能体。

Keevx

Keevx

AI数字人视频创作平台

Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。

即梦AI

即梦AI

一站式AI创作平台

提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作

扣子-AI办公

扣子-AI办公

AI办公助手,复杂任务高效处理

AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!

TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI辅助写作AI工具蛙蛙写作AI写作工具学术助手办公助手营销助手AI助手
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
下拉加载更多