这是一个基于PyTorch的开源在线测试时适应仓库。它是Robert A. Marsden和Mario Döbler的联合作品。这也是以下作品的官方仓库:
@article{marsden2022gradual,
title={Gradual test-time adaptation by self-training and style transfer},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
journal={arXiv preprint arXiv:2208.07736},
year={2022}
}
@inproceedings{dobler2023robust,
title={Robust mean teacher for continual and gradual test-time adaptation},
author={D{\"o}bler, Mario and Marsden, Robert A and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={7704--7714},
year={2023}
}
@inproceedings{marsden2024universal,
title={Universal Test-time Adaptation through Weight Ensembling, Diversity Weighting, and Prior Correction},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
pages={2555--2565},
year={2024}
}
@article{dobler2024lost,
title={A Lost Opportunity for Vision-Language Models: A Comparative Study of Online Test-time Adaptation for Vision-Language Models},
author={D{\"o}bler, Mario and Marsden, Robert A and Raichle, Tobias and Yang, Bin},
journal={arXiv preprint arXiv:2405.14977},
year={2024}
}
</details>
我们欢迎贡献!非常欢迎并感谢添加方法的拉取请求。
要使用这个仓库,我们提供了一个conda环境。
conda update conda conda env create -f environment.yml conda activate tta
这个仓库包含了一系列不同的方法、数据集、模型和设置,我们在一个全面的基准测试中对其进行了评估(见下文)。我们还提供了一个关于如何将此仓库与CLIP类模型结合使用的教程,可以在这里找到。 以下是仓库主要特性的简要概述:
数据集
cifar10_c
CIFAR10-Ccifar100_c
CIFAR100-Cimagenet_c
ImageNet-Cimagenet_a
ImageNet-Aimagenet_r
ImageNet-Rimagenet_v2
ImageNet-V2imagenet_k
ImageNet-Sketchimagenet_d
ImageNet-Dimagenet_d109
domainnet126
DomainNet (清洗后)持续变化的损坏
CCC模型
设置
reset_each_shift
适应一个域后重置模型状态。continual
在一系列域上训练模型,不知道域转移何时发生。gradual
在一系列逐 渐增加/减少的域转移上训练模型,不知道域转移何时发生。mixed_domains
在一个长测试序列上训练模型,其中连续的测试样本可能来自不同的域。correlated
与持续设置相同,但每个域的样本进一步按类别标签排序。mixed_domains_correlated
混合域并按类别标签排序。gradual_correlated
或reset_each_shift_correlated
。方法
混合精度训练
模块化设计
要运行以下基准测试之一,需要下载相应的数据集。
下载缺失的数据集后,您可能需要调整位于conf.py
文件中的根目录路径_C.DATA_DIR = "./data"
。对于各个数据集,目录名称在conf.py
中以字典形式指定(参见complete_data_dir_path
函数)。如果您的目录名称与映射字典中指定的不同,您可以简单地修改它们。
我们为所有实验和方法提供了配置文件。只需使用相应的配置文件运行以下Python文件。
python test_time.py --cfg cfgs/[ccc/cifar10_c/cifar100_c/imagenet_c/imagenet_others/domainnet126]/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml
对于imagenet_others,需要传递CORRUPTION.DATASET
参数:
python test_time.py --cfg cfgs/imagenet_others/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml CORRUPTION.DATASET [imagenet_a/imagenet_r/imagenet_k/imagenet_v2/imagenet_d109]
例如,要运行ROID进行ImageNet-to-ImageNet-R基准测试,请运行以下命令。
python test_time.py --cfg cfgs/imagenet_others/roid.yaml CORRUPTION.DATASET imagenet_r
或者,您可以通过运行classification/scripts
子目录中的run.sh
来重现我们的实验。对于不同的设置,修改run.sh
中的setting
。
要运行不同的连续DomainNet-126序列,您必须传递MODEL.CKPT_PATH
参数。如果不指定CKPT_PATH
,将使用以real域作为源域的序列。这些检查点由AdaContrast提供,可以在这里下载。从结构上讲,最好将它们下载到./ckpt/domainnet126
目录中。
python test_time.py --cfg cfgs/domainnet126/rmt.yaml MODEL.CKPT_PATH ./ckpt/domainnet126/best_clipart_2020.pth
对于GTTA,我们提供了风格转换网络的检查点文件。这些检查点可在
Google-Drive(下载);
将zip文件解压到classification
子目录中。
更改评估配置非常简单。例如,要在reset_each_shift
设置下使用ResNet-50和IMAGENET1K_V1
初始化在ImageNet-to-ImageNet-C上运行TENT,需要传递以下参数。
更多模型和初始化可以在这里(torchvision)或这里(timm)找到。
python test_time.py --cfg cfgs/imagenet_c/tent.yaml MODEL.ARCH resnet50 MODEL.WEIGHTS IMAGENET1K_V1 SETTING reset_each_shift
对于ImageNet-C,robustbench提供的默认图像列表每个域考虑5000个样本
(参见这里)。如果你有兴趣在全部
50,000个测试样本上运行实验,只需设置CORRUPTION.NUM_EX 50000
,即
python test_time.py --cfg cfgs/imagenet_c/roid.yaml CORRUPTION.NUM_EX 50000
我们支持大多数方法使用损失缩放的自动混合精度更新。
默认情况下混合精度设置为false。要激活混合精度,设置参数MIXED_PRECISION True
。
我们在这里提供了每种方法使用不同模型和设置的详细结果, 基准测试会定期更新,随着新方法、数据集或设置添加到仓库中。 关于设置或模型的更多信息也可以在我们的论文中找到。
要运行基于CarlaTTA的实验,你首先需要下载如下提供的数据集分割。同样,你可能需要在conf.py
中更改数据目录_C.DATA_DIR = "./data"
。此外,你需要下载预训练的源检查点(下载