def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.learning_rate if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_' + str(i)) for (i, p) in enumerate(params) ] vs = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_' + str(i)) for (i, p) in enumerate(params) ] if self.amsgrad: vhats = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='vhat_' + str(i)) for (i, p) in enumerate(params) ] else: vhats = [ K.zeros(1, name='vhat_' + str(i)) for i in range(len(params)) ] self.weights = [self.iterations] + ms + vs + vhats for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g - m_t) if self.amsgrad: vhat_t = K.maximum(vhat, v_t) p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def get_updates(self, loss, params): # 是否更新 cond = K.equal(self.iterations % self.grad_accum_steps, 0) cond = K.cast(cond, K.floatx()) # 获取梯度 grads = self.get_gradients(loss, params) self.accum_grads = [ K.zeros(shape=K.int_shape(p), dtype=K.dtype(p), name='accum_grad_{}'.format(i)) for i, p in enumerate(params) ] old_update = K.update def new_update(x, new_x): new_x = cond * new_x + (1 - cond) * x return old_update(x, new_x) K.update = new_update updates = super(NewOptimizer, self).get_updates(loss, params) K.update = old_update # 累计更新 with K.control_dependencies(updates): acc_updates = [ K.update(ag, g + (1 - cond) * ag) for ag, g in zip(self.accum_grads, grads) ] return acc_updates