def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)

        lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)
        # Learning rate multipliers
        if self.lr_multipliers is not None:
            lr_t = _apply_lr_multiplier(self, lr_t, var)

        m_scaled_g_values = grad * (1 - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

        if self.amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            var_delta = m_t / (math_ops.sqrt(vhat_t) + epsilon_t)
        else:
            var_delta = m_t / (math_ops.sqrt(v_t) + epsilon_t)
        var_t = math_ops.sub(var, self.eta_t * lr_t * var_delta)

        # Weight decays
        if var.name in self.weight_decays.keys():
            var_t = _apply_weight_decays(self, var, var_t)

        var_update = state_ops.assign(var,
                                      var_t,
                                      use_locking=self._use_locking)

        # Cosine annealing
        (iteration_done, t_cur_update,
         eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var)
        if iteration_done and not self._init_notified:
            self._init_notified = True

        updates = [var_update, m_t, v_t]
        if iteration_done:
            updates += [t_cur_update]
        if self.use_cosine_annealing and iteration_done:
            updates += [eta_t_update]
        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)
    def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)

        # Learning rate multipliers
        if self.lr_multipliers is not None:
            lr_t = _apply_lr_multiplier(self, lr_t, var)

        if self._momentum:
            momentum = array_ops.identity(
                self._get_hyper('momentum', var_dtype))
            m = self.get_slot(var, 'momentum')
            v = momentum * m - self.eta_t * lr_t * grad
            m = state_ops.assign(m, v, use_locking=self._use_locking)

            if self.nesterov:
                var_t = self._resource_scatter_add(
                    var, indices, momentum * v - (self.eta_t * lr_t * grad))
            else:
                var_t = self._resource_scatter_add(var, indices, v)
        else:
            v = -self.eta_t * lr_t * grad
            var_t = var + v

        # Weight decays
        if var.name in self.weight_decays.keys():
            var_t = _apply_weight_decays(self, var, var_t)

        var_update = state_ops.assign(var,
                                      var_t,
                                      use_locking=self._use_locking)

        # Cosine annealing
        (iteration_done, t_cur_update,
         eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var)
        if iteration_done and not self._init_notified:
            self._init_notified = True

        updates = [var_update]
        if self._momentum:
            updates += [m]
        if iteration_done:
            updates += [t_cur_update]
        if self.use_cosine_annealing and iteration_done:
            updates += [eta_t_update]
        return control_flow_ops.group(*updates)
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        var_dtype = var.dtype.base_dtype
        lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        next_step = math_ops.cast(self.iterations + 2, var_dtype)
        decay_base = math_ops.cast(0.96, var_dtype)

        # Learning rate multipliers
        if self.lr_multipliers is not None:
            lr_t = _apply_lr_multiplier(self, lr_t, var)

        momentum_cache_t = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * local_step)))
        momentum_cache_t_1 = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * next_step)))
        m_schedule_new = math_ops.cast(self._m_cache_read,
                                       var_dtype) * momentum_cache_t
        if var_dtype is self._m_cache.dtype:
            m_schedule_new = array_ops.identity(
                state_ops.assign(self._m_cache,
                                 m_schedule_new,
                                 use_locking=self._use_locking))
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        m_scaled_g_values = grad * (1. - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
            m_t_slice = array_ops.gather(m_t, indices)

        m_t_prime = m_t_slice / (1. - m_schedule_next)
        g_prime = grad / (1. - m_schedule_new)
        m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t_1 *
                                                       m_t_prime)

        v_scaled_g_values = (grad * grad) * (1. - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)

        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
            v_t_slice = array_ops.gather(v_t, indices)

        v_t_prime_denominator = 1. - math_ops.pow(beta_2_t, local_step)
        v_t_prime = v_t_slice / v_t_prime_denominator
        v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + epsilon_t

        var_t = self._resource_scatter_add(
            var, indices, -self.eta_t * lr_t * m_t_bar / v_prime_sqrt_plus_eps)

        # Weight decays
        if var.name in self.weight_decays.keys():
            var_t = _apply_weight_decays(self, var, var_t)

        var_update = state_ops.assign(var,
                                      var_t,
                                      use_locking=self._use_locking)

        # Cosine annealing
        (iteration_done, t_cur_update,
         eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var)
        if iteration_done and not self._init_notified:
            self._init_notified = True

        updates = [var_update, m_t_bar, v_t]
        if iteration_done:
            updates += [t_cur_update]
        if self.use_cosine_annealing and iteration_done:
            updates += [eta_t_update]
        return control_flow_ops.group(*updates)
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        next_step = math_ops.cast(self.iterations + 2, var_dtype)
        decay_base = math_ops.cast(0.96, var_dtype)

        # Learning rate multipliers
        if self.lr_multipliers is not None:
            lr_t = _apply_lr_multiplier(self, lr_t, var)

        # Due to the recommendations in [2], i.e. warming momentum schedule
        momentum_cache_t = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * local_step)))
        momentum_cache_t_1 = beta_1_t * (
            1. - 0.5 *
            (math_ops.pow(decay_base, self._initial_decay * next_step)))
        m_schedule_new = math_ops.cast(self._m_cache_read,
                                       var_dtype) * momentum_cache_t
        if var_dtype is self._m_cache.dtype:
            m_schedule_new = array_ops.identity(
                state_ops.assign(self._m_cache,
                                 m_schedule_new,
                                 use_locking=self._use_locking))
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        # the following equations given in [1]
        g_prime = grad / (1. - m_schedule_new)
        m_t = beta_1_t * m + (1. - beta_1_t) * grad
        m_t_prime = m_t / (1. - m_schedule_next)
        v_t = beta_2_t * v + (1. - beta_2_t) * math_ops.square(grad)
        v_t_prime = v_t / (1. - math_ops.pow(beta_2_t, local_step))
        m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t *
                                                       m_t_prime)

        m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
        v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)

        var_t = math_ops.sub(
            var, self.eta_t * lr_t * m_t_bar /
            (math_ops.sqrt(v_t_prime + epsilon_t)))

        # Weight decays
        if var.name in self.weight_decays.keys():
            var_t = _apply_weight_decays(self, var, var_t)

        var_update = state_ops.assign(var,
                                      var_t,
                                      use_locking=self._use_locking)

        # Cosine annealing
        (iteration_done, t_cur_update,
         eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var)
        if iteration_done and not self._init_notified:
            self._init_notified = True

        updates = [var_update, m_t, v_t]
        if iteration_done:
            updates += [t_cur_update]
        if self.use_cosine_annealing and iteration_done:
            updates += [eta_t_update]
        return control_flow_ops.group(*updates)