2023年11月更新:SePiCo被选为 :trophy: <span style="color:red">ESI高被引论文</span>!!
2023年2月15日更新:发布Cityscapes → Dark Zurich的代码。
2023年1月14日更新:🥳 我们很高兴地宣布SePiCo已被TPAMI接收并将在即将出版的一期中发表。
2022年9月24日更新:所有检查点均已可用。
2022年9月4日更新:代码发布。
2022年4月20日更新:SePiCo的ArXiv版本已发布。
在这项工作中,我们提出了语义引导的像素对比学习(SePiCo),这是一种新颖的单阶段适应框架,它突出了单个像素的语义概念,以促进跨域类别区分性和类别平衡的像素嵌入空间的学习,最终提升自训练方法的性能。
<img src="https://yellow-cdn.veclightyear.com/835a84d5/8b12dafb-7541-461f-a63d-bc14f65a3866.png" width=50% height=50%> <div align="right"> <b><a href="#概述">↥</a></b> </div>本代码使用Python 3.8.5
和PyTorch 1.7.1
在CUDA 11.0
上实现。
要尝试这个项目,建议先设置一个虚拟环境:
# 创建并激活环境 conda create --name sepico -y python=3.8.5 conda activate sepico # 为新的Python环境安装正确的pip和依赖项 conda install -y ipython pip
然后,可以通过以下方式安装依赖项:
# 安装所需的包 pip install -r requirements.txt # 安装mmcv-full,此命令在本地编译mmcv,可能需要一些时间 pip install mmcv-full==1.3.7 # 需要先安装其他包
或者,可以使用官方预构建的包更快地安装mmcv-full
,例如:
# 另一种安装mmcv-full的方法,更快 pip install mmcv-full==1.3.7 -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
现在环境已经完全准备好了。
<div align="right"> <b><a href="#概述">↥</a></b> </div>创建所需数据集的符号链接:
ln -s /path/to/gta5/dataset data/gta ln -s /path/to/cityscapes/dataset data/cityscapes ln -s /path/to/dark_zurich/dataset data/dark_zurich
进行预处理以将标签ID转换为训练ID并收集数据集统计信息:
python tools/convert_datasets/gta.py data/gta --nproc 8 python tools/convert_datasets/cityscapes.py data/cityscapes --nproc 8
最终,数据结构应该如下所示:
<div align="right"> <b><a href="#overview">↥</a></b> </div>SePiCo ├── ... ├── data │ ├── cityscapes │ │ ├── gtFine │ │ ├── leftImg8bit │ ├── dark_zurich │ │ ├── corresp │ │ ├── gt │ │ ├── rgb_anon │ ├── gta │ │ ├── images │ │ ├── labels ├── ...
我们通过Google Drive和百度网盘(访问码:pico
)提供了两个域适应语义分割任务的预训练模型。
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_gta2city_dlv2.pth | 61.0 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_gta2city_dlv2.pth | 59.8 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_gta2city_dlv2.pth | 58.8 | Google / 百度 (提取码: pico ) |
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_gta2city_daformer.pth | 70.3 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_gta2city_daformer.pth | 68.7 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_gta2city_daformer.pth | 68.5 | Google / 百度 (提取码: pico ) |
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_syn2city_dlv2.pth | 58.1 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_syn2city_dlv2.pth | 57.4 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_syn2city_dlv2.pth | 56.8 | Google / 百度 (提取码: pico ) |
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_syn2city_daformer.pth | 64.3 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_syn2city_daformer.pth | 63.3 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_syn2city_daformer.pth | 62.9 | 谷歌 / 百度 (提取码: pico ) |
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_city2dark_dlv2.pth | 45.4 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_city2dark_dlv2.pth | 44.1 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_city2dark_dlv2.pth | 42.6 | 谷歌 / 百度 (提取码: pico ) |
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_city2dark_daformer.pth | 54.2 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_distcl_city2dark_daformer.pth | 53.3 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_distcl_city2dark_daformer.pth | 52.7 | 谷歌 / 百度 (提取码: pico ) |
我们训练的模型(sepico_distcl_city2dark_daformer.pth)也在Nighttime Driving和BDD100k-night测试集上进行了泛化性能测试。
方法 | 模型名称 | Dark Zurich-test | Nighttime Driving | BDD100k-night | 检查点下载 |
---|---|---|---|---|---|
SePiCo | sepico_distcl_city2dark_daformer.pth | 54.2 | 56.9 | 40.6 | 谷歌 / 百度 (提取码: pico ) |
要在Cityscapes上评估预训练模型,请按如下方式运行:
<details> <summary>示例</summary>python -m tools.test /path/to/config /path/to/checkpoint --eval mIoU