Esempio n. 1
0
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)

        if self._initial_total_steps > 0:
            total_steps = self._get_hyper('total_steps', var_dtype)
            warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype)
            min_lr = self._get_hyper('min_lr', var_dtype)
            lr_t = tf.where(
                local_step <= warmup_steps,
                lr_t * (local_step / warmup_steps),
                min_lr + (lr_t - min_lr) * (1.0 - tf.minimum(local_step, total_steps) / total_steps),
            )

        sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
        sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)

        m_t = state_ops.assign(m,
                               beta_1_t * m + (1.0 - beta_1_t) * grad,
                               use_locking=self._use_locking)
        m_corr_t = m_t / (1.0 - beta_1_power)

        v_t = state_ops.assign(v,
                               beta_2_t * v + (1.0 - beta_2_t) * math_ops.square(grad),
                               use_locking=self._use_locking)
        if self.amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power) + epsilon_t)
        else:
            v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power) + epsilon_t)

        r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) *
                            (sma_t - 2.0) / (sma_inf - 2.0) *
                            sma_inf / sma_t)

        var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / v_corr_t, m_corr_t)

        if self._initial_weight_decay > 0.0:
            var_t += self._get_hyper('weight_decay', var_dtype) * var

        var_update = state_ops.assign_sub(var,
                                          lr_t * var_t,
                                          use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)
Esempio n. 2
0
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)

        decay_steps = self._get_hyper('decay_steps', var_dtype)
        warmup_steps = self._get_hyper('warmup_steps', var_dtype)
        min_lr = self._get_hyper('min_lr', var_dtype)
        lr_t = tf.where(
            local_step <= warmup_steps,
            lr_t * (local_step / warmup_steps),
            min_lr + (lr_t - min_lr) *
            (1.0 - tf.minimum(local_step, decay_steps) / decay_steps),
        )
        lr_t = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))

        m_t = state_ops.assign(m,
                               beta_1_t * m + (1.0 - beta_1_t) * grad,
                               use_locking=self._use_locking)

        v_t = state_ops.assign(v,
                               beta_2_t * v +
                               (1.0 - beta_2_t) * math_ops.square(grad),
                               use_locking=self._use_locking)

        if self.amsgrad:
            v_hat = self.get_slot(var, 'vhat')
            v_hat_t = math_ops.maximum(v_hat, v_t)
            var_update = m_t / (math_ops.sqrt(v_hat_t) + epsilon_t)
        else:
            var_update = m_t / (math_ops.sqrt(v_t) + epsilon_t)

        if self._initial_weight_decay > 0.0:
            weight_decay = self._get_hyper('weight_decay', var_dtype)
            var_update += weight_decay * var
        var_update = state_ops.assign_sub(var,
                                          lr_t * var_update,
                                          use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(v_hat_t)
        return control_flow_ops.group(*updates)
Esempio n. 3
0
    def _apply_dense_shared(self, grad, var):
        var_dtype = var.dtype.base_dtype
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var_dtype)
        beta2_power = math_ops.cast(beta2_power, var_dtype)
        niter = self._get_niter()
        niter = math_ops.cast(niter, var_dtype)
        lr_t = math_ops.cast(self._lr_t, var_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var_dtype)

        sma_inf = 2.0 / (1.0 - beta2_t) - 1.0
        sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power)

        m = self.get_slot(var, 'm')
        m_t = state_ops.assign(m,
                               beta1_t * m + (1.0 - beta1_t) * grad,
                               use_locking=self._use_locking)
        m_corr_t = m_t / (1.0 - beta1_power)

        v = self.get_slot(var, 'v')
        v_t = state_ops.assign(v,
                               beta2_t * v +
                               (1.0 - beta2_t) * math_ops.square(grad),
                               use_locking=self._use_locking)

        if self._amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t)
        else:
            v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t)

        r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) /
                            (sma_inf - 2.0) * sma_inf / sma_t)

        var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t)

        var_update = state_ops.assign_sub(var,
                                          lr_t * var_t,
                                          use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self._amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)