def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) decay_steps = self._get_hyper('decay_steps', var_dtype) warmup_steps = self._get_hyper('warmup_steps', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, decay_steps) / decay_steps), ) lr_t = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) m_t = state_ops.assign(m, beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking) v_t = state_ops.assign(v, beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), use_locking=self._use_locking) if self.amsgrad: v_hat = self.get_slot(var, 'vhat') v_hat_t = math_ops.maximum(v_hat, v_t) var_update = m_t / (math_ops.sqrt(v_hat_t) + epsilon_t) else: var_update = m_t / (math_ops.sqrt(v_t) + epsilon_t) if self._initial_weight_decay > 0.0: weight_decay = self._get_hyper('weight_decay', var_dtype) var_update += weight_decay * var var_update = state_ops.assign_sub(var, lr_t * var_update, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(v_hat_t) return control_flow_ops.group(*updates)
def _apply_dense_shared(self, grad, var): var_dtype = var.dtype.base_dtype beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var_dtype) beta2_power = math_ops.cast(beta2_power, var_dtype) niter = self._get_niter() niter = math_ops.cast(niter, var_dtype) lr_t = math_ops.cast(self._lr_t, var_dtype) beta1_t = math_ops.cast(self._beta1_t, var_dtype) beta2_t = math_ops.cast(self._beta2_t, var_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var_dtype) sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power) m = self.get_slot(var, 'm') m_t = state_ops.assign(m, beta1_t * m + (1.0 - beta1_t) * grad, use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta1_power) v = self.get_slot(var, 'v') v_t = state_ops.assign(v, beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) if self._amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self._amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _finish(self, update_ops, name_scope): # Update the power accumulators. with ops.control_dependencies(update_ops): beta1_power, beta2_power = self._get_beta_accumulators() niter = self._get_niter() with ops.colocate_with(beta1_power): update_beta1 = beta1_power.assign( beta1_power * self._beta1_t, use_locking=self._use_locking) update_beta2 = beta2_power.assign( beta2_power * self._beta2_t, use_locking=self._use_locking) update_niter = niter.assign(niter + 1, use_locking=self._use_locking) return control_flow_ops.group( *update_ops + [update_beta1, update_beta2, update_niter], name=name_scope)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): learning_rate_t = math_ops.cast(self.learning_rate_t, var.dtype.base_dtype) beta_1_t = math_ops.cast(self.beta_1_t, var.dtype.base_dtype) beta_2_t = math_ops.cast(self.beta_2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self.epsilon_t, var.dtype.base_dtype) weight_decay_rate_t = math_ops.cast(self.weight_decay_rate_t, var.dtype.base_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) learning_rate_t = math_ops.cast(self.learning_rate_t, var.dtype.base_dtype) learning_rate_t = (learning_rate_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) m_scaled_g_values = grad * (1 - beta_1_t) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) update = m_t / (math_ops.sqrt(v_t) + epsilon_t) if self._do_use_weight_decay(var.name): update += weight_decay_rate_t * var update_with_lr = learning_rate_t * update var_update = state_ops.assign_sub(var, update_with_lr, use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def _resource_apply_dense(self, grad, var): learning_rate_t = math_ops.cast(self.learning_rate_t, var.dtype.base_dtype) beta_1_t = math_ops.cast(self.beta_1_t, var.dtype.base_dtype) beta_2_t = math_ops.cast(self.beta_2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self.epsilon_t, var.dtype.base_dtype) weight_decay_rate_t = math_ops.cast(self.weight_decay_rate_t, var.dtype.base_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) learning_rate_t = math_ops.cast(self.learning_rate_t, var.dtype.base_dtype) learning_rate_t = (learning_rate_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # Standard Adam update. next_m = (tf.multiply(beta_1_t, m) + tf.multiply(1.0 - beta_1_t, grad)) next_v = (tf.multiply(beta_2_t, v) + tf.multiply(1.0 - beta_2_t, tf.square(grad))) update = next_m / (tf.sqrt(next_v) + epsilon_t) if self._do_use_weight_decay(var.name): update += weight_decay_rate_t * var update_with_lr = learning_rate_t * update next_param = var - update_with_lr return control_flow_ops.group( *[var.assign(next_param), m.assign(next_m), v.assign(next_v)])
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype) min_lr = self._get_hyper('min_lr', var_dtype) decay_steps = K.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), lr_t + decay_rate * K.minimum(local_step - warmup_steps, decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) m_corr_t = m_t / (1.0 - beta_1_power) v = self.get_slot(var, 'v') v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) if self.amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power)) else: vhat_t = None v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power)) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var var_update = self._resource_scatter_add(var, indices, tf.gather(-lr_t * var_t, indices)) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) if self._initial_total_steps > 0: total_steps = self._get_hyper('total_steps', var_dtype) warmup_steps = total_steps * self._get_hyper( 'warmup_proportion', var_dtype) decay_steps = total_steps - warmup_steps lr_t = tf.where( local_step <= warmup_steps, lr_t * (local_step / warmup_steps), lr_t * (1.0 - tf.minimum(local_step, decay_steps) / decay_steps), ) sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) m_t = state_ops.assign(m, beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking) m_corr_t = m_t / (1.0 - beta_1_power) v_t = state_ops.assign(v, beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad), use_locking=self._use_locking) if self.amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power) + epsilon_t) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power) + epsilon_t) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t) if self._initial_weight_decay > 0.0: var_t += self._get_hyper('weight_decay', var_dtype) * var var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self.amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)