一、论文相关信息
1.论文题目
Distilling the Knowledge in a Neural Network
2.论文时间
2015年
3.论文文献
https://arxiv.org/abs/1503.02531
二、论文背景及简介
提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均。不幸的是,当使用集成模型做预测时,是十分笨重的,部署时,会花费巨大的计算代价。Caruana等人已经证明,可以将集成模型的知识压缩到单个模型,这样就可以让部署更加容易。这篇文章就是使用了另一种不同的压缩方法。作者表明,通过将模型集合中的知识提取到单个模型中,可以显著地改进大量使用的商业系统的声学模型。作者也引入了一种由一个或多个完整模型和许多特有模型组成的新的集成的类型,其中特有模型能有区分在完整模型中混淆的细粒度类。与那些expert模型不同,这些特有模型能够快速并行的进行训练。
三、论文内容总结
- Introduction总结
- 生物在生命的不同阶段具有不同的需求,也因此具有不同的结构与功能。机器学习也一样,在训练与部署阶段具有不同的需求,因此我们需要在不同的阶段做不同的事情。本文所提出的蒸馏,便是将训练阶段的模型通过变形使其适应部署阶段的需求。
- Caruana等人提出,可以将一个大的冗杂的模型的知识转移到一个小的更适合部署的模型中去,这种方法就叫做蒸馏(“distillation”)
- 我们很难去评价一个模型是否成功的转化到了另一个模型,特别是对于该模型的泛化能力来说,因为我们并没有这样的先验知识。而将其泛化能力进行转化,就是我们要讲到的重点,对对抗样本的防御方法。
- 将大模型的泛化能力转化为小模型的一个方法是利用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)。
- 在Caruana等人做蒸馏时,出现了一些问题(见Intorduction第5段)。作者为了解决这个问题,提出,增加softmax的温度,直到其soft target比较合适位置。
- 我们对小模型进行训练时,使用的是tranfer集,该数据集可以是原本的数据集,也可以是一些不带标签的数据组成的
- Distillation方法
- 利用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)
- 使用大模型产生的logits作为目标,是蒸馏的一个特例
- 作者在MINST数据集和语音识别任务上都进行了实验
- 作者提出了一个训练大型集成模型的方法
- 集成模型包括一个通用模型和多个”specialist“模型,每一个”specialist“模型专注于一部分易于区分的类别,把剩下的类别作为垃圾类。采用了一个独特的推断方式进行预测。
- 使用soft targets进行训练时,可以提高模型的泛化能力,防止过拟合,而且只用一小部分训练集就可以训练的很好。
附:如需继续学习对抗样本其他内容,请查阅对抗样本学习目录
四、论文主要内容
1、Introduction
许多昆虫会有一个用于从外界环境吸取养分的幼体模式,以及一个完全不同的用于旅行和繁殖的成年模式,这两种模式对应于不同的需求。但是在大量的机器学习方法中,我们都是使用相似的模型用于训练阶段和部署阶段。对于像语音和目标识别这样的任务,训练必须从非常大、高度冗余的数据集中提取结构,但不需要实时操作,而且需要大量的计算。但是,当部署给大量用户时,需要对延迟和计算资源有着更加严格的要求。从昆虫的思想考虑,我们在训练阶段要训练一个复杂的模型,以便为了我么能够很轻松的从数据中提取结构,这个模型可能是一个集成模型,当模型训练完成后,我们采用另一种训练方法,在这里我们称之为蒸馏(“distillation”),将模型学习到的知识转移到另一个小的更适合部署的模型中去。这种策略最早Caruana和他的合作者们所提出。在他们的论文中,该策略也已经被证实。
一个概念块可能阻止了对这种非常有前途的方法的更多研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何改变模型的形式,但保持相同的知识。将知识从任何特定实例化中解放出来的更抽象的知识视图是,它是从输入向量到输出向量的学习映射。对于学习区分大量类的大模型,通常的训练目标是使正确答案的平均对数概率最大化,但学习的副作用是训练模型为所有错误答案分配概率,即使这些概率非常小,也有一些比其他的大很多。不正确答案的相对概率告诉我们很多关于大模型如何趋向于泛化的信息。例如,一辆宝马车的图片,被误认为是垃圾车的可能性很小,但这个错误的可能性仍然比误认为是胡萝卜的可能性大很多倍。
众所周知,目标函数应该尽可能反映用户的真实目标。比如,当我们的目标函数目的是为了使模型更好的泛化到新数据时,我们就能够训练出一个泛化能力好的模型,但是,这需要我们对泛化要有足够的认识、信息。而我们目前并没有这些信息。当我们将大模型蒸馏到小模型时,我们可以训练这个小模型使其具有与大模型一样的泛化能力。而且,我们用训练大模型的方式去训练小模型,其小模型的泛化能力通常比用正常训练方式得到的小模型的泛化能力要好得多。
将大模型的泛化能力转化为小模型的一个方法是利用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)。在这个转移阶段,我们可以使用同样的训练集或者使用一个单独“transfer”数据集。当大模型是集成模型时,我们根据每个模型的贡献率计算类概率的算数均值或者几何均值作为”soft targets”。当”soft targets”具有很高的熵时,当对于”hard targets”,其在训练的情况下,能够提供更多的信息量,且梯度变化也要小得多。因此小模型往往比原始的大模型能够在更少的数据集上及逆行训练,并且可以使用更高的学习率。
对于MNIST这样的任务来说,大模型总能够以很高的置信度生成正确的答案。关于所学习函数的大部分信息都存在于软目标中非常小的概率比中。比如:一张2的图片,在一个版本中,可能给予类别3 10^-6的概率,给与类别7 10^-9的概率,但是在另一个训练版本中,可能是相反的。这是有价值的信息,定义了数据上丰富的相似结构,但是这对于交叉熵损失函数来说却有着很小的影响,因为对交叉上来说,概率太小了,接近0了。Carunan等人通过使用logits损失函数规避了这个问题。他们最小化 大模型生成的logits和小模型生成的logits的平方差。我们更普遍的解决方案,称为“蒸馏”,是提高最终softmax的温度(下面会讲到),直到大模型产生一组合适的软目标。然后我们在训练小模型时使用相同的高温来匹配这些软目标。我们稍后将展示,匹配大模型的logits实际上是蒸馏的一个特例。
“transfer”集,可能是由没有标签的数据组成的,也可能时使用原始训练集。我们发现使用原始训练集效果很好,特别是如果我们在目标函数中加入一个小项,鼓励小模型预测真实目标,以及匹配笨重模型提供的软目标。通常,小模型不能精确地匹配软目标,朝着正确答案的方向出错是有帮助的。
2、Distillation
神经网络通常使用softmax来生成类别概率,其将logits(神经网络输出的值)$ z_i$ 转换成概率$ q_i$ :
T表示的softmax的温度,通常设为1,T越大,则softmax就能够生成更软的概率(更软的意思我在AI小知识系列第一讲中讲到)
蒸馏网络就是我们要转移到的那个小模型
在最简单的蒸馏形式中,通过使用transfer集和每个类别的软目标分布(使用 带有高温的softmax的大模型 生成的,就是指上面提到的各个类别的soft target的分布),对蒸馏网络进行训练从而将知识传递到蒸馏网络中。训练蒸馏模型时使用相同的高温,但训练后,进行部署时,使用的温度为1。
当所有或部分transfer集都有正确的标签时,上面的方法还可以通过训练蒸馏模型来生成正确的标签来得到显著改进。而让所有或者部分tranfer集有正确的标签的一种方法,是使用正确的标签(原本的离散的标签)来调整soft targets,但是我们发现了一种更好的方式,就是简单的使用两种不同目标函数的加权平均值。第一个目标函数是带有soft targets的交叉熵,该交叉熵使用跟蒸馏网络相同的高温来进行计算,其用于在大模型中生成软目标。第二个目标函数是带有正确标签(hard targets)的交叉熵,该交叉熵使用的温度为1,其使用的是蒸馏网络中softmax的logits。作者发现,当第二个目标函数具有相对较低的权重时,能够得到最好的结果。因为,通过soft targets训练而得到的网络的梯度的大小大概在$ 1/T^2$ ,因此,我们要把hard targets 与 soft targets都乘以$ T^2$ 。当蒸馏网络的温度改变(调参)时,这保证了hard targets 与 soft targets 的相对分布不会发生改变。
2.1 匹配logits是蒸馏的一个特例
在tranfer集上的每一个样本对logits都会有一个交叉熵梯度$ dC/dz_i $ ,如果大模型由logits$ v_i$ ,其对应的soft target 概率为$ p_i$ ,softmax的温度为T,则梯度可以表示为:
如果T比logits大,那么:
如果我们假设在transfer集上的所有logits的均值为0,即:$ \sum_j z_j = \sum_j v_j=0$ ,则:
因此,在高的T值以及logits的0均值的限制下,蒸馏等价于最小化$ \frac{1}{2}(z_i-v_i)^2$ 。而在低的温度下,蒸馏则不会特别的注重于匹配比平均值负得多logits,这是蒸馏的优势,因为被损失函数用作训练大模型的logits几乎没有限制条件,所以可能会存在很大的噪声。另一方面,非常负的logits,可能会传递由大模型产生的有用的信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获大模型中的所有知识时,中间温度的效果最好,这强烈表明忽略较大的负对数是有帮助的。
3、在MNIST数据集上进行初步实验
作者使用了一个具有两层隐藏层的网络(1200个神经元),使用了relu,dropout,权重约束。在测试集(10000张图片)上,只有67张图片测试错误,之后作者用了一个更小的网络,两个隐藏层,800个神经元,没有正则化手段,有146张测试错误。然后作者尝试对更小的网络进行蒸馏,作者将一个使用20温度的大网络迁移到这个小网络,迁移后,这个小网络只有74张测试错误。这表明soft target可以迁移大量的知识到蒸馏网络,包括如何从翻译后的训练数据中学习到的知识,即使转移集不包含任何翻译。
当蒸馏网络,每层隐藏层有300+神经元时,将温度设置为8+,有着相当相似的结果。但是,当每层隐藏层的神经元降到每层30个的时候,将温度设置为2.5~4,得到的效果比4+或者2.5-的温度要好。
之后作者尝试将数字3从transfer集中删除,对蒸馏网络来说,3是个神秘的数字,因为训练集里没有3。这样试验后,蒸馏网络将206张图片分类错误(133张为数字3,测试集上总共1010张3)。大部分分类错误的原因是,给类别3 分的权重太低了,如果将这个权重乘以3.5,那么就只有109张错误的,其中,只有14张是数字3。所以,当有着正确的权重时,蒸馏网络能够得到98.6%的对3的正确率。如果我们继续增大权重,增加到7倍或者8倍,那么蒸馏网络的错误率就上升到了47.3%。
4、在语音识别上进行试验
作者在语音识别任务上进行了实验,使用的ASR模型(Automatic Speech Recognition),发现使用蒸馏策略得到的模型要比直接训练得到的相同大小的模型的效果好得多。
5、在大数据集上进行训练出集成模型
训练集成模型是非常简单的,我们可以通过并行化的方式来训练。但是集成模型也有一些缺点,集成模型在测试阶段需要大量的计算(可以通过蒸馏方法来解决),另一个缺点就是,当单个模型是大的神经网络模型,并且数据集很大时,集成模型的训练需要消耗大量的计算资源。
在这一节,我们给除了这样的例子,我们展示了,如何学习到这样一个集成模型,在集成模型中,每个单独的模型都关注一个不同的子集,这样就减少了计算量,但这样也带来了问题,就是很容易过拟合。而这个过拟合可以通过使用soft targets来解决。
5.1 JFT数据集
JFT是谷歌内部的数据集,有1亿张图片,15000个类。Google基于JFT的backbone是一个深度卷积网络,在数据集上训练了6个月,用了两种并行化的方法。第一种是我们有很多份相同的神经网络分布在不同的核上,每一份神经网络处理一份batch,得到梯度,根据这些梯度得到一个新的梯度,使用新的梯度更新神经网络的梯度。第二种是将神经网络拆分,然后将其放置到不同的核中,就像AlexNet那样。集成训练时第三种并行化训练的方法。
很明显,用几个月的时间来训练一个模型并不是一个好的选择,所以我们需要找到一个更快的训练baseline的方法。
5.2 Specialist Models
当类别数目很大时,训练一个大型的集成网络(包含一个在全部数据集上训练的通用模型以及许多“specialist”模型)是很有意义的。这些“specialist”模型是由一个特别的训练集训练来的,这个训练集包含的是互相之间差别较大的图片。这样的方式可以让softmax损失最小,因为分类器不用去区分那些非常相似的图片,而分类器没有关注的类别会被分类到垃圾类别中。
为了减少过拟合,并且让这些“specialist”的模型能够共享一些低阶特征,每一个“specialist”模型都会用通用模型的权重来进行初始化,然后在他们特别的训练集上进行fine-tuning。这些特别的训练集是这样得到的,他们一半来自于特殊的子集,一半来自于训练集的补集(lable为垃圾类别)。在训练结束后,我们可以通过将垃圾类别的logits乘以抽样比的对数来修正。
5.3 将类别分配给模型
上面我们简单的介绍了训练集怎么得到,那么我们怎样得到这些特殊的自己呢?作者将目标转向于那些经常被分错的样本。尽管我们可以计算混淆矩阵并将其用作查找此类集群的方法,但我们选择了一种更简单的方法,不需要真正的标签来构造集群。‘
我们将聚类算法应用于我们的广义模型预测的协方差矩阵,一组经常一起预测的类$ S^m$ 将被用作我们的一个“specialist”模型m的目标,我们对协方差矩阵的列应用了在线版本的K-均值算法,并获得了合理的簇。
5.4 集成模型的推断
上面所讲到的集成模型的预测过程如下,给定一个输入图片$ x$ :
- 第一步,我们首先根据通用模型找到$ n$ 个最可能的类别,我们把它称之为类别集合$ k$ 。作者取n=1
- 第二步,我们找到所有的“specialist”模型,这些模型的训练集的混淆类别$ S^m$ 与$ k$ 相交不为空。把这些模型集合称之为$ A_k$ ,我们的目标就是找到一个概率分布$ q$ ,使得下面的式子最小:
其中$ KL$ 表示$ KL$ 散度,$ p^g$ 表示通用模型的类别概率,$ p^m$ 表示第$ m$ 个“specialist”模型得到的类别概率(包括一个垃圾类别)
我们会参数化$ q=softmax(z)(T=1)$ ,然后使用梯度下降来优化logits$ z$ 。
5.5 结果
使用上面的集成训练方法速度很快,在训练JFT训练集时只需要几天的时间(而不是几周),而且,所有的”specialist“模型的训练过程都是独立的。作者进行了实验,当使用61个”specialist“模型时,准确率还提高了4.4%。
在这次实验过程中,使用了61个”specialist“模型,每一个模型包括300个类别(包括垃圾类别)。作者发现,当一个类别被更多的”specialist“模型覆盖时,它预测的准确率就提升的越高。
6、Soft Targets as Regularizers
soft tartgets中带有一些有用的信息,这些是hard targets所不具备的。当我们使用soft targets的时候,我们可以用更少的数据集就能够训练出一个baseline,而且能够减少模型参数的大小,但是当我们使用hard targets的时候,我们用同样大小的数据集训练时会出现过拟合的情况,下面的图表表示了这几种情况:
6.1 利用soft targets来防止过拟合
我们之前讲过的集成策略,在”specialist“模型训练时,会将大量的样本分类为垃圾类别,而有用的样本数是比较少的,这样就很容易让”specialist“模型过拟合,若是我们丢弃一些垃圾样本,这会让我们丢失掉很多的信息,所以我们就是用soft targets策略训练来防止其过拟合。这里的soft targets可以来自于集成策略中的通用模型。
7、Relationship on Mixtures of Experts
集成训练的策略中”specialist“模型的融合过程可能跟 门网络是相似的,利用门网络给样例分配模型时,门网络会根据模型的相对判别性能来进行选择,这固然很好,比简单的对输入向量进行聚类,根据聚类结果进行跟配模型的方法要好得多,但这会使得训练很难并行化,同时也会有一些其他的训练困难,因为这个方案不太可行性。
8、Discussion
- 我们展示了蒸馏网络可以很好的将一个大的模型的能力转移到一个小的模型中
- 使用我们的集成训练方法,可以有效的减少一个大模型训练的时间