
网站链接: https://tinytraining.mit.edu/
论文链接: https://arxiv.org/abs/2206.15472
Demo: https://www.bilibili.com/video/BV1qv4y1d7MV
代码链接: https://github.com/mit-han-lab/tiny-training
导读
说到神经网络训练,大家的第一印象都是 GPU + 服务器 + 云平台。传统的训练由于其巨大的内存开销,往往是云端进行训练而边缘平台仅负责推理。然而,这样的设计使得 AI 模型很难适应新的数据:毕竟现实世界是一个动态的,变化的,发展的场景,一次训练怎么能覆盖所有场景呢?
为了使得模型能够不断的适应新数据,我们能否在边缘进行训练(on-device training),使设备不断的自我学习?在这项工作中,我们仅用了不到 256KB 内存就实现了设备上的训练,开销不到 PyTorch 的 1/1000,同时在视觉唤醒词任务上(VWW)达到了云端训练的准确率。该项技术使得模型能够适应新传感器数据。用户在享受定制的服务的同时而无需将数据上传到云端,从而保护隐私。
贡献
设备上的训练(On-device Training)允许预训练的模型在部署后适应新环境。通过在移动端进行本地训练和适应,模型可以不断改进其结果并为用户定制模型。通过让训练更在终端进行而不是云端,我们能有效在提升模型质量的同时保护用户隐私,尤其是在处理医疗数据、输入历史记录这类隐私信息时。
然而,在小型的 IoT 设备进行训练与云训练有着本质的区别,非常具有挑战性,首先, AIoT 设备(MCU)的 SRAM 大小通常有限(256KB)。这种级别的内存做推理都十分勉强,更不用说训练了。再者,现有的低成本高效转移学习算法,例如只训练最后一层分类器 (last FC),只进行学习 bias 项,往往准确率都不尽如人意,无法用于实践,更不用说现有的深度学习框架无法将这些算法的理论数字转化为实测的节省。
最后,现代深度训练框架(PyTorch,TensorFlow)通常是为云服务器设计的,即便把 batch-size 设置为 1,训练小模型(MobileNetV2-w0.35)也需要大量的内存占用。因此,我们需要协同设计算法和系统,以实现智能终端设备上的训练。

▲ 传统框架训练需要的内存大大超过了智能终端设备的资源,我们所提出的协同设计,有效的将内存开销从几百 MB 降低至 256KB 以内。
方法
我们发现设备上训练有两个独特的挑战:(1)模型在边缘设备上是量化的。一个真正的量化图(如下图所示)由于低精度的张量和缺乏批量归一化层而难以优化;(2)小型硬件的有限硬件资源(内存和计算)不允许完全反向传播,其内存用量很容易超过微控制器的 SRAM 的限制(一个数量级以上),但如果只更新最后一层,最后的精度又难免差强人意。

为了应对优化的困难,我们提出了 Quantization-Aware Scaling(QAS)来自动缩放不同位精度的张量的梯度(如下左图所示)。QAS 在不需要额外超参数的同时,可以自动匹配梯度和参数 scale 并稳定训练。在 8 个数据集上,QAS 均可以达到与浮点训练一致的性能(如下右图)。

为了减少反向传播所需要的内存占用,我们提出了 Sparse Update,以跳过不太重要的层和子张的梯度计算。我们开发了一种基于贡献分析的自动方法来寻找最佳更新方案。对比以往的 bias-only, last-k layers update, 我们搜索到的 sparse update 方案拥有 4.5 倍到 7.5 倍的内存节省,在 8 个下游数据集上的平均精度甚至更高。

为了将算法中的理论减少转换为实际数值,我们设计了 Tiny Training Engine(TTE):它将自动微分的工作转到编译时,并使用 codegen 来减少运行时开销。它还支持 graph pruning 和 reordering,以实现真正的节省与加速。
与 Full Update 相比,Sparse Update 有效地减少了 7-9 倍的峰值内存,并且可以通过 reorder 进一步提升至 20-21 倍的总内存节省。相比于 TF-Lite,TTE 里经过优化的内核和 sparse update 使整体训练速度提高了23-25倍。

实验
本文主要采用MobileNetV2、ProxylessNAS、MCUNet三个三类模型进行了实验。我们在ImageNet [22]上对模型进行了预训练,并进行了训练后量化[34]。量化模型在下游数据集上进行微调,以评估迁移学习能力。我们在一个微控制器STM32F746(320KB SRAM,1MB Flash)上进行训练和内存/延迟测量。为了更快地获得多个下游数据集上的精度统计数据,我们在gpu上模拟了训练结果,并验证了仿真与在微控制器上的训练相比获得了相同的精度水平。
我们对MCUNet的最后两个块(模拟低成本微调)进行微调,以各种下游数据集(表1)。对于动量SGD,由于优化困难,量化模型(int8)的训练精度落后于浮点模型。像Adam [36]这样的自适应学习速率优化器可以提高准确性,但仍然低于fp32的微调结果;由于二阶动量,它还会消耗3×的内存,这对于tinyML设置是不需要的。尽管有广泛的超参数调优,但LARS [69]仍然不能很好地收敛于大多数数据集(在学习率和“信任系数”上)。我们假设LARS的侵略性梯度尺度规则使训练不稳定。当我们应用QAS时,精度差距缩小了,在没有额外内存成本的情况下匹配浮点训练的精度。图8还提供了MCUNet在汽车数据集上和无QAS上的学习曲线(微调)。因此,QAS有效地有助于优化。
如图9所示,由于学习能力有限,仅进行分类器更新的准确性较低。这表明仅更新分类器是不够的,还需要更新主干。我们进一步测量了STM32FM746MCU上每张图像的训练延迟:TF-Lite微内核的完全更新,TF-Lite微内核的稀疏更新,TTE内核的稀疏更新(图10(c))。
我们在图11(下图(b),包含10个类)中搜索到100KB额外内存(分析)。它更新了最后22层的偏差,并稀疏地更新了6层的权重(有些是子张量更新)。最初的20层被冻结,只向前运行。为了理解为什么这个方案是有意义的,我们还绘制了在上子图(a).中更新每一层时的激活和权重所产生的内存成本我们看到了一个清晰的模式:初始层的激活成本很高;结束层的权重成本很高;而当我们更新中间层(层索引18-30)时,总内存成本很低。
内容中包含的图片若涉及版权问题,请及时与我们联系删除
评论
沙发等你来抢