Make the Most of Your Data: Changing the Training Data Distribution to Improve In-distribution Generalization Performance

2024年04月27日
  • 简介
    我们能否修改训练数据分布,以鼓励底层优化方法找到在分布数据上具有优越泛化性能的解决方案?在这项工作中,我们首次比较了梯度下降(GD)的归纳偏差和锐度感知最小化(SAM)的归纳偏差,来探讨这个问题。通过研究一个两层的CNN,我们证明了SAM更均匀地学习易于和困难的特征,特别是在早期阶段。也就是说,相比GD,SAM更不容易受到简单性偏差的影响。基于这个观察结果,我们提出了一种算法USEFUL,它根据训练早期的网络输出对示例进行聚类,并对没有易于特征的示例进行上采样,以减轻简单性偏差的缺陷。我们通过模仿SAM的训练动态,实证地证明了这种修改训练数据分布的方法在使用(S)GD进行训练时有效地提高了原始数据分布上的泛化性能。值得注意的是,我们证明了我们的方法可以与SAM和现有的数据增强策略相结合,实现在CIFAR10、STL10、CINIC10、Tiny-ImageNet上训练ResNet18,在CIFAR100上训练ResNet34,以及在CIFAR10上训练VGG19和DenseNet121等方面达到我们所知道的最先进的性能。
  • 图表
  • 解决问题
    论文旨在探讨如何修改训练数据分布以提高梯度下降等优化方法在分布数据上的泛化性能。研究比较了梯度下降和锐度感知最小化两种算法的归纳偏差,并提出了一种基于锐度感知最小化的算法USEFUL,通过聚类和上采样的方式改变训练数据分布,以减轻简单性偏差的影响。
  • 关键思路
    论文的关键思路是使用锐度感知最小化算法来改善优化算法在分布数据上的泛化性能,并通过USEFUL算法改变训练数据分布,以减轻简单性偏差的影响。
  • 其它亮点
    论文通过实验验证了USEFUL算法的有效性,使ResNet18在CIFAR10、STL10、CINIC10、Tiny-ImageNet数据集上达到了目前最好的性能。论文还提出了一种新的算法SAM,相比梯度下降算法,SAM更加均匀地学习简单和复杂的特征。论文开源了代码。
  • 相关研究
    相关研究包括使用数据增强和模型集成来提高模型的泛化性能,以及使用不同的优化算法来改善模型的训练效果。
PDF
原文
点赞 收藏 评论 分享到Link

沙发等你来抢

去评论