def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) accumulation_steps = self._get_hyper('accumulation_steps', 'int64') update_cond = tf.equal((self.iterations + 1) % accumulation_steps, 0) sub_step = self.iterations % accumulation_steps + 1 local_step = math_ops.cast(self.iterations // accumulation_steps + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) lr = tf.where(update_cond, lr, 0.0) g = self.get_slot(var, 'g') g_a = grad / math_ops.cast(accumulation_steps, var_dtype) g_t = tf.where(tf.equal(sub_step, 1), g_a, g + (g_a - g) / math_ops.cast(sub_step, var_dtype)) g_t = state_ops.assign(g, g_t, use_locking=self._use_locking) m = self.get_slot(var, 'm') m_t = tf.where(update_cond, m * beta_1_t + g_t * (1 - beta_1_t), m) m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) v = self.get_slot(var, 'v') v_t = tf.where(update_cond, v * beta_2_t + (g_t * g_t) * (1 - beta_2_t), v) v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) if not self.amsgrad: v_sqrt = math_ops.sqrt(v_t) var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) else: v_hat = self.get_slot(var, 'vhat') v_hat_t = tf.where(update_cond, math_ops.maximum(v_hat, v_t), v_hat) with ops.control_dependencies([v_hat_t]): v_hat_t = state_ops.assign(v_hat, v_hat_t, use_locking=self._use_locking) v_hat_sqrt = math_ops.sqrt(v_hat_t) var_update = state_ops.assign_sub(var, lr * m_t / (v_hat_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, total_steps) / total_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) m_t = state_ops.assign(m, beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta_1_power) v_t = state_ops.assign(v, beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), use_locking=self._use_locking) if self.amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power) + epsilon_t) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power) + epsilon_t) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) decay_steps = self._get_hyper('decay_steps', var_dtype) warmup_steps = self._get_hyper('warmup_steps', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, decay_steps) / decay_steps), ) lr_t = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) v = self.get_slot(var, 'v') v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) if self.amsgrad: v_hat = self.get_slot(var, 'vhat') v_hat_t = math_ops.maximum(v_hat, v_t) var_update = m_t / (math_ops.sqrt(v_hat_t) + epsilon_t) else: var_update = m_t / (math_ops.sqrt(v_t) + epsilon_t) if self._initial_weight_decay > 0.0: weight_decay = self._get_hyper('weight_decay', var_dtype) var_update += weight_decay * var var_update = state_ops.assign_sub(var, lr_t * var_update, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(v_hat_t) return control_flow_ops.group(*updates)
def _apply_dense_shared(self, grad, var): var_dtype = var.dtype.base_dtype beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var_dtype) beta2_power = math_ops.cast(beta2_power, var_dtype) niter = self._get_niter() niter = math_ops.cast(niter, var_dtype) lr_t = math_ops.cast(self._lr_t, var_dtype) beta1_t = math_ops.cast(self._beta1_t, var_dtype) beta2_t = math_ops.cast(self._beta2_t, var_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var_dtype) sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power) m = self.get_slot(var, 'm') m_t = state_ops.assign(m, beta1_t * m + (1.0 - beta1_t) * grad, use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta1_power) v = self.get_slot(var, 'v') v_t = state_ops.assign(v, beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) if self._amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self._amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) decay_steps = K.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), lr_t + decay_rate * K.minimum(local_step - warmup_steps, decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) m_corr_t = m_t / (1.0 - beta_1_power) v = self.get_slot(var, 'v') v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) if self.amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)