论文链接: https://openreview.net/pdf?id=luGXvawYWJ
代码链接: https://github.com/Huage001/DatasetFactorization
简介
深度学习取得了巨大成功,训练一般需要大量的数据。存储、传输和数据集预处理成为大数据集使用的阻碍。另外发布原始数据可能会有隐私版权等问题。
数据集蒸馏(Dataset Distillation)是一种解决方案,通过蒸馏一个数据集形成一个只包含少量样本的合成数据集,同时训练成本显著降低。数据集蒸馏可以用于持续学习、神经网络架构搜索等领域。
最早提出的数据集蒸馏算法核心思想即优化合成数据集,在下游任务中最小化损失函数。DSA( Dataset condensation with differentiable siamese augmentation)、GM( Dataset condensation with gradient matching)、CS(Dataset condensation with contrastive signals)等方法提出匹配真实数据集和合成数据集的梯度信息的算法。MTT(Dataset distillation by matching training trajectories)指出由于跨多个步骤的误差累计,单次迭代的训练误差可能导致较差的性能,提出在真实数据集上匹配模型的长期动态训练过程。除了匹配梯度信息的方法,DM(Dataset condensation with distribution matching)提出了匹配数据集分布,具体方法是添加最大平均差异约束( Maximum Mean Discrepancy,MMD)。
本文方法将合成数据集分解为两个部分:数据幻觉器网络(Data Hallucination Network)和基础数据(Bases)。数据幻觉器网络将基础数据作为输入,输出幻觉图像(合成图像)。在数据幻觉器网络训练过程中,本文考虑添加特殊设计的对比学习损失和一致性损失。本文方法得到的合成数据集在跨架构任务中比基准方法取得了精度10%的提升。
方法
传统的数据集蒸馏方法将合成样本独立处理,忽略了不同样本间的内部关系,可能导致较差的数据效率。本文方法提出将数据集蒸馏定义为包含H个幻觉器和B个基础样本的幻觉器-基(hallucinator-basis)的分解问题:
训练过程时,训练数据通过传入第i个基础数据在线生成。合成数据可以表示为:\( \hat{x}_{ij}=H_{\theta_j}(\hat{x}_{I}),\hat{y}_{ij}=\hat{y}_{j} \)。
基与幻觉器
先前数据集蒸馏方法中,为了在下游模型中输入和输出的形状保持一致,合成数据的形状需要与真实数据相同。由于幻觉器网络可以使用空间和通道变换,本文方法没有形状相同限制。
给了基数据\( \hat{x} \in R^{h' \times w' \times c'} \),一个幻觉器网络,目标创建一个输出\( \hat{x} \in R^{h \times w \times c} \)。该任务可以视作为一个条件图像生成问题。借鉴于风格迁移任务,本文的幻觉器网络设计为encoder-transform-decoder架构。编码器由若干卷积层组成,将输入非线性映射。之后经过尺度\( \sigma \)和位移\( \mu \)的仿射变化。\( \sigma \)和\( \mu \)是网络参数。解码器是和编码器对称的CNN网络将特征映射到图像空间。
对抗性对比约束
本文的幻觉器网络训练过程是一个最小-最大博弈(min-max game)过程。最大化过程即最大化不同幻觉器间的差异。输入\( x_{ij} \)在幻觉器最后一层的输出定义为\( F_{-1}(x_{ij}) \)。损失函数类似于对比学习,可以描述为:
对于图像分类任务,
另一个损失函数关注于减少幻觉器网络输出\( \hat{x}_{ij} \)与\( \hat{x}_{ik} \)间的差异,核心目标是增加合成数据集的数据多样性。损失函数可以描述为:
分解训练方法
与先前的数据集蒸馏方法训练范式类似,合成数据集按照迭代算法更新。每一个迭代周期,随机选取幻觉器和基,形成若干幻觉器-基组合。训练的损失函数包含知识蒸馏损失与一致性损失:
本文的数据集蒸馏损失函数采用MTT方法。核心思想是使用训练周期为的模型权重,使用合成数据集训练次,使用真实数据集训练次,通过损失函数使合成数据集更新的参数与真实数据集更新的参数保持一致:
实验
与SOTA方法的比较结果。比较的方法包括核心集算法(Coreset),数据集蒸馏方法(元学习方法DD、LD,训练匹配方法DC、DSA、DSA,分布匹配方法DM、CAFE)和本文方法Factorization。超参数,每一类合成样本数(IPC)[1,10,50],本文的每一类基数量(BPC)[1,9,49]。
下图给出了实验结果。可以看出本文方法取得了最高的精度,在合成数据集样本数小于1%时性能差异最为显著。
与不同合成数据集生成算法和不同卷积神经网络模型组合的比较实验。在AlexNet网络的实验中,本文的方法与MTT相比最高取得了17.57%的性能提升。
不同类别是否共享幻觉器的Ablation实验。在相同的BPC条件下,较少的合成样本数情况下不共享幻觉器的方法(w/o share)可以获得更好的性能。较多的BPC情况下,不共享幻觉器方法不能获得更好的性能。主要原因:1)共享幻觉器方法可以获得数据集的全局信息。2)不共享幻觉器的方法给优化过程较大的负担
本文方法基和幻觉器生成图像的可视化如下:
内容中包含的图片若涉及版权问题,请及时与我们联系删除
评论
沙发等你来抢