Beispiel #1
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

        # the following equations given in [1]
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = state_ops.scatter_update(m, grad.indices,
                                       beta1_t * array_ops.gather(m, grad.indices) +
                                       (1. - beta1_t) * grad.values,
                                       use_locking=self._use_locking)
        m_t_slice = tf.gather(m_t, grad.indices)

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = state_ops.scatter_update(v, grad.indices,
                                       beta2_t * array_ops.gather(v, grad.indices) +
                                       (1. - beta2_t) * tf.square(grad.values),
                                       use_locking=self._use_locking)
        v_prime = self.get_slot(var, "v_prime")
        v_t_slice = tf.gather(v_t, grad.indices)
        v_prime_slice = tf.gather(v_prime, grad.indices)
        v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice))

        v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices)
        var_update = state_ops.scatter_sub(var, grad.indices,
                                           lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
                                           use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
  def _apply_sparse(self, grad, var):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # m := beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_t = state_ops.scatter_update(m, grad.indices,
                                   beta1_t * array_ops.gather(m, grad.indices) +
                                   (1 - beta1_t) * grad.values,
                                   use_locking=self._use_locking)

    # v := beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_t = state_ops.scatter_update(v, grad.indices,
                                   beta2_t * array_ops.gather(v, grad.indices) +
                                   (1 - beta2_t) * math_ops.square(grad.values),
                                   use_locking=self._use_locking)

    # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
    m_t_slice = array_ops.gather(m_t, grad.indices)
    v_t_slice = array_ops.gather(v_t, grad.indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_t_slice / denominator_slice,
                                       use_locking=self._use_locking)
    return control_flow_ops.group(var_update, m_t, v_t)
Beispiel #3
0
 def _apply_sparse(self, grad, var):
   lr = (self._lr_t *
         math_ops.sqrt(1 - self._beta2_power)
         / (1 - self._beta1_power))
   # m_t = beta1 * m + (1 - beta1) * g_t
   m = self.get_slot(var, "m")
   m_scaled_g_values = grad.values * (1 - self._beta1_t)
   m_scaled = gen_array_ops.gather(m, grad.indices) * self._beta1_t
   m_t = state_ops.scatter_update(m, grad.indices,
                                  m_scaled + m_scaled_g_values,
                                  use_locking=self._use_locking)
   m_tp = gen_array_ops.gather(m_t, grad.indices)
   
   # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
   v = self.get_slot(var, "v")
   v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
   v_scaled = gen_array_ops.gather(v, grad.indices) * self._beta2_t
   v_t = state_ops.scatter_update(v, grad.indices,
                                  v_scaled + v_scaled_g_values,
                                  use_locking=self._use_locking)
   v_tp = gen_array_ops.gather(v_t, grad.indices)
   v_sqrtp = math_ops.sqrt(v_tp)
   
   var_update = state_ops.scatter_sub(var, grad.indices,
                                      lr * m_tp / (v_sqrtp + self._epsilon_t),
                                      use_locking=self._use_locking)    
   return control_flow_ops.group(*[var_update, m_t, v_t])
Beispiel #4
0
 def _apply_sparse(self, grad, var):
   if len(grad.indices.get_shape()) == 1:
     grad_indices = grad.indices
     grad_values = grad.values
   else:
     grad_indices = array_ops.reshape(grad.indices, [-1])
     grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
   gidxs, metagidxs = array_ops.unique(grad_indices)
   sizegidxs = array_ops.size(gidxs)
   gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
   # m_t = mu * m + (1 - mu) * g_t
   m = self.get_slot(var, "m")
   m_scaled_g_values = gvals * (1 - self._mu_t)
   m_t = state_ops.scatter_update(m, gidxs,
                                  array_ops.gather(m, gidxs) * self._mu_t,
                                  use_locking=self._use_locking)
   m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
                               use_locking=self._use_locking)
   m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
   # m_bar = mu * m_t + (1 - mu) * g_t
   m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
   var_update = state_ops.scatter_sub(var, gidxs,
                                    self._lr_t * m_bar,
                                    use_locking=self._use_locking)
   return control_flow_ops.group(*[var_update, m_t])
Beispiel #5
0
  def _apply_sparse(self, grad, var):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
    m = self.get_slot(var, "m")
    m_t = state_ops.scatter_update(m, grad.indices,
                                   beta1_t * array_ops.gather(m, grad.indices) +
                                   (1 - beta1_t) * grad.values,
                                   use_locking=self._use_locking)

    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
    v = self.get_slot(var, "v")
    v_t = state_ops.scatter_update(v, grad.indices,
                                   beta2_t * array_ops.gather(v, grad.indices) +
                                   (1 - beta2_t) * math_ops.square(grad.values),
                                   use_locking=self._use_locking)

    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
    m_t_slice = array_ops.gather(m_t, grad.indices)
    v_t_slice = array_ops.gather(v_t, grad.indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_t_slice / denominator_slice,
                                       use_locking=self._use_locking)
    return control_flow_ops.group(var_update, m_t, v_t)
Beispiel #6
0
  def _apply_sparse(self, grad, var):
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
    beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

    m = self.get_slot(var, 'm')
    m_t = state_ops.assign(
        m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking)

    sign_g = ops.IndexedSlices(
        math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape)
    sign_gm = ops.IndexedSlices(
        array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values,
        sign_g.indices,
        dense_shape=sign_g.dense_shape)

    sign_decayed = math_ops.cast(
        self._sign_decay_t, var.dtype.base_dtype)
    multiplier_values = alpha_t + sign_decayed * sign_gm.values
    multiplier = ops.IndexedSlices(
        multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape)

    final_update = ops.IndexedSlices(
        lr_t * multiplier.values * grad.values,
        multiplier.indices,
        dense_shape=multiplier.dense_shape)

    var_update = state_ops.scatter_sub(
        var,
        final_update.indices,
        final_update.values,
        use_locking=self._use_locking)

    return control_flow_ops.group(* [var_update, m_t])
Beispiel #7
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        m = self.get_slot(var, 'm')
        m_t = state_ops.assign(m, (m * beta_t) + (grad * (1 - beta_t)),
                               use_locking=self._use_locking)

        sign_g = ops.IndexedSlices(math_ops.sign(grad.values),
                                   grad.indices,
                                   dense_shape=grad.dense_shape)
        sign_gm = ops.IndexedSlices(
            array_ops.gather(math_ops.sign(m_t), sign_g.indices) *
            sign_g.values,
            sign_g.indices,
            dense_shape=sign_g.dense_shape)

        sign_decayed = math_ops.cast(self._sign_decay_t, var.dtype.base_dtype)
        multiplier_values = alpha_t + sign_decayed * sign_gm.values
        multiplier = ops.IndexedSlices(multiplier_values,
                                       sign_gm.indices,
                                       dense_shape=sign_gm.dense_shape)

        final_update = ops.IndexedSlices(lr_t * multiplier.values *
                                         grad.values,
                                         multiplier.indices,
                                         dense_shape=multiplier.dense_shape)

        var_update = state_ops.scatter_sub(var,
                                           final_update.indices,
                                           final_update.values,
                                           use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t])
Beispiel #8
0
 def scatter_sub(self, sparse_delta, use_locking=False):
   if not isinstance(sparse_delta, ops.IndexedSlices):
     raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
   return state_ops.scatter_sub(
       self._variable,
       sparse_delta.indices,
       sparse_delta.values,
       use_locking=use_locking)
    def _apply_sparse(self, grad, var):
        t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1.
        m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        schedule_decay_t = math_ops.cast(self._schedule_decay_t,
                                         var.dtype.base_dtype)

        # Due to the recommendations in [2], i.e. warming momentum schedule
        momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t)
        momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power)
        momentum_cache_t_1 = beta1_t * (
            1. - 0.5 * momentum_cache_power * self._momentum_cache_const)
        m_schedule_new = m_schedule * momentum_cache_t
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        # the following equations given in [1]
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = state_ops.scatter_update(
            m,
            grad.indices,
            beta1_t * array_ops.gather(m, grad.indices) +
            (1. - beta1_t) * grad.values,
            use_locking=self._use_locking)
        g_prime_slice = grad.values / (1. - m_schedule_new)
        m_t_prime_slice = array_ops.gather(
            m_t, grad.indices) / (1. - m_schedule_next)
        m_t_bar_slice = (
            1. - momentum_cache_t
        ) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = state_ops.scatter_update(
            v,
            grad.indices,
            beta2_t * array_ops.gather(v, grad.indices) +
            (1. - beta2_t) * tf.square(grad.values),
            use_locking=self._use_locking)
        v_t_prime_slice = array_ops.gather(
            v_t, grad.indices) / (1. - tf.pow(beta2_t, t))

        var_update = state_ops.scatter_sub(
            var,
            grad.indices,
            lr_t * m_t_bar_slice /
            (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
            use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t, v_t])
  def _apply_sparse(self, grad, var):

    max_learning_rate = array_ops.where(self._counter < self._burnin,
                                        self._burnin_max_learning_rate,
                                        self._max_learning_rate)

    learn_rate = clip_ops.clip_by_value(
        self._get_coordinatewise_learning_rate(grad, var), 0.0,
        math_ops.cast(max_learning_rate, var.dtype))
    delta = grad.values * learn_rate

    return state_ops.scatter_sub(var, grad.indices, delta,
                                 use_locking=self._use_locking)
    def _apply_sparse(self, grad, var):
        beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

        # m := beta1 * m + (1 - beta1) * g_t
        # We use a slightly different version of the moving-average update formula
        # that does a better job of handling concurrent lockless updates:
        # m -= (1 - beta1) * (m - g_t)
        m = self.get_slot(var, "m")
        m_t_delta = array_ops.gather(m, grad.indices) - grad.values
        m_t = state_ops.scatter_sub(m,
                                    grad.indices, (1 - beta1_t) * m_t_delta,
                                    use_locking=self._use_locking)

        # v := beta2 * v + (1 - beta2) * (g_t * g_t)
        # We reformulate the update as:
        # v -= (1 - beta2) * (v - g_t * g_t)
        v = self.get_slot(var, "v")
        v_t_delta = array_ops.gather(v, grad.indices) - math_ops.square(
            grad.values)
        v_t = state_ops.scatter_sub(v,
                                    grad.indices, (1 - beta2_t) * v_t_delta,
                                    use_locking=self._use_locking)

        # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
        m_t_slice = array_ops.gather(m_t, grad.indices)
        v_t_slice = array_ops.gather(v_t, grad.indices)
        denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
        var_update = state_ops.scatter_sub(var,
                                           grad.indices,
                                           lr * m_t_slice / denominator_slice,
                                           use_locking=self._use_locking)
        return control_flow_ops.group(var_update, m_t, v_t)
Beispiel #12
0
    def testSub(self):
        variable = variables.Variable(array_ops.ones([8], dtype=dtypes.int32))
        resource_variable = resource_variable_ops.ResourceVariable(
            array_ops.ones([8], dtype=dtypes.int32))
        indices = constant_op.constant([4, 3, 1, 7])
        updates = constant_op.constant([0, 2, -1, 2], dtype=dtypes.int32)

        for ref in (variable, resource_variable):
            sub_result = state_ops.scatter_sub(ref, indices, updates)
            self.evaluate(ref.initializer)

            expected_result = constant_op.constant([1, 2, 1, -1, 1, 1, 1, -1])
            self.assertAllEqual(self.evaluate(sub_result), expected_result)
            self.assertAllEqual(self.evaluate(ref), expected_result)
Beispiel #13
0
def _center_loss(logit, labels, alpha, lam, num_classes, dtype=dtypes.float32):
    """
    coumpute the center loss and update the centers,
    followed by 'A Discriminative Feature Learning Approach for Deep Face Recognition',ECCV 2016

    :param logit: output of NN full connection layer, [batch_size, feature_dimension] tensor
    :param labels: true label of every sample, [batch_size] tensor without ont-hot
    :param alpha: learning rate about speed of updating, 0-1 float
    :param lam: center loss weight compared to softmax loss and others
    :param num_classes: classes numbers,int
    :return:
        loss: the computed center loss
        centers: tensor of all centers,[num_classes, feature_dimension]
        centers_update_op: should be running while training the model to update centers
    """

    # get feature dimension
    fea_dimension = array_ops.shape(logit)[1]

    # initialize centers
    centers = variable_scope.get_variable(
        'centers', [num_classes, fea_dimension],
        dtype=dtype,
        initializer=init_ops.constant_initializer(0),
        trainable=False)

    labels = array_ops.reshape(labels, [-1])

    # get centers about current batch
    centers_batch = array_ops.gather(centers, labels)

    # compote l2 loss
    loss = nn_ops.l2_loss(logit - centers_batch) * lam

    # compute the difference between each sample and their corresponding center
    diff = centers_batch - logit

    # compute delta of corresponding center
    unique_label, unique_idx, unique_count = array_ops.unique_with_counts(
        labels)
    appear_times = array_ops.gather(unique_count, unique_idx)
    appear_times = array_ops.reshape(appear_times, [-1, 1])
    delta_centers = diff / math_ops.cast(1 + appear_times, tf.float32)
    delta_centers = delta_centers * alpha

    # update centers
    center_update_op = state_ops.scatter_sub(centers, labels, delta_centers)

    return loss, centers, center_update_op
  def _apply_sparse(self, grad, var):
    beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # m := beta1 * m + (1 - beta1) * g_t
    # We use a slightly different version of the moving-average update formula
    # that does a better job of handling concurrent lockless updates:
    # m -= (1 - beta1) * (m - g_t)
    m = self.get_slot(var, "m")
    m_t_delta = array_ops.gather(m, grad.indices) - grad.values
    m_t = state_ops.scatter_sub(m, grad.indices,
                                (1 - beta1_t) * m_t_delta,
                                use_locking=self._use_locking)

    # v := beta2 * v + (1 - beta2) * (g_t * g_t)
    # We reformulate the update as:
    # v -= (1 - beta2) * (v - g_t * g_t)
    v = self.get_slot(var, "v")
    v_t_delta = array_ops.gather(v, grad.indices) - math_ops.square(grad.values)
    v_t = state_ops.scatter_sub(v, grad.indices,
                                (1 - beta2_t) * v_t_delta,
                                use_locking=self._use_locking)

    # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
    m_t_slice = array_ops.gather(m_t, grad.indices)
    v_t_slice = array_ops.gather(v_t, grad.indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_t_slice / denominator_slice,
                                       use_locking=self._use_locking)
    return control_flow_ops.group(var_update, m_t, v_t)
Beispiel #15
0
    def _apply_sparse(self, grad, var):

        max_learning_rate = array_ops.where(self._counter < self._burnin,
                                            self._burnin_max_learning_rate,
                                            self._max_learning_rate)

        learn_rate = clip_ops.clip_by_value(
            self._get_coordinatewise_learning_rate(grad, var), 0.0,
            math_ops.cast(max_learning_rate, var.dtype))
        delta = grad.values * learn_rate

        return state_ops.scatter_sub(var,
                                     grad.indices,
                                     delta,
                                     use_locking=self._use_locking)
Beispiel #16
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        clip_multiplier_t = math_ops.cast(self.clip_multiplier_t,
                                          var.dtype.base_dtype)
        clip_epsilon_t = math_ops.cast(self.clip_epsilon_t,
                                       var.dtype.base_dtype)

        v = self.get_slot(var, "v")
        v_slice = array_ops.gather(v, grad.indices)

        #clip gradient so that each value exceeds its previous maximum by no more than clip_multiplier
        clipped_values = grad.values
        if self.clip_gradients:
            clipVal = v_slice * clip_multiplier_t + clip_epsilon_t
            clipped_values = clip_ops.clip_by_value(grad.values, -clipVal,
                                                    clipVal)

        # m := beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t_values = beta1_t * array_ops.gather(
            m, grad.indices) + (1 - beta1_t) * clipped_values
        m_t = state_ops.scatter_update(m,
                                       grad.indices,
                                       m_t_values,
                                       use_locking=self._use_locking)

        # v := max(beta2 * v , abs(grad))
        v_t_values = math_ops.maximum(beta2_t * v_slice,
                                      math_ops.abs(clipped_values))
        v_t = state_ops.scatter_update(v,
                                       grad.indices,
                                       v_t_values,
                                       use_locking=self._use_locking)

        # variable -= learning_rate * m_t / (epsilon_t + v_t)
        # we do not use bias-correction term for the first moment; it does not give observable benefit
        var_update = state_ops.scatter_sub(var,
                                           grad.indices,
                                           lr_t * m_t_values /
                                           (v_t_values + epsilon_t),
                                           use_locking=self._use_locking)
        return control_flow_ops.group(var_update, v_t, m_t)
Beispiel #17
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        eps = 1e-7  # cap for moving average

        m = self.get_slot(var, "m")
        m_slice = tf.gather(m, grad.indices)
        m_t = state_ops.scatter_update(m, grad.indices,
                                       tf.maximum(beta_t * m_slice + eps, tf.abs(grad.values)))
        m_t_slice = tf.gather(m_t, grad.indices)

        var_update = state_ops.scatter_sub(var, grad.indices, lr_t * grad.values * tf.exp(
            tf.log(alpha_t) * tf.sign(grad.values) * tf.sign(m_t_slice)))  # Update 'ref' by subtracting 'value
        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        return control_flow_ops.group(*[var_update, m_t])
Beispiel #18
0
    def scatter_sub(self, sparse_delta, use_locking=False):
        """Subtracts `IndexedSlices` from this variable.

    This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
    sparse_delta.values)`.

    Args:
      sparse_delta: `IndexedSlices` to be subtracted from this variable.
      use_locking: If `True`, use locking during the operation.

    Returns:
      A `Tensor` that will hold the new value of this variable after
      the scattered subtraction has completed.

    Raises:
      ValueError: if `sparse_delta` is not an `IndexedSlices`.
    """
        if not isinstance(sparse_delta, ops.IndexedSlices):
            raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
        return state_ops.scatter_sub(self._variable, sparse_delta.indices, sparse_delta.values, use_locking=use_locking)
Beispiel #19
0
    def _apply_sparse(self, grad, var):
        t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1.
        m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype)

        # Due to the recommendations in [2], i.e. warming momentum schedule
        momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t)
        momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power)
        momentum_cache_t_1 = beta1_t * (1. - 0.5 * momentum_cache_power * self._momentum_cache_const)
        m_schedule_new = m_schedule * momentum_cache_t
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        # the following equations given in [1]
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = state_ops.scatter_update(m, grad.indices,
                                       beta1_t * array_ops.gather(m, grad.indices) +
                                       (1. - beta1_t) * grad.values,
                                       use_locking=self._use_locking)
        g_prime_slice = grad.values / (1. - m_schedule_new)
        m_t_prime_slice = array_ops.gather(m_t, grad.indices) / (1. - m_schedule_next)
        m_t_bar_slice = (1. - momentum_cache_t) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = state_ops.scatter_update(v, grad.indices,
                                       beta2_t * array_ops.gather(v, grad.indices) +
                                       (1. - beta2_t) * tf.square(grad.values),
                                       use_locking=self._use_locking)
        v_t_prime_slice = array_ops.gather(v_t, grad.indices) / (1. - tf.pow(beta2_t, t))

        var_update = state_ops.scatter_sub(var, grad.indices,
                                           lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
                                           use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t, v_t])
Beispiel #20
0
  def scatter_sub(self, sparse_delta, use_locking=False):
    """Subtracts `IndexedSlices` from this variable.

    This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
    sparse_delta.values)`.

    Args:
      sparse_delta: `IndexedSlices` to be subtracted from this variable.
      use_locking: If `True`, use locking during the operation.

    Returns:
      A `Tensor` that will hold the new value of this variable after
      the scattered subtraction has completed.

    Raises:
      ValueError: if `sparse_delta` is not an `IndexedSlices`.
    """
    if not isinstance(sparse_delta, ops.IndexedSlices):
      raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
    return state_ops.scatter_sub(self._variable,
                                 sparse_delta.indices,
                                 sparse_delta.values,
                                 use_locking=use_locking)
 def testScatterSubStateOps(self):
   with context.eager_mode():
     v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub")
     state_ops.scatter_sub(v, [1], [3])
     self.assertAllEqual([1.0, -1.0], v.numpy())
Beispiel #22
0
    def _apply_gradient(self, grad, var, indices=None):
        """The main function to update a variable.

    Args:
      grad: A Tensor containing gradient to apply.
      var: A Tensor containing the variable to update.
      indices: An array of integers, for sparse update.

    Returns:
      Updated variable var = var - learning_rate * preconditioner * grad

    If the gradient is dense, var and grad have the same shape.
    If the update is sparse, then the first dimension of the gradient and var
    may differ, others are all the same. In this case the indices array
    provides the set of indices of the variable which are to be updated with
    each row of the gradient.
    """
        global_step = self._global_step + 1

        # Update accumulated weighted average of gradients
        gbar = self.get_slot(var, "gbar")
        gbar_decay_t = GetParam(self._gbar_decay, global_step)
        gbar_weight_t = GetParam(self._gbar_weight, global_step)
        if indices is not None:
            # Note - the sparse update is not easily implemented, since the
            # algorithm needs all indices of gbar to be updated
            # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
            # One way to make mat_gbar_decay = 1 is by rescaling.
            # If we want the update:
            #         G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
            # define:
            #         r_{t+1} = a_{t+1} * r_t
            #         h_t = G_t / r_t
            # Then:
            #         h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
            # So we get the mat_gbar_decay = 1 as desired.
            # We can implement this in a future version as needed.
            # However we still need gbar_decay = 0, otherwise all indices
            # of the variable will need to be updated.
            if self._gbar_decay != 0.0:
                tf_logging.warning("Not applying momentum for variable: %s" %
                                   var.name)
            gbar_updated = grad
        else:
            gbar_updated = self._weighted_average(gbar, self._gbar_decay,
                                                  gbar_decay_t,
                                                  gbar_weight_t * grad)

        # Update the preconditioners and compute the preconditioned gradient
        shape = var.get_shape()
        mat_g_list = []
        for i in range(len(shape)):
            mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
        mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
        mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)

        preconditioned_grad = gbar_updated
        v_rank = len(mat_g_list)
        neg_alpha = -GetParam(self._alpha, global_step) / v_rank
        svd_interval = GetParam(self._svd_interval, global_step)
        precond_update_interval = GetParam(self._precond_update_interval,
                                           global_step)
        for i, mat_g in enumerate(mat_g_list):
            # axes is the list of indices to reduce - everything but the current i.
            axes = list(range(i)) + list(range(i + 1, v_rank))
            if shape[i] < self._max_matrix_size:
                # If the tensor size is sufficiently small perform full Shampoo update
                # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
                # is not strictly correct. However we will use it for now, and
                # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)

                # pylint: disable=g-long-lambda,cell-var-from-loop
                mat_g_updated = control_flow_ops.cond(
                    math_ops.mod(global_step, precond_update_interval) < 1,
                    lambda: self._update_mat_g(
                        mat_g, grad, axes, mat_gbar_decay_t, mat_gbar_weight_t
                        * precond_update_interval, i), lambda: mat_g)

                if self._svd_interval == 1:
                    mat_h = self._compute_power(var, mat_g_updated, shape[i],
                                                neg_alpha)
                else:
                    mat_h = control_flow_ops.cond(
                        math_ops.mod(global_step, svd_interval) < 1,
                        lambda: self._compute_power(var, mat_g_updated, shape[
                            i], neg_alpha, "H_" + str(i)),
                        lambda: self.get_slot(var, "H_" + str(i)))

                # mat_h is a square matrix of size d_i x d_i
                # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
                # After contraction with a d_i x d_i tensor
                # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
                # (the first dimension is contracted out, and the second dimension of
                # mat_h is appended).  After going through all the indices, it becomes
                # a d_0 x ... x d_n tensor again.
                preconditioned_grad = math_ops.tensordot(preconditioned_grad,
                                                         mat_h,
                                                         axes=([0], [0]),
                                                         name="precond_" +
                                                         str(i))
            else:
                # Tensor size is too large -- perform diagonal Shampoo update
                grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
                if i == 0 and indices is not None:
                    assert self._mat_gbar_decay == 1.0
                    mat_g_updated = state_ops.scatter_add(
                        mat_g, indices, mat_gbar_weight_t * grad_outer)
                    mat_h = math_ops.pow(
                        array_ops.gather(mat_g_updated, indices) +
                        self._epsilon, neg_alpha)
                else:
                    mat_g_updated = self._weighted_average(
                        mat_g, self._mat_gbar_decay, mat_gbar_decay_t,
                        mat_gbar_weight_t * grad_outer)
                    mat_h = math_ops.pow(mat_g_updated + self._epsilon,
                                         neg_alpha)

                # Need to do the transpose to ensure that the tensor becomes
                # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
                preconditioned_grad = array_ops.transpose(
                    preconditioned_grad,
                    perm=list(range(1, v_rank)) + [0]) * mat_h

        # Update the variable based on the Shampoo update
        learning_rate_t = GetParam(self._learning_rate, global_step)
        if indices is not None:
            var_updated = state_ops.scatter_sub(
                var, indices, learning_rate_t * preconditioned_grad)
        else:
            var_updated = state_ops.assign_sub(
                var, learning_rate_t * preconditioned_grad)
        return var_updated
Beispiel #23
0
 def _scatter_sub(self, x, i, v):
   return state_ops.scatter_sub(
       x, i, v, use_locking=self._use_locking)
Beispiel #24
0
    def _finish(self, update_ops, name_scope):
        """"""

        caches = [update_op[0] for update_op in update_ops]
        update_ops = [update_op[1:] for update_op in update_ops]
        if self._noise is not None:
            for cache in caches:
                s_t, x_tm1 = cache[:2]
                s_t += random_ops.random_normal(
                    x_tm1.initialized_value().get_shape(), stddev=self._noise)
                cache[0] = s_t

        if self._clip > 0:
            S_t = [cache[0] for cache in caches]
            S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip)
            for cache, s_t in zip(caches, S_t):
                cache[0] = s_t

        new_update_ops = []
        for cache, update_op in zip(caches, update_ops):
            if len(cache) == 3:
                s_t, x_tm1 = cache[:2]
                with ops.name_scope('update_' + x_tm1.op.name), ops.device(
                        x_tm1.device):
                    x_t = state_ops.assign_sub(x_tm1,
                                               s_t,
                                               use_locking=self._use_locking)
                    cache.append(x_t)
            else:
                s_t_, x_tm1, idxs = cache[:3]
                with ops.name_scope('update_' + x_tm1.op.name), ops.device(
                        x_tm1.device):
                    x_t = state_ops.scatter_sub(x_tm1,
                                                idxs,
                                                s_t_,
                                                use_locking=self._use_locking)
                    cache.append(x_t)
            new_update_ops.append(control_flow_ops.group(*([x_t] + update_op)))

        with ops.control_dependencies(new_update_ops):
            more_update_ops = []
            if self._save_step:
                for cache in caches:
                    if len(cache) == 4:
                        s_t, x_tm1 = cache[:2]
                        s_tm1 = self.get_slot(x_tm1, 's')
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            new_step_and_grads = []
                            s_t = state_ops.assign(
                                s_tm1, -s_t, use_locking=self._use_locking)
                    else:
                        s_t_, x_tm1, idxs = cache[:3]
                        s_tm1 = self.get_slot(x_tm1, 's')
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            s_t = state_ops.scatter_update(
                                s_tm1,
                                idxs,
                                -s_t_,
                                use_locking=self._use_locking)
                    more_update_ops.append(s_t)
            if self._save_grad:
                for cache in caches:
                    if len(cache) == 4:
                        x_tm1, g_t = cache[1:3]
                        g_tm1 = self.get_slot(x_tm1, 'g')
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            new_step_and_grads = []
                            g_t = state_ops.assign(
                                g_tm1, g_t, use_locking=self._use_locking)
                    else:
                        x_tm1, idxs, g_t_ = cache[1:4]
                        g_tm1 = self.get_slot(x_tm1, 'g')
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            g_t = state_ops.scatter_update(
                                g_tm1,
                                idxs,
                                g_t_,
                                use_locking=self._use_locking)
                    more_update_ops.append(g_t)

            if self._chi > 0:
                for cache in caches:
                    if len(cache) == 4:
                        _, x_tm1, _, x_t = cache
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            x_and_t = self._dense_moving_average(
                                x_tm1, x_t, 'x', self._chi)
                            more_update_ops.append(
                                control_flow_ops.group(*x_and_t))
                    else:
                        _, x_tm1, idxs, _, x_t = cache
                        with ops.name_scope('update_' +
                                            x_tm1.op.name), ops.device(
                                                x_tm1.device):
                            x_t_ = array_ops.gather(x_t, idxs)
                            x_and_t = self._sparse_moving_average(
                                x_tm1, idxs, x_t_, 'x', self._chi)
                            more_update_ops.append(
                                control_flow_ops.group(*x_and_t))

        return control_flow_ops.group(*(new_update_ops + more_update_ops),
                                      name=name_scope)
Beispiel #25
0
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        step, beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)

        if self._initial_total_steps > 0:
            total_steps = math_ops.cast(self._total_steps_t,
                                        var.dtype.base_dtype)
            warmup_proportion = math_ops.cast(self._warmup_proportion_t,
                                              var.dtype.base_dtype)
            min_lr = math_ops.cast(self._min_lr_t, var.dtype.base_dtype)
            warmup_steps = total_steps * warmup_proportion
            decay_steps = math_ops.maximum(total_steps - warmup_steps, 1)
            decay_rate = (min_lr - lr_t) / decay_steps
            lr_t = tf.where(
                step <= warmup_steps,
                lr_t * (step / warmup_steps),
                lr_t + decay_rate *
                math_ops.minimum(step - warmup_steps, decay_steps),
            )

        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

        sma_inf = 2.0 / (1.0 - beta2_t) - 1.0
        sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power)

        m = self.get_slot(var, "m")
        m_scaled_g_values = grad * (1 - beta1_t)
        m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = scatter_add(m, indices, m_scaled_g_values)
        m_corr_t = m_t / (1.0 - beta1_power)

        v = self.get_slot(var, "v")
        v_scaled_g_values = (grad * grad) * (1 - beta2_t)
        v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = scatter_add(v, indices, v_scaled_g_values)
        if self._amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power))
        else:
            v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power))

        r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) /
                            (sma_inf - 2.0) * sma_inf / sma_t)

        var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t),
                         m_corr_t)

        if self._initial_weight_decay > 0.0:
            param_name = self._get_variable_name(var.name)
            if self._do_use_weight_decay(param_name):
                var_t += math_ops.cast(self._weight_decay_t,
                                       var.dtype.base_dtype) * var

        var_t = lr_t * var_t
        var_update = state_ops.scatter_sub(var,
                                           indices,
                                           array_ops.gather(var_t, indices),
                                           use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self._amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)
Beispiel #26
0
 def _finish(self, update_ops, name_scope):
   """"""
   
   caches = [update_op[0] for update_op in update_ops]
   update_ops = [update_op[1:] for update_op in update_ops]
   if self._noise is not None:
     for cache in caches:
       s_t, x_tm1 = cache[:2]
       s_t += random_ops.random_normal(x_tm1.initialized_value().get_shape(), stddev=self._noise)
       cache[0] = s_t
   
   if self._clip is not None:
     S_t = [cache[0] for cache in caches]
     S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip)
     for cache, s_t in zip(caches, S_t):
       cache[0] = s_t
   
   new_update_ops = []
   for cache, update_op in zip(caches, update_ops):
     if len(cache) == 3:
       s_t, x_tm1 = cache[:2]
       with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
         x_t = state_ops.assign_sub(x_tm1, s_t, use_locking=self._use_locking)
         cache.append(x_t)
     else:
       s_t_, x_tm1, idxs = cache[:3]
       with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
         x_t = state_ops.scatter_sub(x_tm1, idxs, s_t_, use_locking=self._use_locking)
         cache.append(x_t)
     new_update_ops.append(control_flow_ops.group(*([x_t] + update_op)))
   
   with ops.control_dependencies(new_update_ops):
     more_update_ops = []
     if self._save_step:
       for cache in caches:
         if len(cache) == 4:
           s_t, x_tm1 = cache[:2]
           s_tm1 = self.get_slot(x_tm1, 's')
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             new_step_and_grads = []
             s_t = state_ops.assign(s_tm1, -s_t, use_locking=self._use_locking)
         else:
           s_t_, x_tm1, idxs = cache[:3]
           s_tm1 = self.get_slot(x_tm1, 's')
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             s_t = state_ops.scatter_update(s_tm1, idxs, -s_t_, use_locking=self._use_locking)
         more_update_ops.append(s_t)
     if self._save_grad:
       for cache in caches:
         if len(cache) == 4:
           x_tm1, g_t = cache[1:3]
           g_tm1 = self.get_slot(x_tm1, 'g')
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             new_step_and_grads = []
             g_t = state_ops.assign(g_tm1, g_t, use_locking=self._use_locking)
         else:
           x_tm1, idxs, g_t_ = cache[1:4]
           g_tm1 = self.get_slot(x_tm1, 'g')
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             g_t = state_ops.scatter_update(g_tm1, idxs, g_t_, use_locking=self._use_locking)
         more_update_ops.append(g_t)
     
     if self._chi > 0:
       for cache in caches:
         if len(cache) == 4:
           _, x_tm1, _, x_t = cache
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             x_and_t = self._dense_moving_average(x_tm1, x_t, 'x', self._chi)
             more_update_ops.append(control_flow_ops.group(*x_and_t))
         else:
           _, x_tm1, idxs, _, x_t = cache
           with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
             x_t_ = array_ops.gather(x_t, idxs)
             x_and_t = self._sparse_moving_average(x_tm1, idxs, x_t_, 'x', self._chi)
             more_update_ops.append(control_flow_ops.group(*x_and_t))
   
   return control_flow_ops.group(*(new_update_ops + more_update_ops), name=name_scope)
 def testScatterSubStateOps(self):
   with context.eager_mode():
     v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub")
     state_ops.scatter_sub(v, [1], [3])
     self.assertAllEqual([1.0, -1.0], v.numpy())
    def _apply_sparse(self, grad, var):
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        global_step = self._get_step_accumulators()
        global_step = math_ops.cast(global_step, var.dtype.base_dtype)
        pre_step = self.get_slot(var, "pre_step")

        indices = grad.indices
        pre_step_slice = array_ops.gather(pre_step, indices)
        skipped_steps = global_step - pre_step_slice

        m = self.get_slot(var, "m")
        m_slice = array_ops.gather(m, indices)
        v = self.get_slot(var, "v")
        v_slice = array_ops.gather(v, indices)
        # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
        # replace math_ops.pow(b, a)
        # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) /
        #           (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) /
        #           (1 - beta1)\\)
        lr = ((lr_t * math_ops.sqrt(1 - math_ops.exp(pre_step_slice *
                                                     math_ops.log(beta2_t))) /
               (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) *
              (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) /
              (1 - beta1_t))
        # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\)
        var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t)
        var_update_op = state_ops.scatter_sub(var,
                                              indices,
                                              var_slice,
                                              use_locking=self._use_locking)

        with ops.control_dependencies([var_update_op]):
            # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\)
            # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
            # replace math_ops.pow(b, a)
            m_t_slice = (
                math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice +
                (1 - beta1_t) * grad)
            m_update_op = state_ops.scatter_update(
                m, indices, m_t_slice, use_locking=self._use_locking)

            # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\)
            # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
            # replace math_ops.pow(b, a)
            v_t_slice = (
                math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice +
                (1 - beta2_t) * math_ops.square(grad))
            v_update_op = state_ops.scatter_update(
                v, indices, v_t_slice, use_locking=self._use_locking)

        with ops.control_dependencies([m_update_op, v_update_op]):
            pre_step_update_op = state_ops.scatter_update(
                pre_step, indices, global_step, use_locking=self._use_locking)

        return control_flow_ops.group(var_update_op, m_update_op, v_update_op,
                                      pre_step_update_op)