RecStudio

RecStudio

基于PyTorch的模块化推荐系统库 支持多任务多模型

RecStudio是一个基于PyTorch的模块化推荐系统库。它支持通用、序列、知识、特征和社交等多种推荐任务。该框架提供灵活的模型结构、统一的数据处理、GPU加速、简洁的模型分类和多种负采样方法。RecStudio为推荐系统研究和开发提供了高效便捷的工具。

RecStudio推荐系统深度学习机器学习PyTorchGithub开源项目

RecStudio

<p float="left"> <img src="https://img.shields.io/badge/python-v3.7+-blue"> <img src="https://img.shields.io/badge/pytorch-v1.9+-blue"> <img src="https://img.shields.io/badge/License-MIT-blue"> </p> <p align="left"> <img src="https://yellow-cdn.veclightyear.com/0a4dffa0/55ea8e0a-c973-4ea8-a24b-c2d2af071a71.png" alt="RecStudio logo" width="300"> <br> </p>

RecStudio是一个基于PyTorch的统一、高度模块化和推荐效率高的推荐系统库。所有算法根据推荐任务分类如下:

  • 通用推荐
  • 序列推荐
  • 基于知识的推荐
  • 基于特征的推荐
  • 社交推荐

描述

模型结构

在库的核心部分,所有推荐模型被分为三个基类:

  • TowerFreeRecommender:最灵活的基类,可以实现任何复杂的特征交互建模。
  • ItemTowerRecommender:物品编码器与推荐器分离,支持快速ANN和基于模型的负采样。
  • TwoTowerRecommenderItemTowerRecommender的子类,推荐器仅由用户编码器和物品编码器组成。

数据集结构

对于数据集结构,数据集被分为五类:

数据集应用示例
TripletDataset提供用户-物品-评分三元组的数据集BPR, NCF, CML等
UserDataset用于基于AutoEncoder的ItemTowerRecommender的数据集MultiVAE, RecVAE等
SeqDataset用于因果预测的序列推荐器的数据集GRU4Rec, SASRec等
Seq2SeqDataset用于掩码预测的序列推荐器的数据集Bert4Rec等
ALSDataset用于交替最小二乘法优化的推荐器的数据集WRMF等

为了加速数据集处理,处理后的数据集会自动缓存,以便快速重复训练。

模型评估

RecStudio基于PyTorch实现了推荐系统中几乎所有常用的指标,如NDCGRecallPrecision等。所有指标函数具有相同的接口,完全使用张量运算符实现。因此,评估过程可以移至GPU上,从而显著加快评估速度。

ANNs与采样器

为了加速训练和评估,RecStudio集成了各种近似最近邻搜索(ANNs)和负采样器。通过使用ANNs构建索引,基于欧氏距离、内积和余弦相似度的topk运算可以显著加速。负采样器包括静态采样器和RecStudio团队开发的基于模型的采样器。静态采样器包括均匀采样器流行度采样器。基于模型的采样器基于物品向量的量化或重要性重采样。此外,我们还在数据集中实现了静态采样,这使我们能够在加载数据时生成负样本。

损失函数与评分函数

在RecStudio中,损失函数分为三类: - FullScoreLoss:在所有物品上计算分数,如SoftmaxLoss。 - PairwiseLoss:在正样本和负样本上计算分数,如BPRLossBinaryCrossEntropyLoss等。 - PointwiseLoss:为单个(用户,物品)交互计算分数,如HingeLoss

评分函数用于建模用户对物品的偏好。RecStudio实现了各种常用的评分函数,如InnerProductEuclideanDistanceCosineDistanceMLPScorer等。

损失函数数学类型采样分布计算复杂度采样复杂度收敛速度相关指标
Softmax<!-- $-\log \frac{\exp f_{\theta}(c, {\color{red}k})}{\sum_{i=1}^{N} \exp f_{\theta}(c, i)}$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;-\log&space;\frac{\exp&space;f_{\theta}(c,&space;{\color{red}k})}{\sum_{i=1}^{N}&space;\exp&space;f_{\theta}(c,&space;i)}">无采样<!-- $O(N)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(N)">-非常快NDCG
采样Softmax<!-- $-\log \frac{\exp \left(f_{\theta}(c, {\color{red}k})-\log Q({\color{red}k} \mid c)\right)}{\sum_{i \in S \cup\{k\}} \exp \left(f_{\theta}(c, i)-\log Q(i \mid c)\right)}$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;-\log&space;\frac{\exp&space;\left(f_{\theta}(c,&space;{\color{red}k})-\log&space;Q({\color{red}k}&space;\mid&space;c)\right)}{\sum_{i&space;\in&space;S&space;\cup\{k\}}&space;\exp&space;\left(f_{\theta}(c,&space;i)-\log&space;Q(i&space;\mid&space;c)\right)}">无采样<!-- $O(N)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(N)">-NDCG
BPR<!-- $-\log \left(\sigma\left(f_{\theta}(c, {\color{red}k})-f_{\theta}(c, {\color{red}j})\right)\right)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;-\log&space;\left(\sigma\left(f_{\theta}(c,&space;{\color{red}k})-f_{\theta}(c,&space;{\color{red}j})\right)\right)">均匀采样<!-- $O(1)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(1)"><!-- $O(1)$ --> <img style="transform: translateY(0.1em);" src="https://render.githubusercontent.com/render/math?math=O(1)">AUC
WARP<!-- $L\left(\left\lfloor \frac{Y-1}{N}\right\rfloor\right)\lvert 1-f_{\theta}(c, {\color{red}k})+f_{\theta}(c, {\color{red}j})\rvert_{+}$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;L\left(\left\lfloor&space;\frac{Y-1}{N}\right\rfloor\right)\lvert&space;1-f_{\theta}(c,&space;{\color{red}k})&plus;f_{\theta}(c,&space;{\color{red}j})\rvert_{&plus;}">拒绝采样<!-- $O(1)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(1)">越来越慢精确率
InfoNCE<!-- $-\log \frac{\exp \left(f_{\theta}(c, {\color{red}k})\right)}{\sum_{i \in S \cup\{k\}} \exp \left(f_{\theta}(c, i)\right)}$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;-\log&space;\frac{\exp&space;\left(f_{\theta}(c,&space;{\color{red}k})\right)}{\sum_{i&space;\in&space;S&space;\cup\{k\}}&space;\exp&space;\left(f_{\theta}(c,&space;i)\right)}">流行度采样<!-- $O(\lvert S\rvert)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(\lvert&space;S\rvert)"><!-- $O(1)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(1)">DCG
WRMF<!-- $\sum_{j} {\color{red}w_{c j}}\left(f_{\theta}(c, j)-y(c, j)\right)^{2}$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;\sum_{j}&space;{\color{red}w_{c&space;j}}\left(f_{\theta}(c,&space;j)-y(c,&space;j)\right)^{2}">无采样<!-- $O(N)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(N)">-非常快-
PRIS<!-- $-\sum_{j \in S} \frac{\exp \left(f_{\theta}(c, {\color{red}j})-\log Q({\color{red}j} \mid c)\right)}{\sum_{{j^{\prime}} \in S} \exp \left(f_{\theta}\left(c, {j^{\prime}}\right)-\log Q\left({j^{\prime}} \mid c\right)\right)} \log \left(\sigma\left(f_{\theta}(c, {\color{red}k})-f_{\theta}(c, {\color{red}j})\right)\right)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?-\sum_{j&space;\in&space;S}&space;\frac{\exp&space;\left(f_{\theta}(c,&space;{\color{red}j})-\log&space;Q({\color{red}j}&space;\mid&space;c)\right)}{\sum_{{j^{\prime}}&space;\in&space;S}&space;\exp&space;\left(f_{\theta}\left(c,&space;{j^{\prime}}\right)-\log&space;Q\left({j^{\prime}}&space;\mid&space;c\right)\right)}&space;\log&space;\left(\sigma\left(f_{\theta}(c,&space;{\color{red}k})-f_{\theta}(c,&space;{\color{red}j})\right)\right)">聚类采样<!-- $O(\lvert S\rvert)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(\lvert&space;S\rvert)"><!-- $O(K)$ --> <img style="transform: translateY(0.1em);" src="https://latex.codecogs.com/svg.image?\small&space;O(K)">非常快DCG
<p align="center"> <img src="https://yellow-cdn.veclightyear.com/0a4dffa0/c83117fd-9b9c-4bba-8ef4-af59a06e1bc4.png" alt="RecStudio v0.2 框架" width="600"> <br> <b>图片</b>:RecStudio 框架 </p>

特性

  • 通用数据集结构 RecStudio 支持基于原子数据文件和自动数据缓存的统一数据集配置。
  • 模块化模型结构 通过将整个推荐器组织成不同的模块、损失函数、评分函数、采样器和人工神经网络,您可以像搭建积木一样定制您的模型。
  • GPU 加速 从模型训练到模型评估的整个操作可以轻松地在 GPU 和分布式 GPU 上运行。
  • 简单的模型分类 RecStudio 根据编码器的数量对所有模型进行分类,易于理解和使用。这种分类方法可以涵盖所有模型。
  • 简单和复杂的负采样器 RecStudio 仅使用张量运算符集成了静态和基于模型的采样器。

快速开始

通过下载源代码,您可以运行提供的脚本 run.py 来初步使用 RecStudio。

python run.py

初始配置将在 MovieLens-100k(ml-100k)数据集上训练和评估 BPR 模型。

一般来说,这个简单的示例在 GPU 上运行不到一分钟。输出将类似于以下内容:

[2023-08-24 10:51:41] INFO 日志保存在 /home/RecStudio/log/BPR/ml-100k/2023-08-24-10-51-41-738329.log。 [2023-08-24 10:51:41] INFO 全局种子设置为 2022 [2023-08-24 10:51:41] INFO 数据集从 /home/RecStudio/recstudio/dataset_demo/ml-100k 读取。 [2023-08-24 10:51:42] INFO 数据集信息: ============================================================================= 交互信息: 字段 用户ID 物品ID 评分 时间戳 类型 标记 标记 浮点数 浮点数 ## 944 1575 - - ============================================================================= 用户信息: 字段 用户ID 年龄 性别 职业 邮编 类型 标记 标记 标记 标记 标记 ## 944 62 3 22 795 ============================================================================= 物品信息: 字段 物品ID 类型 标记 ## 1575 ============================================================================= 总交互数:82520 稀疏度:0.944404 ============================================================================= 时间戳=StandardScaler() [2023-08-24 10:51:42] INFO 模型配置: 数据: 二值化评分阈值= fm评估=False 负样本数=0 采样器= 随机打乱=True 划分模式=用户条目 划分比例=[0.8, 0.1, 0.1] fm评估=False 二值化评分阈值=0.0 评估: 批量大小=20 截断=[5, 10, 20] 验证指标=['ndcg', '召回率'] 验证周期=1 测试指标=['ndcg', '召回率', '精确率', 'map', 'mrr', '命中率'] topk=100 保存路径=./saved/ 模型: 嵌入维度=64 物品偏置=False 训练: 加速器=gpu 近似最近邻= 批量大小=512 早停模式=最大化 早停耐心值=10 训练轮数=1000 gpu=1 梯度裁剪范数= 初始化方法=xavier_normal 物品批量大小=1024 优化器=adam 学习率=0.001 线程数=10 采样方法= 采样器=均匀 负样本数=1 排除历史=False 学习率调度器= 随机种子=2022 权重衰减=0.0 tensorboard路径=[2023-08-24 10:51:42] 信息 Tensorboard日志保存在 ./tensorboard/BPR/ml-100k/2023-08-24-10-51-41-738329。 [2023-08-24 10:51:42] 信息 默认使用的字段设置为[用户ID, 物品ID, 评分]。如需更多字段,请使用"self._set_data_field()"重新设置。 [2023-08-24 10:51:42] 信息 保存目录:./saved/ [2023-08-24 10:51:42] 信息 BPR( (得分函数): 内积评分器() (损失函数): BPR损失() (物品编码器): 嵌入(1575, 64, padding_idx=0) (查询编码器): 嵌入(944, 64, padding_idx=0) (采样器): 均匀采样器() ) [2023-08-24 10:51:42] 信息 选择了GPU [8][2023-08-24 10:51:45] 信息 训练: 轮次= 0 [ndcg@5=0.0111 召回率@5=0.0044 训练损失_0=0.6931] [2023-08-24 10:51:45] 信息 训练时间: 0.88524秒。验证时间: 0.18036秒。GPU内存: 0.03/10.76 GB [2023-08-24 10:51:45] 信息 ndcg@5有所提升。最佳值: 0.0111 [2023-08-24 10:51:45] 信息 最佳模型检查点保存在 ./saved/BPR/ml-100k/2023-08-24-10-51-41-738329.ckpt。 ... [2023-08-24 10:52:08] 信息 训练: 轮次= 34 [ndcg@5=0.1802 召回率@5=0.1260 训练损失_0=0.1901] [2023-08-24 10:52:08] 信息 训练时间: 0.41784秒。验证时间: 0.32394秒。GPU内存: 0.03/10.76 GB [2023-08-24 10:52:08] 信息 提前停止。由于指标ndcg@5在10轮内未有改善。 [2023-08-24 10:52:08] 信息 ndcg@5的最佳分数是0.1807,出现在第24轮 [2023-08-24 10:52:08] 信息 最佳模型检查点保存在 ./saved/BPR/ml-100k/2023-08-24-10-51-41-738329.ckpt。 [2023-08-24 10:52:08] 信息 测试: [ndcg@5=0.2389 召回率@5=0.1550 精确率@5=0.1885 map@5=0.1629 mrr@5=0.3845 命中率@5=0.5705 ndcg@10=0.2442 召回率@10=0.2391 精确率@10=0.1498 map@10=0.1447 mrr@10=0.4021 命中率@10=0.6999 ndcg@20=0.2701 召回率@20=0.3530 精确率@20=0.1170 map@20=0.1429 mrr@20=0.4109 命中率@20=0.8240] 如果你想更改模型或数据集,命令行已经准备就绪。 ```bash python run.py -m=NCF -d=ml-1m
  • 支持的命令行参数:

    参数类型描述默认值可选项
    -m,--model字符串模型名称BPRRecStudio中的所有模型
    -d,--dataset字符串数据集名称ml-100kRecStudio支持的所有数据集
    --data_dir字符串数据集文件夹datasetsRecStudio可以读取的文件夹
    mode字符串训练模式light['light','detail','tune']
    --learning_rate浮点数学习率0.001
    --learner字符串优化器名称adam['adam','sgd','adasgd','rmsprop','sparse_adam']
    --weight_decay浮点数优化器的权重衰减0
    --epochs整数训练轮数20,50
    --batch_size整数训练时的小批量大小2048
    --eval_batch_size整数评估时的小批量大小128
    --embed_dim整数嵌入层的输出大小64
  • 对于"ItemTowerRecommender",还支持一些额外的参数:

    参数类型描述默认值可选项
    --sampler字符串采样器名称uniform['uniform','popularity','midx_uni','midx_pop','cluster_uni','cluster_pop']
    --negative_count整数负样本数量1正整数
  • 对于"TwoTowerRecommender",在"ItemTowerRecommender"的基础上还支持一些额外的参数:

    参数类型描述默认值可选项
    --split_mode字符串数据集的划分方法user_entry['user','entry','user_entry']

以下是一些不明确参数的详细说明。

  1. mode:在light模式和detail模式下,输出将显示在终端上,后者提供更详细的信息。tune模式将使用神经网络智能(NNI)显示一个漂亮的可视化界面。你可以使用类似config.yaml的配置文件运行tune.sh。有关NNI的更多详情,请参阅NNI文档

  2. sampleruniform表示使用均匀采样器。popularity表示根据物品流行度进行采样(更受欢迎的物品被采样的概率更高)。midx_unimidx_popmidx动态采样器,更多详情请参阅FastVAEcluster_unicluster_popcluster动态采样器,更多详情请参阅PRIS

  3. split_modeuser表示将所有用户分成训练/验证/测试数据集,这些数据集中的用户是互不相交的。entry表示将所有交互分成这三个数据集。user_entry表示将每个用户的交互分成三部分。

此外,你可以通过PyPi安装RecStudio:

pip install recstudio

基本用法如下:

import recstudio recstudio.run(model="BPR", data_dir="./datasets/", dataset='ml-100k')

更详细的信息,请参阅我们的文档 http://recstudio.org.cn/docs/。

自动超参数调优

RecStudio集成了NNI模块,用于自动调优超参数。为了简单使用,你可以在bash中运行以下命令:

nnictl create --config ./nni-experiments/config/bpr.yaml --port 2023

根据个人需求配置nni-experiments/config/bpr.yamlnni-experiments/search_space/bpr.yaml

有关NNI的更多详细信息,请参阅NNI文档

贡献

如果你遇到bug或有任何建议,请通过提交问题让我们知道。

我们欢迎所有贡献,从修复bug到新功能和扩展。

我们希望所有贡献首先在问题追踪器中讨论,然后通过PR进行。

团队

RecStudio由USTC BigData Lab开发和维护。

用户贡献
@DefuLian框架设计和构建
@AngusHuang17序列模型、文档、修复bug
@Xiuchen519基于知识的模型、修复bug
@JennahFNCF、CML、logisticMF模型
@HERECJAutoEncoder模型
@BinbinJinIRGAN模型
@pepsi2222排序模型
@echobelbo文档
@jinbaobaojhr文档

许可证

RecStudio使用MIT许可证

编辑推荐精选

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

下拉加载更多