HOW TO TRAIN YOUR MAML
MAML的问题
- 训练不稳定性
训练时很不稳定,依赖于神经网络的架构和外层超参数的设定。 可能导致梯度爆炸或梯度消失。
- 二阶导的成本
梯度更新需要计算二阶导,非常昂贵,MAML中提出使用一阶近似来加速这个过程,但这样做会对最终泛化误差产生负面的影响。 在不牺牲泛化性能的情况下减少计算时间的方法还没有被提出。
- 批处理归一化(Batch Normalization)统计量累积的缺失
在原MAML论文中实验里批处理归一化使用方式是有问题的,其没有累积得到的统计量,而是直接将当前统计量用于批处理归一化,这样学习到的偏差需要适应不同的均值和标准偏差,而如果使用的是累积的统计量,那么最终会收敛到某个全局的均值和标准偏差,并且这样做收敛更快、更稳定、泛化性能更好。
- 共享(跨步)批处理归一化偏差
另一个问题是其批处理归一化偏差并不是在inner-loop中更新的,而在整个base-model迭代中使用了同样的bias,这含蓄地假设了所有的base-model在inner-loop中更新时都一样,即其特征有同样的分布,但这个假设是错误的。在每个inner-loop中,会有新的base-model产生,并且其和之前的模型有足够的不同,从bias的估计上应该被认为是一个新的模型。
- 共享内循环(跨步和跨参数)学习率
还有个同时影响泛化性能和收敛速度(用训练迭代次数表示)的问题是所有参数和更新都使用同一个共享的学习率。这样需要一个固定的学习率,需要进行多次超参数调整,可能会非常昂贵。 Learning to learn quickly for few shot learning. arXiv preprint arXiv:1707.09835, 2017.的作者提出可以对网络每个参数的学习率进行学习,可以解决上述问题,但也有自己的问题,这样会增加计算和内存空间的成本。
- 固定的外循环学习率
MAML中作者使用固定学习率的Adam来训练元优化器,使用阶跃函数或余弦函数对学习速率进行退火已被证明是在多种设置下实现先进的泛化性能的关键,因此证明了使用静态学习率有可能降低了MAML的泛化性能和优化速度,并且固定的学习率也意味着须花费更多的时间调整。
MAML++
一个一个解决上述问题。
- 梯度不稳定性 -> 多步损失优化(Multi-Step Loss Optimization, MSL)
在每个base-model的inner-loop中,对support-set中的每一步都对目标进行更新: $$ \theta = \theta - \beta \nabla_\theta \sum_{b=1}^B \sum_{i=0}^N v_i L_{T_b}(f_{\theta_i^b}) $$ b表示任务,i表示每个任务中第i步,\(v_i\)表示第i步后的权重。 还需要对每步损失的权重进行退火,一开始所有损失贡献相同,但随着训练迭代,我们减少靠前步骤的权重,缓慢增加靠后步骤的权重,这保证随着训练的进行,优化器会更重视最终的步骤从而达到最低可能的损失。
- 二阶导成本 -> 导数退火(Derivative-Order Annealing, DA)
原MAML算法中需要计算二阶导,其作者提出使用一阶近似来计算,但他是在整个训练过程中都使用一阶近似。MAML++提出,可以在前50个epochs中使用一阶近似,而之后都使用二阶导计算,这样经过经验可以得到前50个epochs可以被极大地加速,并且后面使用二阶导计算可以得到很强的泛化性能。 并且还可以观察到,这样做可以避免梯度爆炸和梯度消失,而全部直接使用二阶导会更加不稳定。说明DA比单独使用标准MAML算法更加稳定,使用一阶近似进行训练相当于标准MAML的一种预训练,让后期MAML对模型的训练更好之外,可以避免标准MAML算法出现梯度衰减、梯度爆炸现象。
- 批处理归一化统计量累积的缺失 -> 每步批处理归一化处理统计量(Per-Step Batch Normalization Running Statistics, BNRS)
MAML++提出在每步中收集统计量,需要在网络的每一批归一化层都实例化N组running均值和标准偏差,然后随着优化的每一步分别更新这些统计量。
- 共享(跨步)批处理归一化偏差 -> 每步的批处理归一化权重和偏差(Per-Step Batch Normalization Weights and Biases, BNWB)
MAML++提出在inner-loop更新过程中每步都学习一组偏差。这样做意味着Batch Normalization将学习特定于在每个集合处看到的特征分布的偏差,这将提高收敛速度,稳定性和泛化性能。
- 共享内循环(跨步和跨参数)学习率 -> 每层每步的学习率和梯度方向的学习(Learning Per-Layer Per-Step Learning Rates and Gradient Directions, LSLR)
之前提到,Learning to learn quickly for few shot learning. arXiv preprint arXiv:1707.09835, 2017.的作者提出可以对网络每个参数的学习率进行学习,可以解决上述问题,但也有自己的问题,这样会增加计算和内存空间的成本。 MAML++提出对网络中每一层学习一个学习率和一个搜索方向。
- 固定的外循环学习率 -> 元优化器学习率的余弦退火(Cosine Annealing of Meta-Optimizer Learning Rate, CA)
原MAML模型采用固定的外循环学习率,MAML++提出对元学习优化器(外循环)学习率使用余弦退火。