コード例 #1
0
def apply_dense_on_grasssmann(grad_clip, grad_on_grassmann, grad_on_oblique,
                              var, learning_rate, times, delta):
    a = tf.maximum(delta, 1 / tf.log((tf.log((times + 2)))))
    n = gutils.unit(gutils.grassmann_project(
        var, grad_on_oblique)) * gutils.norm(grad_on_grassmann)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_grassmann, n)
    b_2 = gutils.norm(grad_on_grassmann)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = learning_rate * (a * grad_on_grassmann + b * n)
        h = -1 * gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_grassmann + b * n)

    var_update = gutils.grassmann_retrction(var, h)
    return var_update
コード例 #2
0
def _apply_dense_on_oblique(grad_clip, grad_on_grassmann, grad_on_oblique, var,
                            learning_rate, times, delta):
    a = torch.max(delta, 1 / torch.log(times + 2))
    # a=0.5
    n = gutils.unit(gutils.oblique_project(
        var, grad_on_grassmann)) * gutils.norm(grad_on_oblique)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_oblique, n)
    b_2 = gutils.norm(grad_on_oblique)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)
        h = gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)

    var_update = gutils.oblique_retrction(var, h)
    return var_update
コード例 #3
0
    def _apply_dense(self, grad, var):
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        unity, _ = unit(var)  # for numerical stability
        h = gproj(unity, grad)

        if self._grad_clip_t != None:
            h_hat = clip_by_norm(h, self._grad_clip_t)
        else:
            h_hat = h

        mnew = self._beta1_t * m + (1.0 - self._beta1_t) * h_hat
        vnew = self._beta2_t * v + (1.0 - self._beta2_t) * xTy(h_hat, h_hat)

        alpha = tf.sqrt(1 - self._beta2_power) / (1. - self._beta1_power)
        deltas = (-alpha * self._lr_t) * mnew / tf.sqrt(vnew + self._epsilon_t)

        var_update = tf.assign(var, gexp(unity, deltas))
        m_update = tf.assign(m, gpt2(unity, mnew, deltas))
        v_update = tf.assign(v, vnew)

        return tf.group(*[var_update, m_update, v_update])
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            grassmann = group['grassmann']

            if grassmann:
                beta1 = group['momentum']
                beta2 = group['beta2']
                epsilon = group['epsilon']
                grad_clip = group['grad_clip']
                omega = group['omega']

                for p in group['params']:
                    if p.grad is None:
                        continue

                    unity,_ = unit(p.data.view(p.size()[0],-1))
                    g = p.grad.data.view(p.size()[0],-1)

                    if omega != 0:
                      # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                      # dL/dY=2(YY'Y-Y)
                      g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h = gproj(unity, g)

                    if grad_clip is not None:
                        h_hat = clip_by_norm(h, grad_clip)
                    else:
                        h_hat = h

                    param_state = self.state[p]
                    if 'm_buffer' not in param_state:
                        size=p.size()
                        param_state['m_buffer'] = torch.zeros([size[0], int(np.prod(size[1:]))])
                        param_state['v_buffer'] = torch.zeros([size[0], 1])
                        if p.is_cuda:
                            param_state['m_buffer'] = param_state['m_buffer'].cuda()
                            param_state['v_buffer'] = param_state['v_buffer'].cuda()

                        param_state['beta1_power'] = beta1
                        param_state['beta2_power'] = beta2

                    m = param_state['m_buffer']
                    v = param_state['v_buffer']
                    beta1_power = param_state['beta1_power']
                    beta2_power = param_state['beta2_power']

                    mnew = beta1*m  + (1.0-beta1)*h_hat
                    vnew = beta2*v  + (1.0-beta2)*xTy(h_hat,h_hat)

                    alpha = np.sqrt(1.-beta2_power) / (1.-beta1_power)
                    deltas = mnew / vnew.add(epsilon).sqrt()
                    deltas.mul_(-alpha*group['lr'])

                    p.data.copy_(gexp(unity, deltas).view(p.size()))
                    m.copy_(gpt2(unity, mnew, deltas))
                    v.copy_(vnew)

                    param_state['beta1_power']*=beta1
                    param_state['beta2_power']*=beta2
            else:
                momentum = group['momentum']
                weight_decay = group['weight_decay']
                dampening = group['dampening']
                nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss