admin 管理员组

文章数量: 888526

【深度学习】关于EMA:指数移动平均

什么是EMA

        指数移动平均(exponential moving average),也叫做权重移动平均(weighted moving average),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关。在采用 SGD 或者其他的一些优化算法 (Adam, Momentum) 训练神经网络时,通常会使用EMA的方法。 它的意义在于利用滑动平均的参数来提高模型在测试数据上的健壮性。(在SGD优化算法中,也会通过使用动量或者改变学习率的方式加快收敛速度)。

EMA公式

        

        shadowVariable 为最后经过 EMA 处理后得到的参数值,Variable 为当前 epoch 轮次的参数值。EMA 对每一个待更新训练学习的变量 (variable) 都会维护一个影子变量 (shadow variable)。影子变量的初始值就是这个变量的初始值。由上述公式可知, decay 控制着模型更新的速度,越大越趋于稳定。实际运用中,通常会设为一个十分接近 1 的常数 (0.999 或 0.9999)。

EMA为什么可以提升模型性能

        EMA可以使得模型在测试数据上更健壮,“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”

        对神经网络边的的权重进行移动平均,得到对应的影子权重:shadow_weights,在训练过程中仍然使用不带滑动平均的权重(原始weights),以得到 weights 下一步更新的值,进而求下一步 weights 的影子权重 shadow_weights。在测试过程中,则使用影子权重代替原始weights,这样在测试数据上的效果更好,因为shadow_weights的更新更加平滑。

  • 随机梯度下降:更平滑的更新说明不会偏离最优点很远
  • batch gradient decent:影子变量作用可能不大,因为梯度下降的方向已经是最优的了,loss 一定减小
  • mini-batch gradient decent:可以尝试滑动平均,因为mini-batch gradient decent 对参数的更  新也存在抖动

pytorch实现 :

class EMA():def __init__(self, model, decay):self.model = modelself.decay = decayself.shadow = {}self.backup = {}def register(self):for name, param in self.model.named_parameters():if param.requires_grad:self.shadow[name] = param.data.clone()def update(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadownew_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()def apply_shadow(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadowself.backup[name] = param.dataparam.data = self.shadow[name]def restore(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.backupparam.data = self.backup[name]self.backup = {}# 初始化
ema = EMA(model, 0.999)
ema.register()# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()ema.update()# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():ema.apply_shadow()# evaluateema.restore()

本文标签: 深度学习关于EMA指数移动平均