pytorch-frame

pytorch-frame

模块化深度学习框架用于异构表格数据

PyTorch Frame是一个为异构表格数据设计的深度学习框架,支持数值、分类、时间、文本和图像等多种列类型。它采用模块化架构,实现了先进的深度表格模型,并可与大型语言模型集成。该框架提供了便捷的mini-batch加载器、基准数据集和自定义数据接口,简化了表格数据的深度学习研究过程,适用于各层次研究人员。框架内置多个预实现的深度表格模型,如Trompt、FTTransformer和TabNet等,并提供与XGBoost等GBDT模型的性能对比基准。PyTorch Frame无缝集成于PyTorch生态系统,便于与其他PyTorch库协同使用,为端到端的深度学习研究提供了便利。

PyTorch Frame深度学习表格数据神经网络模块化框架Github开源项目
<div align="center"> <img height="175" src="https://yellow-cdn.veclightyear.com/835a84d5/e99ea417-9818-4fd1-90ff-62febdad0f2c.png?sanitize=true" /> <br> <br>

一个用于在异构表格数据上构建神经网络模型的模块化深度学习框架。


arXiv PyPI 版本 测试状态 文档状态 贡献 Slack

</div>

文档 | 论文

PyTorch Frame 是 PyTorch 的深度学习扩展,专为包含数值、类别、时间、文本和图像等不同列类型的异构表格数据设计。它提供了一个模块化框架,用于实现现有和未来的方法。该库包含最先进模型的方法、用户友好的小批量加载器、基准数据集以及自定义数据集成接口。

PyTorch Frame 让表格数据的深度学习研究变得更加普及,既适合新手也适合专家。我们的目标是:

  1. 促进表格数据的深度学习: 历史上,基于树的模型(如 GBDT)在表格学习方面表现出色,但存在一些明显的局限性,例如与下游模型的集成困难,以及处理复杂列类型(如文本、序列和嵌入)的能力。深度表格模型有望解决这些局限性。我们的目标是通过模块化实现并支持多样的列类型,来促进表格数据的深度学习研究。

  2. 与大型语言模型等多样化模型架构集成: PyTorch Frame 支持与各种不同的架构集成,包括大型语言模型。使用任何下载的模型或嵌入 API 端点,您可以为文本数据生成嵌入,并与其他复杂语义类型一起使用深度学习模型进行训练。我们支持以下(但不限于):

<table> <tr> <td align="center"> <a href="https://platform.openai.com/docs/guides/embeddings"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/9d61a92f-d8ab-4ad7-bcc4-a686bb6d9c83.png" alt="OpenAI" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">OpenAI 嵌入代码示例</a> </td> <td align="center"> <a href="https://cohere.com/embeddings"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/a4f0154b-4636-4eac-8df0-e46159b868d1.png" alt="Cohere" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Cohere Embed v3 代码示例</a> </td> <td align="center"> <a href="https://huggingface.co/"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/1adfecf3-86dd-4b67-add8-60110e8ebd55.png" alt="Hugging Face" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py">Hugging Face 代码示例</a> </td> <td align="center"> <a href="https://www.voyageai.com/"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/079d0720-c739-4596-9358-cdfbf1d57324.webp" alt="Voyage AI" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Voyage AI 代码示例</a> </td> </tr> </table> <hr style="border: 0.5px solid #ccc;">

库亮点

PyTorch Frame 直接构建在 PyTorch 之上,确保现有 PyTorch 用户能够顺利过渡。主要特性包括:

  • 多样化列类型: PyTorch Frame 支持跨各种列类型的学习:numericalcategoricalmulticategoricaltext_embeddedtext_tokenizedtimestampimage_embeddedembedding。详细教程请参见此处
  • 模块化模型设计: 支持模块化深度学习模型实现,促进代码重用、清晰编码和实验灵活性。更多详情请参见架构概览
  • 模型 实现了许多最先进的深度表格模型以及强大的 GBDT(XGBoost、CatBoost 和 LightGBM),并支持超参数调优。
  • 数据集: 提供了一系列可直接使用的表格数据集。同时支持自定义数据集以解决您自己的问题。 我们对深度表格模型与 GBDT 进行了基准测试
  • PyTorch 集成: 与其他 PyTorch 库无缝集成,便于 PyTorch Frame 与下游 PyTorch 模型的端到端训练。例如,通过与 PyTorch 图神经网络库 PyG 集成,我们可以对关系数据库进行深度学习。更多信息请参见 RelBench示例代码(进行中)

架构概览

PyTorch Frame 中的模型遵循 FeatureEncoderTableConvDecoder 的模块化设计,如下图所示:

<p align="center"> <img width="50%" src="https://yellow-cdn.veclightyear.com/835a84d5/c3afa3c9-b2bc-4b15-8964-dc3ce4371827.png" /> </p>

本质上,这种模块化设置使用户能够轻松尝试各种架构:

  • Materialization 处理将原始 pandas DataFrame 转换为适合 PyTorch 训练和建模的 TensorFrame
  • FeatureEncoderTensorFrame 编码为大小为 [batch_size, num_cols, channels] 的隐藏列嵌入。
  • TableConv 对隐藏嵌入进行列间交互建模。
  • Decoder 为每行生成嵌入/预测。

快速上手

在这个快速上手中,我们将展示如何仅用几行代码就能轻松创建和训练一个深度表格模型。

构建并训练您自己的深度表格模型

作为示例,我们将按照 PyTorch Frame 的模块化架构实现一个简单的 ExampleTransformer。 在下面的示例中:

  • self.encoder 将输入的 TensorFrame 映射到大小为 [batch_size, num_cols, channels] 的嵌入。
  • self.convs 迭代地将大小为 [batch_size, num_cols, channels] 的嵌入转换为相同大小的嵌入。
  • self.decoder 将大小为 [batch_size, num_cols, channels] 的嵌入池化为 [batch_size, out_channels]
from torch import Tensor from torch.nn import Linear, Module, ModuleList from torch_frame import TensorFrame, stype from torch_frame.nn.conv import TabTransformerConv from torch_frame.nn.encoder import ( EmbeddingEncoder, LinearEncoder, StypeWiseFeatureEncoder, ) class ExampleTransformer(Module): def __init__( self, channels, out_channels, num_layers, num_heads, col_stats, col_names_dict, ): super().__init__() self.encoder = StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict={ stype.categorical: EmbeddingEncoder(), stype.numerical: LinearEncoder() }, ) self.convs = ModuleList([ TabTransformerConv( channels=channels, num_heads=num_heads, ) for _ in range(num_layers) ]) self.decoder = Linear(channels, out_channels) def forward(self, tf: TensorFrame) -> Tensor: x, _ = self.encoder(tf) for conv in self.convs: x = conv(x) out = self.decoder(x.mean(dim=1)) return out

要准备数据,我们可以快速实例化一个预定义的数据集并创建一个与 PyTorch 兼容的数据加载器,如下所示:

from torch_frame.datasets import Yandex from torch_frame.data import DataLoader dataset = Yandex(root='/tmp/adult', name='adult') dataset.materialize() train_dataset = dataset[:0.8] train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128, shuffle=True)

然后,我们只需按照<a href="https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation">标准PyTorch训练流程</a>来优化模型参数。就这么简单!

import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ExampleTransformer( channels=32, out_channels=dataset.num_classes, num_layers=2, num_heads=8, col_stats=train_dataset.col_stats, col_names_dict=train_dataset.tensor_frame.col_names_dict, ).to(device) optimizer = torch.optim.Adam(model.parameters()) for epoch in range(50): for tf in train_loader: tf = tf.to(device) pred = model.forward(tf) loss = F.cross_entropy(pred, tf.y) optimizer.zero_grad() loss.backward()

已实现的深度表格模型

以下是目前支持的深度表格模型列表:

此外,我们还为想要将模型性能与GBDT进行比较的用户实现了XGBoostCatBoostLightGBM示例,这些示例使用Optuna进行超参数调优。

基准测试

我们在多种规模和任务类型的公共数据集上对最近的表格深度学习模型与GBDT进行了基准测试。

下图展示了各种模型在小型回归数据集上的性能,其中行代表模型名称,列代表数据集索引(这里我们有13个数据集)。有关分类和更大数据集的更多结果,请查看基准测试文档

模型名称数据集0数据集1数据集2数据集3数据集4数据集5数据集6数据集7数据集8数据集9数据集10数据集11数据集12
XGBoost0.250±0.0000.038±0.0000.187±0.0000.475±0.0000.328±0.0000.401±0.0000.249±0.0000.363±0.0000.904±0.0000.056±0.0000.820±0.0000.857±0.0000.418±0.000
CatBoost0.265±0.0000.062±0.0000.128±0.0000.336±0.0000.346±0.0000.443±0.0000.375±0.0000.273±0.0000.881±0.0000.040±0.0000.756±0.0000.876±0.0000.439±0.000
LightGBM0.253±0.0000.054±0.0000.112±0.0000.302±0.0000.325±0.0000.384±0.0000.295±0.0000.272±0.0000.877±0.0000.011±0.0000.702±0.0000.863±0.0000.395±0.000
Trompt0.261±0.0030.015±0.0050.118±0.0010.262±0.0010.323±0.0010.418±0.0030.329±0.0090.312±0.002OOM0.008±0.0010.779±0.0060.874±0.0040.424±0.005
ResNet0.288±0.0060.018±0.0030.124±0.0010.268±0.0010.335±0.0010.434±0.0040.325±0.0120.324±0.0040.895±0.0050.036±0.0020.794±0.0060.875±0.0040.468±0.004
FTTransformerBucket0.325±0.0080.096±0.0050.360±0.3540.284±0.0050.342±0.0040.441±0.0030.345±0.0070.339±0.003OOM0.105±0.0110.807±0.0100.885±0.0080.468±0.006
ExcelFormer0.262±0.0040.099±0.0030.128±0.0000.264±0.0030.331±0.0030.411±0.0050.298±0.0120.308±0.007OOM0.011±0.0010.785±0.0110.890±0.0030.431±0.006
FTTransformer0.335±0.0100.161±0.0220.140±0.0020.277±0.0040.335±0.0030.445±0.0030.361±0.0180.345±0.005OOM0.106±0.0120.826±0.0050.896±0.0070.461±0.003
TabNet0.279±0.0030.224±0.0160.141±0.0100.275±0.0020.348±0.0030.451±0.0070.355±0.0300.332±0.0040.992±0.1820.015±0.0020.805±0.0140.885±0.0130.544±0.011
TabTransformer0.624±0.0030.229±0.0030.369±0.0050.340±0.0040.388±0.0020.539±0.0030.619±0.0050.351±0.0010.893±0.0050.431±0.0010.819±0.0020.886±0.0050.545±0.004

我们可以看到,一些最新的深度表格模型能够达到与强大的GBDT相当的模型性能(尽管训练速度慢5-100倍)。使深度表格模型在更少计算资源下表现更好是未来研究的一个富有成果的方向。

我们还在一个带有一列文本的真实世界表格数据集(葡萄酒评论)上对不同的文本编码器进行了基准测试。下表显示了性能结果:

测试准确率方法模型名称来源
0.7926预训练sentence-transformers/all-distilroberta-v1 (1.25亿参数)Hugging Face
0.7998预训练embed-english-v3.0 (维度大小: 1024)Cohere
0.8102预训练text-embedding-ada-002 (维度大小: 1536)OpenAI
0.8147预训练voyage-01 (维度大小: 1024)Voyage AI
0.8203预训练intfloat/e5-mistral-7b-instruct (70亿参数)Hugging Face
0.8230LoRA微调DistilBERT (6600万参数)Hugging Face

Hugging Face文本编码器的基准测试脚本在这个文件中,其他文本编码器的基准测试脚本在这个文件中。

安装

PyTorch Frame支持Python 3.8到Python 3.11版本。

pip install pytorch_frame

查看安装指南了解其他安装选项。

引用

如果您在工作中使用了PyTorch Frame,请引用我们的论文(Bibtex如下)。

@article{hu2024pytorch,
  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
  journal={arXiv preprint arXiv:2404.00776},
  year={2024}
}

编辑推荐精选

Vora

Vora

免费创建高清无水印Sora视频

Vora是一个免费创建高清无水印Sora视频的AI工具

Refly.AI

Refly.AI

最适合小白的AI自动化工作流平台

无需编码,轻松生成可复用、可变现的AI自动化工作流

酷表ChatExcel

酷表ChatExcel

大模型驱动的Excel数据处理工具

基于大模型交互的表格处理系统,允许用户通过对话方式完成数据整理和可视化分析。系统采用机器学习算法解析用户指令,自动执行排序、公式计算和数据透视等操作,支持多种文件格式导入导出。数据处理响应速度保持在0.8秒以内,支持超过100万行数据的即时分析。

AI工具使用教程AI营销产品酷表ChatExcelAI智能客服
TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

热门AI工具生产力协作转型TraeAI IDE
AIWritePaper论文写作

AIWritePaper论文写作

AI论文写作指导平台

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

数据安全AI助手热门AI工具AI辅助写作AI论文工具论文写作智能生成大纲
博思AIPPT

博思AIPPT

AI一键生成PPT,就用博思AIPPT!

博思AIPPT,新一代的AI生成PPT平台,支持智能生成PPT、AI美化PPT、文本&链接生成PPT、导入Word/PDF/Markdown文档生成PPT等,内置海量精美PPT模板,涵盖商务、教育、科技等不同风格,同时针对每个页面提供多种版式,一键自适应切换,完美适配各种办公场景。

热门AI工具AI办公办公工具智能排版AI生成PPT博思AIPPT海量精品模板AI创作
潮际好麦

潮际好麦

AI赋能电商视觉革命,一站式智能商拍平台

潮际好麦深耕服装行业,是国内AI试衣效果最好的软件。使用先进AIGC能力为电商卖家批量提供优质的、低成本的商拍图。合作品牌有Shein、Lazada、安踏、百丽等65个国内外头部品牌,以及国内10万+淘宝、天猫、京东等主流平台的品牌商家,为卖家节省将近85%的出图成本,提升约3倍出图效率,让品牌能够快速上架。

iTerms

iTerms

企业专属的AI法律顾问

iTerms是法大大集团旗下法律子品牌,基于最先进的大语言模型(LLM)、专业的法律知识库和强大的智能体架构,帮助企业扫清合规障碍,筑牢风控防线,成为您企业专属的AI法律顾问。

SimilarWeb流量提升

SimilarWeb流量提升

稳定高效的流量提升解决方案,助力品牌曝光

稳定高效的流量提升解决方案,助力品牌曝光

Sora2视频免费生成

Sora2视频免费生成

最新版Sora2模型免费使用,一键生成无水印视频

最新版Sora2模型免费使用,一键生成无水印视频

下拉加载更多