Crossformer

Crossformer

高效利用跨维度依赖的多变量时间序列预测模型

Crossformer是一种新型Transformer模型,针对多变量时间序列预测设计。该模型采用维度分段嵌入、两阶段注意力机制和层次编码器-解码器结构,有效捕捉时间和维度间的依赖关系。Crossformer在多个基准数据集上表现优异,为长序列预测和高维数据处理提供新思路。其开源实现便于研究人员和实践者探索应用。

Crossformer时间序列预测注意力机制深度学习TransformerGithub开源项目

Crossformer:利用跨维度依赖关系的Transformer用于多变量时间序列预测(ICLR 2023)

这是Crossformer: 利用跨维度依赖关系的Transformer用于多变量时间序列预测的原始Pytorch实现。

Crossformer的关键点

1. 维度段嵌入(DSW)

<p align="center"> <img src="https://raw.githubusercontent.com/Thinklab-SJTU/Crossformer/master/.\pic\DSW.PNG" height = "200" alt="" align=center />

<b>图1.</b> DSW嵌入。<b></b>:之前基于Transformer模型的嵌入方法:同一时间步不同维度的数据点被嵌入到一个向量中;<b></b>:Crossformer的DSW嵌入:在每个维度中,时间上相邻的点形成一个段进行嵌入。

</p>

2. 两阶段注意力(TSA)层

<p align="center"> <img src="https://raw.githubusercontent.com/Thinklab-SJTU/Crossformer/master/.\pic\TSA.PNG" height = "200" alt="" align=center />

<b>图2.</b> TSA层。<b></b>:整体结构:2D向量数组通过跨时间阶段和跨维度阶段获得相应的依赖关系;<b></b>:在跨维度阶段直接使用MSA建立$D$到$D$的连接会导致$O(D^2)$的复杂度。<b></b>:跨维度阶段的路由机制:固定数量($c$)的"路由器"在维度间收集和分发信息。复杂度降低到$O(2cD) = O(D)$。

</p>

3. 分层编码器-解码器(HED)

<p align="center"> <img src="https://raw.githubusercontent.com/Thinklab-SJTU/Crossformer/master/.\pic\HED.PNG" height = "200" alt="" align=center />

<b>图3.</b> HED。编码器(左)使用TSA层和段合并来捕获不同尺度的依赖关系;解码器(右)通过在每个尺度上进行预测并将它们相加来做出最终预测。

</p>

环境要求

  • Python 3.7.10
  • numpy==1.20.3
  • pandas==1.3.2
  • torch==1.8.1
  • einops==0.4.1

复现步骤

  1. 将用于实验的数据集放入datasets/文件夹。我们已经将ETTh1ETTm1放入其中。WTHECL可以从https://github.com/zhouhaoyi/Informer2020 下载。ILITraffic可以从https://github.com/thuml/Autoformer 下载。请注意,我们在论文中使用的WTH是来自Informer的12维数据集,而不是来自Autoformer的21维数据集。

  2. 要在ETTh1数据集上获得$T=168, \tau = 24, L_{seg} = 6$的Crossformer结果,运行:

python main_crossformer.py --data ETTh1 --in_len 168 --out_len 24 --seg_len 6 --itr 1

模型将自动训练和测试。训练好的模型将保存在checkpoints/文件夹中,评估指标将保存在results/文件夹中。

  1. 你也可以通过运行以下命令来评估已训练的模型:
python eval_crossformer.py --checkpoint_root ./checkpoints --setting_name Crossformer_ETTh1_il168_ol24_sl6_win2_fa10_dm256_nh4_el3_itr0
  1. 要复现论文中的所有结果,运行以下脚本以获得相应结果:
bash scripts/ETTh1.sh
bash scripts/ETTm1.sh
bash scripts/WTH.sh
bash scripts/ECL.sh
bash scripts/ILI.sh
bash scripts/Traffic.sh

自定义使用

我们使用AirQuality数据集来展示如何使用自己的数据训练和评估Crossformer。

  1. AirQualityUCI.csv数据集修改为以下格式,其中第一列是日期(或者你可以将第一列留空),其他13列是要预测的多变量时间序列。并将修改后的文件放入datasets/文件夹。
<p align="center"> <img src="https://raw.githubusercontent.com/Thinklab-SJTU/Crossformer/master/.\pic\Data_format.PNG" height = "120" alt="" align=center /> <br> <b>图4.</b> 自定义数据集的示例。 </p>
  1. 这是一个每小时采样的13维数据集。我们将使用过去一周(168小时)的数据来预测下一天(24小时),段长度设置为6。因此,我们需要运行:
python main_crossformer.py --data AirQuality --data_path AirQualityUCI.csv --data_dim 13 --in_len 168 --out_len 24 --seg_len 6
  1. 我们可以通过运行以下命令来评估训练好的模型:
python eval_crossformer.py --setting_name Crossformer_AirQuality_il168_ol24_sl6_win2_fa10_dm256_nh4_el3_itr0 --save_pred

模型将被评估,预测和真实序列将被保存在results/Crossformer_AirQuality_il168_ol24_sl6_win2_fa10_dm256_nh4_el3_itr0main_crossformer 是我们模型的入口点,还有其他可调参数。以下是它们的详细说明:

参数名称参数描述
data数据集名称
root_path数据文件的根路径(默认为 ./datasets/
data_path数据文件名(默认为 ETTh1.csv
data_split训练/验证/测试集划分,可以是比例(如 0.7,0.1,0.2)或具体数量(如 16800,2880,2880),(默认为 0.7,0.1,0.2
checkpoints存储训练模型的位置(默认为 ./checkpoints/
in_len输入/历史序列长度,即论文中的 $T$(默认为 96)
out_len输出/未来序列长度,即论文中的 $\tau$(默认为 24)
seg_lenDSW 嵌入中每个段的长度,即论文中的 $L_{seg}$(默认为 6)
win_sizeHED 段合并中合并相邻段的数量(默认为 2)
factorTSA 的跨维度阶段中路由器的数量,即论文中的 $c$(默认为 10)
data_dimMTS 数据的维度数,即论文中的 $D$(ETTh 和 ETTm 默认为 7)
d_model隐藏状态的维度,即论文中的 $d_{model}$(默认为 256)
d_ffMSA 中 MLP 的维度(默认为 512)
n_headsMSA 中的头数(默认为 4)
e_layers编码器层数,即论文中的 $N$(默认为 3)
dropout随机失活概率(默认为 0.2)
num_workers数据加载器的工作线程数(默认为 0)
batch_size训练和测试的批量大小(默认为 32)
train_epochs训练轮数(默认为 20)
patience早停耐心值(默认为 3)
learning_rate优化器的初始学习率(默认为 1e-4)
lradj调整学习率的方式(默认为 type1
itr实验次数(默认为 1)
save_pred是否保存预测结果。如果为 True,预测结果将以 numpy 数组形式保存在 results 文件夹中。对于 $D$ 较大的数据集,这将耗费大量时间和内存。(默认为 False
use_gpu是否使用 GPU(默认为 True
gpu用于训练和推理的 GPU 编号(默认为 0)
use_multi_gpu是否使用多个 GPU(默认为 False
devices多个 GPU 的设备 ID(默认为 0,1,2,3

引用

如果您在研究中发现本仓库有用,请引用:

@inproceedings{
zhang2023crossformer,
title={Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting},
author={Yunhao Zhang and Junchi Yan},
booktitle={International Conference on Learning Representations},
year={2023},
}

致谢

我们感谢以下工作为时间序列预测提供的宝贵代码和数据:

https://github.com/zhouhaoyi/Informer2020

https://github.com/thuml/Autoformer

https://github.com/alipay/Pyraformer

https://github.com/MAZiqing/FEDformer

以下两个视觉 Transformer 工作也启发了我们的 DSW 嵌入和 HED 设计:

https://github.com/google-research/vision_transformer

https://github.com/microsoft/Swin-Transformer

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多