def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) # Learning rate multipliers if self.lr_multipliers is not None: lr_t = _apply_lr_multiplier(self, lr_t, var) m_scaled_g_values = grad * (1 - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) if self.amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) var_delta = m_t / (math_ops.sqrt(vhat_t) + epsilon_t) else: var_delta = m_t / (math_ops.sqrt(v_t) + epsilon_t) var_t = math_ops.sub(var, self.eta_t * lr_t * var_delta) # Weight decays if var.name in self.weight_decays.keys(): var_t = _apply_weight_decays(self, var, var_t) var_update = state_ops.assign(var, var_t, use_locking=self._use_locking) # Cosine annealing (iteration_done, t_cur_update, eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var) if iteration_done and not self._init_notified: self._init_notified = True updates = [var_update, m_t, v_t] if iteration_done: updates += [t_cur_update] if self.use_cosine_annealing and iteration_done: updates += [eta_t_update] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) # Learning rate multipliers if self.lr_multipliers is not None: lr_t = _apply_lr_multiplier(self, lr_t, var) if self._momentum: momentum = array_ops.identity( self._get_hyper('momentum', var_dtype)) m = self.get_slot(var, 'momentum') v = momentum * m - self.eta_t * lr_t * grad m = state_ops.assign(m, v, use_locking=self._use_locking) if self.nesterov: var_t = self._resource_scatter_add( var, indices, momentum * v - (self.eta_t * lr_t * grad)) else: var_t = self._resource_scatter_add(var, indices, v) else: v = -self.eta_t * lr_t * grad var_t = var + v # Weight decays if var.name in self.weight_decays.keys(): var_t = _apply_weight_decays(self, var, var_t) var_update = state_ops.assign(var, var_t, use_locking=self._use_locking) # Cosine annealing (iteration_done, t_cur_update, eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var) if iteration_done and not self._init_notified: self._init_notified = True updates = [var_update] if self._momentum: updates += [m] if iteration_done: updates += [t_cur_update] if self.use_cosine_annealing and iteration_done: updates += [eta_t_update] return control_flow_ops.group(*updates)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_dtype = var.dtype.base_dtype lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') local_step = math_ops.cast(self.iterations + 1, var_dtype) next_step = math_ops.cast(self.iterations + 2, var_dtype) decay_base = math_ops.cast(0.96, var_dtype) # Learning rate multipliers if self.lr_multipliers is not None: lr_t = _apply_lr_multiplier(self, lr_t, var) momentum_cache_t = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * local_step))) momentum_cache_t_1 = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * next_step))) m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * momentum_cache_t if var_dtype is self._m_cache.dtype: m_schedule_new = array_ops.identity( state_ops.assign(self._m_cache, m_schedule_new, use_locking=self._use_locking)) m_schedule_next = m_schedule_new * momentum_cache_t_1 m_scaled_g_values = grad * (1. - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) m_t_slice = array_ops.gather(m_t, indices) m_t_prime = m_t_slice / (1. - m_schedule_next) g_prime = grad / (1. - m_schedule_new) m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t_1 * m_t_prime) v_scaled_g_values = (grad * grad) * (1. - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) v_t_slice = array_ops.gather(v_t, indices) v_t_prime_denominator = 1. - math_ops.pow(beta_2_t, local_step) v_t_prime = v_t_slice / v_t_prime_denominator v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + epsilon_t var_t = self._resource_scatter_add( var, indices, -self.eta_t * lr_t * m_t_bar / v_prime_sqrt_plus_eps) # Weight decays if var.name in self.weight_decays.keys(): var_t = _apply_weight_decays(self, var, var_t) var_update = state_ops.assign(var, var_t, use_locking=self._use_locking) # Cosine annealing (iteration_done, t_cur_update, eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var) if iteration_done and not self._init_notified: self._init_notified = True updates = [var_update, m_t_bar, v_t] if iteration_done: updates += [t_cur_update] if self.use_cosine_annealing and iteration_done: updates += [eta_t_update] return control_flow_ops.group(*updates)
def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') local_step = math_ops.cast(self.iterations + 1, var_dtype) next_step = math_ops.cast(self.iterations + 2, var_dtype) decay_base = math_ops.cast(0.96, var_dtype) # Learning rate multipliers if self.lr_multipliers is not None: lr_t = _apply_lr_multiplier(self, lr_t, var) # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_t = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * local_step))) momentum_cache_t_1 = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * next_step))) m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * momentum_cache_t if var_dtype is self._m_cache.dtype: m_schedule_new = array_ops.identity( state_ops.assign(self._m_cache, m_schedule_new, use_locking=self._use_locking)) m_schedule_next = m_schedule_new * momentum_cache_t_1 # the following equations given in [1] g_prime = grad / (1. - m_schedule_new) m_t = beta_1_t * m + (1. - beta_1_t) * grad m_t_prime = m_t / (1. - m_schedule_next) v_t = beta_2_t * v + (1. - beta_2_t) * math_ops.square(grad) v_t_prime = v_t / (1. - math_ops.pow(beta_2_t, local_step)) m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t * m_t_prime) m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) var_t = math_ops.sub( var, self.eta_t * lr_t * m_t_bar / (math_ops.sqrt(v_t_prime + epsilon_t))) # Weight decays if var.name in self.weight_decays.keys(): var_t = _apply_weight_decays(self, var, var_t) var_update = state_ops.assign(var, var_t, use_locking=self._use_locking) # Cosine annealing (iteration_done, t_cur_update, eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var) if iteration_done and not self._init_notified: self._init_notified = True updates = [var_update, m_t, v_t] if iteration_done: updates += [t_cur_update] if self.use_cosine_annealing and iteration_done: updates += [eta_t_update] return control_flow_ops.group(*updates)