def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) m_t_slice = tf.gather(m_t, grad.indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_prime = self.get_slot(var, "v_prime") v_t_slice = tf.gather(v_t, grad.indices) v_prime_slice = tf.gather(v_prime, grad.indices) v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice)) v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m := beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1 - beta1_t) * grad.values, use_locking=self._use_locking) # v := beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1 - beta2_t) * math_ops.square(grad.values), use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def _apply_sparse(self, grad, var): lr = (self._lr_t * math_ops.sqrt(1 - self._beta2_power) / (1 - self._beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad.values * (1 - self._beta1_t) m_scaled = gen_array_ops.gather(m, grad.indices) * self._beta1_t m_t = state_ops.scatter_update(m, grad.indices, m_scaled + m_scaled_g_values, use_locking=self._use_locking) m_tp = gen_array_ops.gather(m_t, grad.indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t) v_scaled = gen_array_ops.gather(v, grad.indices) * self._beta2_t v_t = state_ops.scatter_update(v, grad.indices, v_scaled + v_scaled_g_values, use_locking=self._use_locking) v_tp = gen_array_ops.gather(v_t, grad.indices) v_sqrtp = math_ops.sqrt(v_tp) var_update = state_ops.scatter_sub(var, grad.indices, lr * m_tp / (v_sqrtp + self._epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var): if len(grad.indices.get_shape()) == 1: grad_indices = grad.indices grad_values = grad.values else: grad_indices = array_ops.reshape(grad.indices, [-1]) grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value]) gidxs, metagidxs = array_ops.unique(grad_indices) sizegidxs = array_ops.size(gidxs) gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs) # m_t = mu * m + (1 - mu) * g_t m = self.get_slot(var, "m") m_scaled_g_values = gvals * (1 - self._mu_t) m_t = state_ops.scatter_update(m, gidxs, array_ops.gather(m, gidxs) * self._mu_t, use_locking=self._use_locking) m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values, use_locking=self._use_locking) m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power) # m_bar = mu * m_t + (1 - mu) * g_t m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power) var_update = state_ops.scatter_sub(var, gidxs, self._lr_t * m_bar, use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t])
def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1 - beta1_t) * grad.values, use_locking=self._use_locking) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1 - beta2_t) * math_ops.square(grad.values), use_locking=self._use_locking) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype) beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype) m = self.get_slot(var, 'm') m_t = state_ops.assign( m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking) sign_g = ops.IndexedSlices( math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape) sign_gm = ops.IndexedSlices( array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values, sign_g.indices, dense_shape=sign_g.dense_shape) sign_decayed = math_ops.cast( self._sign_decay_t, var.dtype.base_dtype) multiplier_values = alpha_t + sign_decayed * sign_gm.values multiplier = ops.IndexedSlices( multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape) final_update = ops.IndexedSlices( lr_t * multiplier.values * grad.values, multiplier.indices, dense_shape=multiplier.dense_shape) var_update = state_ops.scatter_sub( var, final_update.indices, final_update.values, use_locking=self._use_locking) return control_flow_ops.group(* [var_update, m_t])
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype) beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype) m = self.get_slot(var, 'm') m_t = state_ops.assign(m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking) sign_g = ops.IndexedSlices(math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape) sign_gm = ops.IndexedSlices( array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values, sign_g.indices, dense_shape=sign_g.dense_shape) sign_decayed = math_ops.cast(self._sign_decay_t, var.dtype.base_dtype) multiplier_values = alpha_t + sign_decayed * sign_gm.values multiplier = ops.IndexedSlices(multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape) final_update = ops.IndexedSlices(lr_t * multiplier.values * grad.values, multiplier.indices, dense_shape=multiplier.dense_shape) var_update = state_ops.scatter_sub(var, final_update.indices, final_update.values, use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t])
def scatter_sub(self, sparse_delta, use_locking=False): if not isinstance(sparse_delta, ops.IndexedSlices): raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) return state_ops.scatter_sub( self._variable, sparse_delta.indices, sparse_delta.values, use_locking=use_locking)
def _apply_sparse(self, grad, var): t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t) momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power) momentum_cache_t_1 = beta1_t * ( 1. - 0.5 * momentum_cache_power * self._momentum_cache_const) m_schedule_new = m_schedule * momentum_cache_t m_schedule_next = m_schedule_new * momentum_cache_t_1 # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update( m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) g_prime_slice = grad.values / (1. - m_schedule_new) m_t_prime_slice = array_ops.gather( m_t, grad.indices) / (1. - m_schedule_next) m_t_bar_slice = ( 1. - momentum_cache_t ) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update( v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_t_prime_slice = array_ops.gather( v_t, grad.indices) / (1. - tf.pow(beta2_t, t)) var_update = state_ops.scatter_sub( var, grad.indices, lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var): max_learning_rate = array_ops.where(self._counter < self._burnin, self._burnin_max_learning_rate, self._max_learning_rate) learn_rate = clip_ops.clip_by_value( self._get_coordinatewise_learning_rate(grad, var), 0.0, math_ops.cast(max_learning_rate, var.dtype)) delta = grad.values * learn_rate return state_ops.scatter_sub(var, grad.indices, delta, use_locking=self._use_locking)
def _apply_sparse(self, grad, var): beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m := beta1 * m + (1 - beta1) * g_t # We use a slightly different version of the moving-average update formula # that does a better job of handling concurrent lockless updates: # m -= (1 - beta1) * (m - g_t) m = self.get_slot(var, "m") m_t_delta = array_ops.gather(m, grad.indices) - grad.values m_t = state_ops.scatter_sub(m, grad.indices, (1 - beta1_t) * m_t_delta, use_locking=self._use_locking) # v := beta2 * v + (1 - beta2) * (g_t * g_t) # We reformulate the update as: # v -= (1 - beta2) * (v - g_t * g_t) v = self.get_slot(var, "v") v_t_delta = array_ops.gather(v, grad.indices) - math_ops.square( grad.values) v_t = state_ops.scatter_sub(v, grad.indices, (1 - beta2_t) * v_t_delta, use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def testSub(self): variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32)) resource_variable = resource_variable_ops.ResourceVariable( array_ops.ones([8], dtype=dtypes.int32)) indices = constant_op.constant([4, 3, 1, 7]) updates = constant_op.constant([0, 2, -1, 2], dtype=dtypes.int32) for ref in (variable, resource_variable): sub_result = state_ops.scatter_sub(ref, indices, updates) self.evaluate(ref.initializer) expected_result = constant_op.constant([1, 2, 1, -1, 1, 1, 1, -1]) self.assertAllEqual(self.evaluate(sub_result), expected_result) self.assertAllEqual(self.evaluate(ref), expected_result)
def _center_loss(logit, labels, alpha, lam, num_classes, dtype=dtypes.float32): """ coumpute the center loss and update the centers, followed by 'A Discriminative Feature Learning Approach for Deep Face Recognition',ECCV 2016 :param logit: output of NN full connection layer, [batch_size, feature_dimension] tensor :param labels: true label of every sample, [batch_size] tensor without ont-hot :param alpha: learning rate about speed of updating, 0-1 float :param lam: center loss weight compared to softmax loss and others :param num_classes: classes numbers,int :return: loss: the computed center loss centers: tensor of all centers,[num_classes, feature_dimension] centers_update_op: should be running while training the model to update centers """ # get feature dimension fea_dimension = array_ops.shape(logit)[1] # initialize centers centers = variable_scope.get_variable( 'centers', [num_classes, fea_dimension], dtype=dtype, initializer=init_ops.constant_initializer(0), trainable=False) labels = array_ops.reshape(labels, [-1]) # get centers about current batch centers_batch = array_ops.gather(centers, labels) # compote l2 loss loss = nn_ops.l2_loss(logit - centers_batch) * lam # compute the difference between each sample and their corresponding center diff = centers_batch - logit # compute delta of corresponding center unique_label, unique_idx, unique_count = array_ops.unique_with_counts( labels) appear_times = array_ops.gather(unique_count, unique_idx) appear_times = array_ops.reshape(appear_times, [-1, 1]) delta_centers = diff / math_ops.cast(1 + appear_times, tf.float32) delta_centers = delta_centers * alpha # update centers center_update_op = state_ops.scatter_sub(centers, labels, delta_centers) return loss, centers, center_update_op
def _apply_sparse(self, grad, var): beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m := beta1 * m + (1 - beta1) * g_t # We use a slightly different version of the moving-average update formula # that does a better job of handling concurrent lockless updates: # m -= (1 - beta1) * (m - g_t) m = self.get_slot(var, "m") m_t_delta = array_ops.gather(m, grad.indices) - grad.values m_t = state_ops.scatter_sub(m, grad.indices, (1 - beta1_t) * m_t_delta, use_locking=self._use_locking) # v := beta2 * v + (1 - beta2) * (g_t * g_t) # We reformulate the update as: # v -= (1 - beta2) * (v - g_t * g_t) v = self.get_slot(var, "v") v_t_delta = array_ops.gather(v, grad.indices) - math_ops.square(grad.values) v_t = state_ops.scatter_sub(v, grad.indices, (1 - beta2_t) * v_t_delta, use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) clip_multiplier_t = math_ops.cast(self.clip_multiplier_t, var.dtype.base_dtype) clip_epsilon_t = math_ops.cast(self.clip_epsilon_t, var.dtype.base_dtype) v = self.get_slot(var, "v") v_slice = array_ops.gather(v, grad.indices) #clip gradient so that each value exceeds its previous maximum by no more than clip_multiplier clipped_values = grad.values if self.clip_gradients: clipVal = v_slice * clip_multiplier_t + clip_epsilon_t clipped_values = clip_ops.clip_by_value(grad.values, -clipVal, clipVal) # m := beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t_values = beta1_t * array_ops.gather( m, grad.indices) + (1 - beta1_t) * clipped_values m_t = state_ops.scatter_update(m, grad.indices, m_t_values, use_locking=self._use_locking) # v := max(beta2 * v , abs(grad)) v_t_values = math_ops.maximum(beta2_t * v_slice, math_ops.abs(clipped_values)) v_t = state_ops.scatter_update(v, grad.indices, v_t_values, use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + v_t) # we do not use bias-correction term for the first moment; it does not give observable benefit var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_values / (v_t_values + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(var_update, v_t, m_t)
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype) beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype) eps = 1e-7 # cap for moving average m = self.get_slot(var, "m") m_slice = tf.gather(m, grad.indices) m_t = state_ops.scatter_update(m, grad.indices, tf.maximum(beta_t * m_slice + eps, tf.abs(grad.values))) m_t_slice = tf.gather(m_t, grad.indices) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * grad.values * tf.exp( tf.log(alpha_t) * tf.sign(grad.values) * tf.sign(m_t_slice))) # Update 'ref' by subtracting 'value # Create an op that groups multiple operations. # When this op finishes, all ops in input have finished return control_flow_ops.group(*[var_update, m_t])
def scatter_sub(self, sparse_delta, use_locking=False): """Subtracts `IndexedSlices` from this variable. This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices, sparse_delta.values)`. Args: sparse_delta: `IndexedSlices` to be subtracted from this variable. use_locking: If `True`, use locking during the operation. Returns: A `Tensor` that will hold the new value of this variable after the scattered subtraction has completed. Raises: ValueError: if `sparse_delta` is not an `IndexedSlices`. """ if not isinstance(sparse_delta, ops.IndexedSlices): raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) return state_ops.scatter_sub(self._variable, sparse_delta.indices, sparse_delta.values, use_locking=use_locking)
def _apply_sparse(self, grad, var): t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t) momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power) momentum_cache_t_1 = beta1_t * (1. - 0.5 * momentum_cache_power * self._momentum_cache_const) m_schedule_new = m_schedule * momentum_cache_t m_schedule_next = m_schedule_new * momentum_cache_t_1 # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) g_prime_slice = grad.values / (1. - m_schedule_new) m_t_prime_slice = array_ops.gather(m_t, grad.indices) / (1. - m_schedule_next) m_t_bar_slice = (1. - momentum_cache_t) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_t_prime_slice = array_ops.gather(v_t, grad.indices) / (1. - tf.pow(beta2_t, t)) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def testScatterSubStateOps(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub") state_ops.scatter_sub(v, [1], [3]) self.assertAllEqual([1.0, -1.0], v.numpy())
def _apply_gradient(self, grad, var, indices=None): """The main function to update a variable. Args: grad: A Tensor containing gradient to apply. var: A Tensor containing the variable to update. indices: An array of integers, for sparse update. Returns: Updated variable var = var - learning_rate * preconditioner * grad If the gradient is dense, var and grad have the same shape. If the update is sparse, then the first dimension of the gradient and var may differ, others are all the same. In this case the indices array provides the set of indices of the variable which are to be updated with each row of the gradient. """ global_step = self._global_step + 1 # Update accumulated weighted average of gradients gbar = self.get_slot(var, "gbar") gbar_decay_t = GetParam(self._gbar_decay, global_step) gbar_weight_t = GetParam(self._gbar_weight, global_step) if indices is not None: # Note - the sparse update is not easily implemented, since the # algorithm needs all indices of gbar to be updated # if mat_gbar_decay != 1 or mat_gbar_decay != 0. # One way to make mat_gbar_decay = 1 is by rescaling. # If we want the update: # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t # define: # r_{t+1} = a_{t+1} * r_t # h_t = G_t / r_t # Then: # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t # So we get the mat_gbar_decay = 1 as desired. # We can implement this in a future version as needed. # However we still need gbar_decay = 0, otherwise all indices # of the variable will need to be updated. if self._gbar_decay != 0.0: tf_logging.warning("Not applying momentum for variable: %s" % var.name) gbar_updated = grad else: gbar_updated = self._weighted_average(gbar, self._gbar_decay, gbar_decay_t, gbar_weight_t * grad) # Update the preconditioners and compute the preconditioned gradient shape = var.get_shape() mat_g_list = [] for i in range(len(shape)): mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) preconditioned_grad = gbar_updated v_rank = len(mat_g_list) neg_alpha = -GetParam(self._alpha, global_step) / v_rank svd_interval = GetParam(self._svd_interval, global_step) precond_update_interval = GetParam(self._precond_update_interval, global_step) for i, mat_g in enumerate(mat_g_list): # axes is the list of indices to reduce - everything but the current i. axes = list(range(i)) + list(range(i + 1, v_rank)) if shape[i] < self._max_matrix_size: # If the tensor size is sufficiently small perform full Shampoo update # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this # is not strictly correct. However we will use it for now, and # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) # pylint: disable=g-long-lambda,cell-var-from-loop mat_g_updated = control_flow_ops.cond( math_ops.mod(global_step, precond_update_interval) < 1, lambda: self._update_mat_g( mat_g, grad, axes, mat_gbar_decay_t, mat_gbar_weight_t * precond_update_interval, i), lambda: mat_g) if self._svd_interval == 1: mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) else: mat_h = control_flow_ops.cond( math_ops.mod(global_step, svd_interval) < 1, lambda: self._compute_power(var, mat_g_updated, shape[ i], neg_alpha, "H_" + str(i)), lambda: self.get_slot(var, "H_" + str(i))) # mat_h is a square matrix of size d_i x d_i # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor # After contraction with a d_i x d_i tensor # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor # (the first dimension is contracted out, and the second dimension of # mat_h is appended). After going through all the indices, it becomes # a d_0 x ... x d_n tensor again. preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, axes=([0], [0]), name="precond_" + str(i)) else: # Tensor size is too large -- perform diagonal Shampoo update grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) if i == 0 and indices is not None: assert self._mat_gbar_decay == 1.0 mat_g_updated = state_ops.scatter_add( mat_g, indices, mat_gbar_weight_t * grad_outer) mat_h = math_ops.pow( array_ops.gather(mat_g_updated, indices) + self._epsilon, neg_alpha) else: mat_g_updated = self._weighted_average( mat_g, self._mat_gbar_decay, mat_gbar_decay_t, mat_gbar_weight_t * grad_outer) mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) # Need to do the transpose to ensure that the tensor becomes # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. preconditioned_grad = array_ops.transpose( preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h # Update the variable based on the Shampoo update learning_rate_t = GetParam(self._learning_rate, global_step) if indices is not None: var_updated = state_ops.scatter_sub( var, indices, learning_rate_t * preconditioned_grad) else: var_updated = state_ops.assign_sub( var, learning_rate_t * preconditioned_grad) return var_updated
def _scatter_sub(self, x, i, v): return state_ops.scatter_sub( x, i, v, use_locking=self._use_locking)
def _finish(self, update_ops, name_scope): """""" caches = [update_op[0] for update_op in update_ops] update_ops = [update_op[1:] for update_op in update_ops] if self._noise is not None: for cache in caches: s_t, x_tm1 = cache[:2] s_t += random_ops.random_normal( x_tm1.initialized_value().get_shape(), stddev=self._noise) cache[0] = s_t if self._clip > 0: S_t = [cache[0] for cache in caches] S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip) for cache, s_t in zip(caches, S_t): cache[0] = s_t new_update_ops = [] for cache, update_op in zip(caches, update_ops): if len(cache) == 3: s_t, x_tm1 = cache[:2] with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t = state_ops.assign_sub(x_tm1, s_t, use_locking=self._use_locking) cache.append(x_t) else: s_t_, x_tm1, idxs = cache[:3] with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t = state_ops.scatter_sub(x_tm1, idxs, s_t_, use_locking=self._use_locking) cache.append(x_t) new_update_ops.append(control_flow_ops.group(*([x_t] + update_op))) with ops.control_dependencies(new_update_ops): more_update_ops = [] if self._save_step: for cache in caches: if len(cache) == 4: s_t, x_tm1 = cache[:2] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): new_step_and_grads = [] s_t = state_ops.assign( s_tm1, -s_t, use_locking=self._use_locking) else: s_t_, x_tm1, idxs = cache[:3] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): s_t = state_ops.scatter_update( s_tm1, idxs, -s_t_, use_locking=self._use_locking) more_update_ops.append(s_t) if self._save_grad: for cache in caches: if len(cache) == 4: x_tm1, g_t = cache[1:3] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): new_step_and_grads = [] g_t = state_ops.assign( g_tm1, g_t, use_locking=self._use_locking) else: x_tm1, idxs, g_t_ = cache[1:4] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): g_t = state_ops.scatter_update( g_tm1, idxs, g_t_, use_locking=self._use_locking) more_update_ops.append(g_t) if self._chi > 0: for cache in caches: if len(cache) == 4: _, x_tm1, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_and_t = self._dense_moving_average( x_tm1, x_t, 'x', self._chi) more_update_ops.append( control_flow_ops.group(*x_and_t)) else: _, x_tm1, idxs, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t_ = array_ops.gather(x_t, idxs) x_and_t = self._sparse_moving_average( x_tm1, idxs, x_t_, 'x', self._chi) more_update_ops.append( control_flow_ops.group(*x_and_t)) return control_flow_ops.group(*(new_update_ops + more_update_ops), name=name_scope)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): step, beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) if self._initial_total_steps > 0: total_steps = math_ops.cast(self._total_steps_t, var.dtype.base_dtype) warmup_proportion = math_ops.cast(self._warmup_proportion_t, var.dtype.base_dtype) min_lr = math_ops.cast(self._min_lr_t, var.dtype.base_dtype) warmup_steps = total_steps * warmup_proportion decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) decay_rate = (min_lr - lr_t) / decay_steps lr_t = tf.where( step <= warmup_steps, lr_t * (step / warmup_steps), lr_t + decay_rate * math_ops.minimum(step - warmup_steps, decay_steps), ) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) m_corr_t = m_t / (1.0 - beta1_power) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) if self._amsgrad: vhat = self.get_slot(var, 'vhat') vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power)) else: v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power)) r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) / (sma_inf - 2.0) * sma_inf / sma_t) var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) if self._initial_weight_decay > 0.0: param_name = self._get_variable_name(var.name) if self._do_use_weight_decay(param_name): var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var var_t = lr_t * var_t var_update = state_ops.scatter_sub(var, indices, array_ops.gather(var_t, indices), use_locking=self._use_locking) updates = [var_update, m_t, v_t] if self._amsgrad: updates.append(vhat_t) return control_flow_ops.group(*updates)
def _finish(self, update_ops, name_scope): """""" caches = [update_op[0] for update_op in update_ops] update_ops = [update_op[1:] for update_op in update_ops] if self._noise is not None: for cache in caches: s_t, x_tm1 = cache[:2] s_t += random_ops.random_normal(x_tm1.initialized_value().get_shape(), stddev=self._noise) cache[0] = s_t if self._clip is not None: S_t = [cache[0] for cache in caches] S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip) for cache, s_t in zip(caches, S_t): cache[0] = s_t new_update_ops = [] for cache, update_op in zip(caches, update_ops): if len(cache) == 3: s_t, x_tm1 = cache[:2] with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t = state_ops.assign_sub(x_tm1, s_t, use_locking=self._use_locking) cache.append(x_t) else: s_t_, x_tm1, idxs = cache[:3] with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t = state_ops.scatter_sub(x_tm1, idxs, s_t_, use_locking=self._use_locking) cache.append(x_t) new_update_ops.append(control_flow_ops.group(*([x_t] + update_op))) with ops.control_dependencies(new_update_ops): more_update_ops = [] if self._save_step: for cache in caches: if len(cache) == 4: s_t, x_tm1 = cache[:2] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): new_step_and_grads = [] s_t = state_ops.assign(s_tm1, -s_t, use_locking=self._use_locking) else: s_t_, x_tm1, idxs = cache[:3] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): s_t = state_ops.scatter_update(s_tm1, idxs, -s_t_, use_locking=self._use_locking) more_update_ops.append(s_t) if self._save_grad: for cache in caches: if len(cache) == 4: x_tm1, g_t = cache[1:3] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): new_step_and_grads = [] g_t = state_ops.assign(g_tm1, g_t, use_locking=self._use_locking) else: x_tm1, idxs, g_t_ = cache[1:4] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): g_t = state_ops.scatter_update(g_tm1, idxs, g_t_, use_locking=self._use_locking) more_update_ops.append(g_t) if self._chi > 0: for cache in caches: if len(cache) == 4: _, x_tm1, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_and_t = self._dense_moving_average(x_tm1, x_t, 'x', self._chi) more_update_ops.append(control_flow_ops.group(*x_and_t)) else: _, x_tm1, idxs, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t_ = array_ops.gather(x_t, idxs) x_and_t = self._sparse_moving_average(x_tm1, idxs, x_t_, 'x', self._chi) more_update_ops.append(control_flow_ops.group(*x_and_t)) return control_flow_ops.group(*(new_update_ops + more_update_ops), name=name_scope)
def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) global_step = self._get_step_accumulators() global_step = math_ops.cast(global_step, var.dtype.base_dtype) pre_step = self.get_slot(var, "pre_step") indices = grad.indices pre_step_slice = array_ops.gather(pre_step, indices) skipped_steps = global_step - pre_step_slice m = self.get_slot(var, "m") m_slice = array_ops.gather(m, indices) v = self.get_slot(var, "v") v_slice = array_ops.gather(v, indices) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) / # (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) / # (1 - beta1)\\) lr = ((lr_t * math_ops.sqrt(1 - math_ops.exp(pre_step_slice * math_ops.log(beta2_t))) / (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) * (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) / (1 - beta1_t)) # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\) var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t) var_update_op = state_ops.scatter_sub(var, indices, var_slice, use_locking=self._use_locking) with ops.control_dependencies([var_update_op]): # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) m_t_slice = ( math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice + (1 - beta1_t) * grad) m_update_op = state_ops.scatter_update( m, indices, m_t_slice, use_locking=self._use_locking) # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) v_t_slice = ( math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice + (1 - beta2_t) * math_ops.square(grad)) v_update_op = state_ops.scatter_update( v, indices, v_t_slice, use_locking=self._use_locking) with ops.control_dependencies([m_update_op, v_update_op]): pre_step_update_op = state_ops.scatter_update( pre_step, indices, global_step, use_locking=self._use_locking) return control_flow_ops.group(var_update_op, m_update_op, v_update_op, pre_step_update_op)