模块化深度学习框架用于异构表格数据
PyTorch Frame是一个为异构表格数据设计的深度学习框架,支持数值、分类、时间、文本和图像等多种列类型。它采用模块化架构,实现了先进的深度表格模型,并可与大型语言模型集成。该框架提供了便捷的mini-batch加载器、基准数据集和自定义数据接口,简化了表格数据的深度学习研究过程,适用于各层次研究人员。框架内置多个预实现的深度表格模型,如Trompt、FTTransformer和TabNet等,并提供与XGBoost等GBDT模型的性能对比基准。PyTorch Frame无缝集成于PyTorch生态系统,便于与其他PyTorch库协同使用,为端到端的深度学习研究提供了便利。
一个用于在异构表格数据上构建神经网络模型的模块化深度学习框架。
PyTorch Frame 是 PyTorch 的深度学习扩展,专为包含数值、类别、时间、文本和图像等不同列类型的异构表格数据设计。它提供了一个模块化框架,用于实现现有和未来的方法。该库包含最先进模型的方法、用户友好的小批量加载器、基准数据集以及自定义数据集成接口。
PyTorch Frame 让表格数据的深度学习研究变得更加普及,既适合新手也适合专家。我们的目标是:
促进表格数据的深度学习: 历史上,基于树的模型(如 GBDT)在表格学习方面表现出色,但存在一些明显的局限性,例如与下游模型的集成困难,以及处理复杂列类型(如文本、序列和嵌入)的能力。深度表格模型有望解决这些局限性。我们的目标是通过模块化实现并支持多样的列类型,来促进表格数据的深度学习研究。
与大型语言模型等多样化模型架构集成: PyTorch Frame 支持与各种不同的架构集成,包括大型语言模型。使用任何下载的模型或嵌入 API 端点,您可以为文本数据生成嵌入,并与其他复杂语义类型一起使用深度学习模型进行训练。我们支持以下(但不限于):
PyTorch Frame 直接构建在 PyTorch 之上,确保现有 PyTorch 用户能够顺利过渡。主要特性包括:
numerical
、categorical
、multicategorical
、text_embedded
、text_tokenized
、timestamp
、image_embedded
和 embedding
。详细教程请参见此处。PyTorch Frame 中的模型遵循 FeatureEncoder
、TableConv
和 Decoder
的模块化设计,如下图所示:
本质上,这种模块化设置使用户能够轻松尝试各种架构:
Materialization
处理将原始 pandas DataFrame
转换为适合 PyTorch 训练和建模的 TensorFrame
。FeatureEncoder
将 TensorFrame
编码为大小为 [batch_size, num_cols, channels]
的隐藏列嵌入。TableConv
对隐藏嵌入进行列间交互建模。Decoder
为每行生成嵌入/预测。在这个快速上手中,我们将展示如何仅用几行代码就能轻松创建和训练一个深度表格模型。
作为示例,我们将按照 PyTorch Frame 的模块化架构实现一个简单的 ExampleTransformer
。
在下面的示例中:
self.encoder
将输入的 TensorFrame
映射到大小为 [batch_size, num_cols, channels]
的嵌入。self.convs
迭代地将大小为 [batch_size, num_cols, channels]
的嵌入转换为相同大小的嵌入。self.decoder
将大小 为 [batch_size, num_cols, channels]
的嵌入池化为 [batch_size, out_channels]
。from torch import Tensor from torch.nn import Linear, Module, ModuleList from torch_frame import TensorFrame, stype from torch_frame.nn.conv import TabTransformerConv from torch_frame.nn.encoder import ( EmbeddingEncoder, LinearEncoder, StypeWiseFeatureEncoder, ) class ExampleTransformer(Module): def __init__( self, channels, out_channels, num_layers, num_heads, col_stats, col_names_dict, ): super().__init__() self.encoder = StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict={ stype.categorical: EmbeddingEncoder(), stype.numerical: LinearEncoder() }, ) self.convs = ModuleList([ TabTransformerConv( channels=channels, num_heads=num_heads, ) for _ in range(num_layers) ]) self.decoder = Linear(channels, out_channels) def forward(self, tf: TensorFrame) -> Tensor: x, _ = self.encoder(tf) for conv in self.convs: x = conv(x) out = self.decoder(x.mean(dim=1)) return out
要准备数据,我们可以快速实例化一个预定义的数据集并创建一个与 PyTorch 兼容的数据加载器,如下所示:
from torch_frame.datasets import Yandex from torch_frame.data import DataLoader dataset = Yandex(root='/tmp/adult', name='adult') dataset.materialize() train_dataset = dataset[:0.8] train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128, shuffle=True)
然后,我们只需按照<a href="https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation">标准PyTorch训练流程</a>来优化模型参数。就这么简单!
import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ExampleTransformer( channels=32, out_channels=dataset.num_classes, num_layers=2, num_heads=8, col_stats=train_dataset.col_stats, col_names_dict=train_dataset.tensor_frame.col_names_dict, ).to(device) optimizer = torch.optim.Adam(model.parameters()) for epoch in range(50): for tf in train_loader: tf = tf.to(device) pred = model.forward(tf) loss = F.cross_entropy(pred, tf.y) optimizer.zero_grad() loss.backward()
以下是目前支持的深度表格模型列表:
此外,我们还为想要将模型性能与GBDT进行比较的用户实现了XGBoost
、CatBoost
和LightGBM
的示例,这些示例使用Optuna进行超参数调优。
我们在多种规模和任务类型的公共数据集上对最近的表格深度学习模型与GBDT进行了基准测试。
下图展示了各种模型在小型回归数据集上的性能,其中行代表模型名称,列代表数据集索引(这里我们有13个数据集)。有关分类和更大数据集的更多结果,请查看基准测试文档。
模型名称 | 数据集0 | 数据集1 | 数据集2 | 数据集3 | 数据集4 | 数据集5 | 数据集6 | 数据集7 | 数据集8 | 数据集9 | 数据集10 | 数据集11 | 数据集12 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.250±0.000 | 0.038±0.000 | 0.187±0.000 | 0.475±0.000 | 0.328±0.000 | 0.401±0.000 | 0.249±0.000 | 0.363±0.000 | 0.904±0.000 | 0.056±0.000 | 0.820±0.000 | 0.857±0.000 | 0.418±0.000 |
CatBoost | 0.265±0.000 | 0.062±0.000 | 0.128±0.000 | 0.336±0.000 | 0.346±0.000 | 0.443±0.000 | 0.375±0.000 | 0.273±0.000 | 0.881±0.000 | 0.040±0.000 | 0.756±0.000 | 0.876±0.000 | 0.439±0.000 |
LightGBM | 0.253±0.000 | 0.054±0.000 | 0.112±0.000 | 0.302±0.000 | 0.325±0.000 | 0.384±0.000 | 0.295±0.000 | 0.272±0.000 | 0.877±0.000 | 0.011±0.000 | 0.702±0.000 | 0.863±0.000 | 0.395±0.000 |
Trompt | 0.261±0.003 | 0.015±0.005 | 0.118±0.001 | 0.262±0.001 | 0.323±0.001 | 0.418±0.003 | 0.329±0.009 | 0.312±0.002 | OOM | 0.008±0.001 | 0.779±0.006 | 0.874±0.004 | 0.424±0.005 |
ResNet | 0.288±0.006 | 0.018±0.003 | 0.124±0.001 | 0.268±0.001 | 0.335±0.001 | 0.434±0.004 | 0.325±0.012 | 0.324±0.004 | 0.895±0.005 | 0.036±0.002 | 0.794±0.006 | 0.875±0.004 | 0.468±0.004 |
FTTransformerBucket | 0.325±0.008 | 0.096±0.005 | 0.360±0.354 | 0.284±0.005 | 0.342±0.004 | 0.441±0.003 | 0.345±0.007 | 0.339±0.003 | OOM | 0.105±0.011 | 0.807±0.010 | 0.885±0.008 | 0.468±0.006 |
ExcelFormer | 0.262±0.004 | 0.099±0.003 | 0.128±0.000 | 0.264±0.003 | 0.331±0.003 | 0.411±0.005 | 0.298±0.012 | 0.308±0.007 | OOM | 0.011±0.001 | 0.785±0.011 | 0.890±0.003 | 0.431±0.006 |
FTTransformer | 0.335±0.010 | 0.161±0.022 | 0.140±0.002 | 0.277±0.004 | 0.335±0.003 | 0.445±0.003 | 0.361±0.018 | 0.345±0.005 | OOM | 0.106±0.012 | 0.826±0.005 | 0.896±0.007 | 0.461±0.003 |
TabNet | 0.279±0.003 | 0.224±0.016 | 0.141±0.010 | 0.275±0.002 | 0.348±0.003 | 0.451±0.007 | 0.355±0.030 | 0.332±0.004 | 0.992±0.182 | 0.015±0.002 | 0.805±0.014 | 0.885±0.013 | 0.544±0.011 |
TabTransformer | 0.624±0.003 | 0.229±0.003 | 0.369±0.005 | 0.340±0.004 | 0.388±0.002 | 0.539±0.003 | 0.619±0.005 | 0.351±0.001 | 0.893±0.005 | 0.431±0.001 | 0.819±0.002 | 0.886±0.005 | 0.545±0.004 |
我们可以看到,一些最新的深度表格模型能够达到与强大的GBDT相当的模型性能(尽管训练速度慢5-100倍)。使深度表格模型在更少计算资源下表现更好是未来研究的一个富有成果的方向。
我们还在一个带有一列文本的真实世界表格数据集(葡萄酒评论)上对不同的文本编码器进行了基准测试。下表显示了性能结果:
测试准确率 | 方法 | 模型名称 | 来源 |
---|---|---|---|
0.7926 | 预训练 | sentence-transformers/all-distilroberta-v1 (1.25亿参数) | Hugging Face |
0.7998 | 预训练 | embed-english-v3.0 (维度大小: 1024) | Cohere |
0.8102 | 预训练 | text-embedding-ada-002 (维度大小: 1536) | OpenAI |
0.8147 | 预训练 | voyage-01 (维度大小: 1024) | Voyage AI |
0.8203 | 预训练 | intfloat/e5-mistral-7b-instruct (70亿参数) | Hugging Face |
0.8230 | LoRA微调 | DistilBERT (6600万参数) | Hugging Face |
Hugging Face文本编码器的基准测试脚本在这个文件中,其他文本编码器的基准测试脚本在这个文件中。
PyTorch Frame支持Python 3.8到Python 3.11版本。
pip install pytorch_frame
查看安装指南了解其他安装选项。
如果您在工作中使用了PyTorch Frame,请引用我们的论文(Bibtex如下)。
@article{hu2024pytorch,
title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
journal={arXiv preprint arXiv:2404.00776},
year={2024}
}
字节跳动发布的AI编程神器IDE
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
全能AI智能助手,随时解答生活与工作的多样问题
问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。
实时语音翻译/同声传译工具
Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成 高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。
AI助力,做PPT更简单!
咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。
选题、配图、成文,一站式创作,让内容运营更高效
讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。
专业的AI公文写作平台,公文写作神器
AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。
OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。
openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号