pytorch-frame

pytorch-frame

模块化深度学习框架用于异构表格数据

PyTorch Frame是一个为异构表格数据设计的深度学习框架,支持数值、分类、时间、文本和图像等多种列类型。它采用模块化架构,实现了先进的深度表格模型,并可与大型语言模型集成。该框架提供了便捷的mini-batch加载器、基准数据集和自定义数据接口,简化了表格数据的深度学习研究过程,适用于各层次研究人员。框架内置多个预实现的深度表格模型,如Trompt、FTTransformer和TabNet等,并提供与XGBoost等GBDT模型的性能对比基准。PyTorch Frame无缝集成于PyTorch生态系统,便于与其他PyTorch库协同使用,为端到端的深度学习研究提供了便利。

PyTorch Frame深度学习表格数据神经网络模块化框架Github开源项目
<div align="center"> <img height="175" src="https://yellow-cdn.veclightyear.com/835a84d5/e99ea417-9818-4fd1-90ff-62febdad0f2c.png?sanitize=true" /> <br> <br>

一个用于在异构表格数据上构建神经网络模型的模块化深度学习框架。


arXiv PyPI 版本 测试状态 文档状态 贡献 Slack

</div>

文档 | 论文

PyTorch Frame 是 PyTorch 的深度学习扩展,专为包含数值、类别、时间、文本和图像等不同列类型的异构表格数据设计。它提供了一个模块化框架,用于实现现有和未来的方法。该库包含最先进模型的方法、用户友好的小批量加载器、基准数据集以及自定义数据集成接口。

PyTorch Frame 让表格数据的深度学习研究变得更加普及,既适合新手也适合专家。我们的目标是:

  1. 促进表格数据的深度学习: 历史上,基于树的模型(如 GBDT)在表格学习方面表现出色,但存在一些明显的局限性,例如与下游模型的集成困难,以及处理复杂列类型(如文本、序列和嵌入)的能力。深度表格模型有望解决这些局限性。我们的目标是通过模块化实现并支持多样的列类型,来促进表格数据的深度学习研究。

  2. 与大型语言模型等多样化模型架构集成: PyTorch Frame 支持与各种不同的架构集成,包括大型语言模型。使用任何下载的模型或嵌入 API 端点,您可以为文本数据生成嵌入,并与其他复杂语义类型一起使用深度学习模型进行训练。我们支持以下(但不限于):

<table> <tr> <td align="center"> <a href="https://platform.openai.com/docs/guides/embeddings"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/9d61a92f-d8ab-4ad7-bcc4-a686bb6d9c83.png" alt="OpenAI" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">OpenAI 嵌入代码示例</a> </td> <td align="center"> <a href="https://cohere.com/embeddings"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/a4f0154b-4636-4eac-8df0-e46159b868d1.png" alt="Cohere" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Cohere Embed v3 代码示例</a> </td> <td align="center"> <a href="https://huggingface.co/"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/1adfecf3-86dd-4b67-add8-60110e8ebd55.png" alt="Hugging Face" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py">Hugging Face 代码示例</a> </td> <td align="center"> <a href="https://www.voyageai.com/"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/079d0720-c739-4596-9358-cdfbf1d57324.webp" alt="Voyage AI" width="100px"/> </a> <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Voyage AI 代码示例</a> </td> </tr> </table> <hr style="border: 0.5px solid #ccc;">

库亮点

PyTorch Frame 直接构建在 PyTorch 之上,确保现有 PyTorch 用户能够顺利过渡。主要特性包括:

  • 多样化列类型: PyTorch Frame 支持跨各种列类型的学习:numericalcategoricalmulticategoricaltext_embeddedtext_tokenizedtimestampimage_embeddedembedding。详细教程请参见此处
  • 模块化模型设计: 支持模块化深度学习模型实现,促进代码重用、清晰编码和实验灵活性。更多详情请参见架构概览
  • 模型 实现了许多最先进的深度表格模型以及强大的 GBDT(XGBoost、CatBoost 和 LightGBM),并支持超参数调优。
  • 数据集: 提供了一系列可直接使用的表格数据集。同时支持自定义数据集以解决您自己的问题。 我们对深度表格模型与 GBDT 进行了基准测试
  • PyTorch 集成: 与其他 PyTorch 库无缝集成,便于 PyTorch Frame 与下游 PyTorch 模型的端到端训练。例如,通过与 PyTorch 图神经网络库 PyG 集成,我们可以对关系数据库进行深度学习。更多信息请参见 RelBench示例代码(进行中)

架构概览

PyTorch Frame 中的模型遵循 FeatureEncoderTableConvDecoder 的模块化设计,如下图所示:

<p align="center"> <img width="50%" src="https://yellow-cdn.veclightyear.com/835a84d5/c3afa3c9-b2bc-4b15-8964-dc3ce4371827.png" /> </p>

本质上,这种模块化设置使用户能够轻松尝试各种架构:

  • Materialization 处理将原始 pandas DataFrame 转换为适合 PyTorch 训练和建模的 TensorFrame
  • FeatureEncoderTensorFrame 编码为大小为 [batch_size, num_cols, channels] 的隐藏列嵌入。
  • TableConv 对隐藏嵌入进行列间交互建模。
  • Decoder 为每行生成嵌入/预测。

快速上手

在这个快速上手中,我们将展示如何仅用几行代码就能轻松创建和训练一个深度表格模型。

构建并训练您自己的深度表格模型

作为示例,我们将按照 PyTorch Frame 的模块化架构实现一个简单的 ExampleTransformer。 在下面的示例中:

  • self.encoder 将输入的 TensorFrame 映射到大小为 [batch_size, num_cols, channels] 的嵌入。
  • self.convs 迭代地将大小为 [batch_size, num_cols, channels] 的嵌入转换为相同大小的嵌入。
  • self.decoder 将大小为 [batch_size, num_cols, channels] 的嵌入池化为 [batch_size, out_channels]
from torch import Tensor from torch.nn import Linear, Module, ModuleList from torch_frame import TensorFrame, stype from torch_frame.nn.conv import TabTransformerConv from torch_frame.nn.encoder import ( EmbeddingEncoder, LinearEncoder, StypeWiseFeatureEncoder, ) class ExampleTransformer(Module): def __init__( self, channels, out_channels, num_layers, num_heads, col_stats, col_names_dict, ): super().__init__() self.encoder = StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict={ stype.categorical: EmbeddingEncoder(), stype.numerical: LinearEncoder() }, ) self.convs = ModuleList([ TabTransformerConv( channels=channels, num_heads=num_heads, ) for _ in range(num_layers) ]) self.decoder = Linear(channels, out_channels) def forward(self, tf: TensorFrame) -> Tensor: x, _ = self.encoder(tf) for conv in self.convs: x = conv(x) out = self.decoder(x.mean(dim=1)) return out

要准备数据,我们可以快速实例化一个预定义的数据集并创建一个与 PyTorch 兼容的数据加载器,如下所示:

from torch_frame.datasets import Yandex from torch_frame.data import DataLoader dataset = Yandex(root='/tmp/adult', name='adult') dataset.materialize() train_dataset = dataset[:0.8] train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128, shuffle=True)

然后,我们只需按照<a href="https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation">标准PyTorch训练流程</a>来优化模型参数。就这么简单!

import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ExampleTransformer( channels=32, out_channels=dataset.num_classes, num_layers=2, num_heads=8, col_stats=train_dataset.col_stats, col_names_dict=train_dataset.tensor_frame.col_names_dict, ).to(device) optimizer = torch.optim.Adam(model.parameters()) for epoch in range(50): for tf in train_loader: tf = tf.to(device) pred = model.forward(tf) loss = F.cross_entropy(pred, tf.y) optimizer.zero_grad() loss.backward()

已实现的深度表格模型

以下是目前支持的深度表格模型列表:

此外,我们还为想要将模型性能与GBDT进行比较的用户实现了XGBoostCatBoostLightGBM示例,这些示例使用Optuna进行超参数调优。

基准测试

我们在多种规模和任务类型的公共数据集上对最近的表格深度学习模型与GBDT进行了基准测试。

下图展示了各种模型在小型回归数据集上的性能,其中行代表模型名称,列代表数据集索引(这里我们有13个数据集)。有关分类和更大数据集的更多结果,请查看基准测试文档

模型名称数据集0数据集1数据集2数据集3数据集4数据集5数据集6数据集7数据集8数据集9数据集10数据集11数据集12
XGBoost0.250±0.0000.038±0.0000.187±0.0000.475±0.0000.328±0.0000.401±0.0000.249±0.0000.363±0.0000.904±0.0000.056±0.0000.820±0.0000.857±0.0000.418±0.000
CatBoost0.265±0.0000.062±0.0000.128±0.0000.336±0.0000.346±0.0000.443±0.0000.375±0.0000.273±0.0000.881±0.0000.040±0.0000.756±0.0000.876±0.0000.439±0.000
LightGBM0.253±0.0000.054±0.0000.112±0.0000.302±0.0000.325±0.0000.384±0.0000.295±0.0000.272±0.0000.877±0.0000.011±0.0000.702±0.0000.863±0.0000.395±0.000
Trompt0.261±0.0030.015±0.0050.118±0.0010.262±0.0010.323±0.0010.418±0.0030.329±0.0090.312±0.002OOM0.008±0.0010.779±0.0060.874±0.0040.424±0.005
ResNet0.288±0.0060.018±0.0030.124±0.0010.268±0.0010.335±0.0010.434±0.0040.325±0.0120.324±0.0040.895±0.0050.036±0.0020.794±0.0060.875±0.0040.468±0.004
FTTransformerBucket0.325±0.0080.096±0.0050.360±0.3540.284±0.0050.342±0.0040.441±0.0030.345±0.0070.339±0.003OOM0.105±0.0110.807±0.0100.885±0.0080.468±0.006
ExcelFormer0.262±0.0040.099±0.0030.128±0.0000.264±0.0030.331±0.0030.411±0.0050.298±0.0120.308±0.007OOM0.011±0.0010.785±0.0110.890±0.0030.431±0.006
FTTransformer0.335±0.0100.161±0.0220.140±0.0020.277±0.0040.335±0.0030.445±0.0030.361±0.0180.345±0.005OOM0.106±0.0120.826±0.0050.896±0.0070.461±0.003
TabNet0.279±0.0030.224±0.0160.141±0.0100.275±0.0020.348±0.0030.451±0.0070.355±0.0300.332±0.0040.992±0.1820.015±0.0020.805±0.0140.885±0.0130.544±0.011
TabTransformer0.624±0.0030.229±0.0030.369±0.0050.340±0.0040.388±0.0020.539±0.0030.619±0.0050.351±0.0010.893±0.0050.431±0.0010.819±0.0020.886±0.0050.545±0.004

我们可以看到,一些最新的深度表格模型能够达到与强大的GBDT相当的模型性能(尽管训练速度慢5-100倍)。使深度表格模型在更少计算资源下表现更好是未来研究的一个富有成果的方向。

我们还在一个带有一列文本的真实世界表格数据集(葡萄酒评论)上对不同的文本编码器进行了基准测试。下表显示了性能结果:

测试准确率方法模型名称来源
0.7926预训练sentence-transformers/all-distilroberta-v1 (1.25亿参数)Hugging Face
0.7998预训练embed-english-v3.0 (维度大小: 1024)Cohere
0.8102预训练text-embedding-ada-002 (维度大小: 1536)OpenAI
0.8147预训练voyage-01 (维度大小: 1024)Voyage AI
0.8203预训练intfloat/e5-mistral-7b-instruct (70亿参数)Hugging Face
0.8230LoRA微调DistilBERT (6600万参数)Hugging Face

Hugging Face文本编码器的基准测试脚本在这个文件中,其他文本编码器的基准测试脚本在这个文件中。

安装

PyTorch Frame支持Python 3.8到Python 3.11版本。

pip install pytorch_frame

查看安装指南了解其他安装选项。

引用

如果您在工作中使用了PyTorch Frame,请引用我们的论文(Bibtex如下)。

@article{hu2024pytorch,
  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
  journal={arXiv preprint arXiv:2404.00776},
  year={2024}
}

编辑推荐精选

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

下拉加载更多