def set_model(self, model):
     """绑定模型,并初始化参数
     """
     super(ExponentialMovingAverage, self).set_model(model)
     self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights]
     self.old_weights = K.batch_get_value(model.weights)
     K.batch_set_value(zip(self.ema_weights, self.old_weights))
     self.updates = []
     for w1, w2 in zip(self.ema_weights, model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         self.updates.append(op)
예제 #2
0
        def apply_ema_weights(self, bias_correction=True):
            """备份原模型权重,然后将平均权重应用到模型上去。
            """
            self.old_weights = K.batch_get_value(self.model_weights)
            ema_weights = K.batch_get_value(self.ema_weights)

            if bias_correction:
                iterations = K.eval(self.iterations)
                scale = 1.0 - np.power(self.ema_momentum, iterations)
                ema_weights = [weight / scale for weight in ema_weights]

            K.batch_set_value(zip(self.model_weights, ema_weights))
예제 #3
0
        def get_updates(self, loss, params):
            updates = super(NewOptimizer, self).get_updates(loss, params)
            self.model_weights = params
            self.ema_weights = [K.zeros(K.shape(w)) for w in params]
            self.old_weights = K.batch_get_value(params)
            K.batch_set_value(zip(self.ema_weights, self.old_weights))

            ema_updates, ema_momentum = [], self.ema_momentum
            with tf.control_dependencies(updates):
                for w1, w2 in zip(self.ema_weights, params):
                    new_w = ema_momentum * w1 + (1 - ema_momentum) * w2
                    ema_updates.append(K.update(w1, new_w))

            return ema_updates
예제 #4
0
 def reset_old_weights(self):
     """恢复模型到旧权重。
     """
     K.batch_set_value(zip(self.model_weights, self.old_weights))
 def apply_ema_weights(self):
     """备份原模型权重,然后将平均权重应用到模型上去。
     """
     self.old_weights = K.batch_get_value(self.model.weights)
     ema_weights = K.batch_get_value(self.ema_weights)
     K.batch_set_value(zip(self.model.weights, ema_weights))