def _resource_apply(self, grad, var, indices=None): # 准备变量 var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = K.cast(self.epsilon, var_dtype) local_step = K.cast(self.iterations + 1, var_dtype) beta_1_t_power = K.pow(beta_1_t, local_step) beta_2_t_power = K.pow(beta_2_t, local_step) # 更新公式 if indices is None: m_t = K.update(m, beta_1_t * m + (1 - beta_1_t) * grad) v_t = K.update(v, beta_2_t * v + (1 - beta_2_t) * grad**2) else: mv_ops = [K.update(m, beta_1_t * m), K.update(v, beta_2_t * v)] with tf.control_dependencies(mv_ops): m_t = self._resource_scatter_add(m, indices, (1 - beta_1_t) * grad) v_t = self._resource_scatter_add(v, indices, (1 - beta_2_t) * grad**2) # 返回算子 with tf.control_dependencies([m_t, v_t]): if self.bias_correction: m_t = m_t / (1.0 - beta_1_t_power) v_t = v_t / (1.0 - beta_2_t_power) var_t = var - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) return K.update(var, var_t)
def beta2(self): if self._beta2 is None: iterations = K.cast(self.iterations + 1, K.floatx()) return 1.0 - K.pow(iterations, -0.8) else: return self._beta2