コード例 #1
0
    def _apply_dense_on_obilique_with_noise(self, grad, var, seed):
        g = gutils.obilique_project(var, grad)
        g_norm = gutils.norm(g)
        if g_norm >= 1 / (self._times):
            a = 1 - 1 / (tf.square(self._times) * tf.square(g_norm))
        else:
            a = 1 / tf.square(self._times)
        b = 1 / tf.square(self._times)

        dim = grad.get_shape()[0]
        noise = tf.truncated_normal([dim, dim],
                                    mean=0.0,
                                    stddev=1.0,
                                    dtype=tf.float32,
                                    seed=seed,
                                    name="random_noise")

        if self._grad_clip == None:
            h = -self._learning_rate_t * (a * g + b * noise)
        else:
            h = -self._learning_rate_t * (a * g + b * noise)
            h = gutils.clip_by_norm(h, self._grad_clip_t)

        var_new = gutils.grassmann_retrction(var, h)

        return var_new
コード例 #2
0
def _apply_dense_on_oblique_with_noise(grad_clip, grad, var, seed,
                                       learning_rate, times):
    g = gutils.oblique_project(var, grad)
    g_norm = gutils.norm(g)
    #a = tf.minimum(1 - 1 / (tf.square(times + 1) * tf.square(g_norm) + 1e-5), 1 / tf.square(times + 1))
    a = 1.0
    b = 1 / (tf.square(times + 1))

    dim = tf.convert_to_tensor(grad.get_shape()[0], dtype=tf.int32)
    noise = tf.truncated_normal([dim, 1],
                                mean=0.0,
                                stddev=0.0001,
                                dtype=tf.float32,
                                seed=seed,
                                name="random_noise")

    if grad_clip == None:
        h = -1 * learning_rate * (a * g + b * noise)
    else:
        h = -1 * learning_rate * (a * g + b * noise)
        h = gutils.clip_by_norm(h, grad_clip)

    var_new = gutils.grassmann_retrction(var, h)

    return var_new
コード例 #3
0
def _apply_dense_on_oblique_o(grad_clip, grad_on_oblique, var, learning_rate,
                              times, delta):
    a = tf.maximum(delta, 1)  #/(tf.log(times+2))

    if grad_clip != None:
        h = -1 * learning_rate * (a * grad_on_oblique)
        h = gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_oblique)

    var_update = gutils.oblique_retrction(var, h)
    return var_update
コード例 #4
0
def apply_dense_on_grasssmann_g(grad_clip, grad_on_grassmann, var,
                                learning_rate, times, delta):
    a = tf.maximum(delta, 1)  #/ (tf.log(times+2))

    if grad_clip != None:
        h = learning_rate * (a * grad_on_grassmann)
        h = -1 * gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * a * grad_on_grassmann

    var_update = gutils.grassmann_retrction(var, h)
    return var_update
コード例 #5
0
    def _apply_dense(self, grad, var):
        mom = self.get_slot(var, "momentum")

        unity, _ = unit(var)  # for numerical stability
        h = gproj(unity, grad)

        if self._grad_clip_t != None:
            h_hat = clip_by_norm(h, self._grad_clip_t)
        else:
            h_hat = h

        mom_new = self._momentum_t * mom - self._learning_rate_t * h_hat

        var_update = tf.assign(var, gexp(unity, mom_new))
        mom_update = tf.assign(mom, gpt(unity, mom_new))

        return tf.group(*[var_update, mom_update])
コード例 #6
0
def apply_dense_on_grasssmann(grad_clip, grad_on_grassmann, grad_on_oblique,
                              var, learning_rate, times, delta):
    a = tf.maximum(delta, 1 / tf.log((tf.log((times + 2)))))
    n = gutils.unit(gutils.grassmann_project(
        var, grad_on_oblique)) * gutils.norm(grad_on_grassmann)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_grassmann, n)
    b_2 = gutils.norm(grad_on_grassmann)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = learning_rate * (a * grad_on_grassmann + b * n)
        h = -1 * gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_grassmann + b * n)

    var_update = gutils.grassmann_retrction(var, h)
    return var_update
コード例 #7
0
def _apply_dense_on_oblique(grad_clip, grad_on_grassmann, grad_on_oblique, var,
                            learning_rate, times, delta):
    a = torch.max(delta, 1 / torch.log(times + 2))
    # a=0.5
    n = gutils.unit(gutils.oblique_project(
        var, grad_on_grassmann)) * gutils.norm(grad_on_oblique)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_oblique, n)
    b_2 = gutils.norm(grad_on_oblique)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)
        h = gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)

    var_update = gutils.oblique_retrction(var, h)
    return var_update
コード例 #8
0
def _apply_dense_on_oblique_with_noise(grad_clip, grad, var, learning_rate,
                                       times, variance):
    g = gutils.oblique_project(var, grad)
    #g_norm = gutils.norm(g)
    #a = tf.minimum(1 - 1 / (tf.square(times + 1) * tf.square(g_norm) + 1e-5), 1 / tf.square(times + 1))

    a = 1.0
    b = 1 / torch.square(times + 1)

    noise = variance * gutils.oblique_project(var, torch.randn(var.size()[0]))

    if grad_clip == None:
        h = -1 * learning_rate * (a * g + b * noise)
    else:
        h = -1 * learning_rate * (a * g + b * noise)
        h = gutils.clip_by_norm(h, grad_clip)

    var_new = gutils.grassmann_retrction(var, h)

    return var_new
コード例 #9
0
    def _apply_dense_on_obilique(self, grad_on_grassmann, grad_on_obilique,
                                 var):
        a = tf.maximum(self._delta_t, 1 / (tf.square(self._times)))
        b_1 = 2 * (1 - a) * tf.matmul(
            tf.transpose(grad_on_obilique),
            gutils.obilique_project(var, grad_on_grassmann))
        b_2 = gutils.norm(gutils.obilique_project(grad_on_grassmann))
        b = b_1 / b_2

        if self._grad_clip != None:
            h = self._learning_rate_t * (
                a * grad_on_obilique +
                b * gutils.obilique_project(var, grad_on_grassmann))
            h = gutils.clip_by_norm(h, self._grad_clip_t)
        else:
            h = -self._learning_rate_t * (
                a * grad_on_obilique +
                b * gutils.obilique_project(var, grad_on_grassmann))

        var_update = gutils.obilique_retrction(var, h)
        return var_update
コード例 #10
0
    def _apply_dense(self, grad, var):
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        unity, _ = unit(var)  # for numerical stability
        h = gproj(unity, grad)

        if self._grad_clip_t != None:
            h_hat = clip_by_norm(h, self._grad_clip_t)
        else:
            h_hat = h

        mnew = self._beta1_t * m + (1.0 - self._beta1_t) * h_hat
        vnew = self._beta2_t * v + (1.0 - self._beta2_t) * xTy(h_hat, h_hat)

        alpha = tf.sqrt(1 - self._beta2_power) / (1. - self._beta1_power)
        deltas = (-alpha * self._lr_t) * mnew / tf.sqrt(vnew + self._epsilon_t)

        var_update = tf.assign(var, gexp(unity, deltas))
        m_update = tf.assign(m, gpt2(unity, mnew, deltas))
        v_update = tf.assign(v, vnew)

        return tf.group(*[var_update, m_update, v_update])
コード例 #11
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            #momentum = group['momentum']
            manifold = group['manifold']

            if manifold != "None":
                grad_clip = group['grad_clip']

                length = len(group['params'])

                for i in range(length):

                    p_grassmann = group['params'][i]
                    p_oblique = group['params'][i + length / 2]

                    if p_grassmann.grad and p_oblique is None:
                        continue

                    unity_grassmann, _ = gutils.unit(
                        p_grassmann.data.view(p_grassmann.size()[0], -1))
                    unity_oblique, _ = gutils.unit(
                        p_oblique.data.view(p_grassmann.size()[0], -1))

                    grad_grassmann = p_grassmann.grad.data.view(
                        p_grassmann.size()[0], -1)
                    grad_oblique = p_grassmann.grad.data.view(
                        p_oblique.size()[0], -1)

                    # if omega != 0:
                    # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                    # dL/dY=2(YY'Y-Y)
                    # g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h_grassmann = gutils.grassmann_project(
                        unity_grassmann, grad_grassmann)
                    h_oblique = gutils.oblique_project(unity_oblique,
                                                       grad_oblique)

                    if grad_clip is not None:
                        h_hat_grassmann = gutils.clip_by_norm(
                            h_grassmann, grad_clip)
                        h_hat_oblique = gutils.clip_by_norm(
                            h_oblique, grad_clip)
                    else:
                        h_hat_grassmann = h_grassmann
                        h_hat_oblique = h_oblique

                        # param_state = self.state[p]
                        # if 'momentum_buffer' not in param_state:
                        #    param_state['momentum_buffer'] = torch.zeros(h_hat.size())
                        #    if p.is_cuda:
                        #      param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda()

                        # mom = param_state['momentum_buffer']
                        # mom_new = momentum*mom - group['lr']*h_hat

                    p_grassmann.data.copy_(
                        gutils.grassmann_retrction(
                            unity_grassmann, group['lr'] *
                            h_hat_grassmann).view(p_grassmann.size()))
                    p_oblique.data.copy_(
                        gutils.oblique_retrction(unity_oblique, group['lr'] *
                                                 h_hat_oblique).view(
                                                     p_oblique.size()))

            elif manifold == "None":
                # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
                weight_decay = group['weight_decay']
                #dampening = group['dampening']
                #nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    #if momentum != 0:
                    #    param_state = self.state[p]
                    #    if 'momentum_buffer' not in param_state:
                    #        buf = param_state['momentum_buffer'] = d_p.clone()
                    #    else:
                    #        buf = param_state['momentum_buffer']
                    #        buf.mul_(momentum).add_(1 - dampening, d_p)
                    #    if nesterov:
                    #        d_p = d_p.add(momentum, buf)
                    #    else:
                    #        d_p = buf

                    p.data.add_(-group['lr'], d_p)
                else:
                    raise ValueError("There is no such a manifold")

        return loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            grassmann = group['grassmann']

            if grassmann:
                grad_clip = group['grad_clip']
                omega = group['omega']

                for p in group['params']:
                    if p.grad is None:
                        continue

                    unity,_ = unit(p.data.view(p.size()[0],-1))
                    g = p.grad.data.view(p.size()[0],-1)

                    if omega != 0:
                      # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                      # dL/dY=2(YY'Y-Y)
                      g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h = gproj(unity, g)

                    if grad_clip is not None:
                        h_hat = clip_by_norm(h, grad_clip)
                    else:
                        h_hat = h

                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        param_state['momentum_buffer'] = torch.zeros(h_hat.size())
                        if p.is_cuda:
                          param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda()

                    mom = param_state['momentum_buffer']
                    mom_new = momentum*mom - group['lr']*h_hat

                    p.data.copy_(gexp(unity, mom_new).view(p.size()))
                    mom.copy_(gpt(unity, mom_new))

            else:
                # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
                weight_decay = group['weight_decay']
                dampening = group['dampening']
                nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            grassmann = group['grassmann']

            if grassmann:
                beta1 = group['momentum']
                beta2 = group['beta2']
                epsilon = group['epsilon']
                grad_clip = group['grad_clip']
                omega = group['omega']

                for p in group['params']:
                    if p.grad is None:
                        continue

                    unity,_ = unit(p.data.view(p.size()[0],-1))
                    g = p.grad.data.view(p.size()[0],-1)

                    if omega != 0:
                      # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                      # dL/dY=2(YY'Y-Y)
                      g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h = gproj(unity, g)

                    if grad_clip is not None:
                        h_hat = clip_by_norm(h, grad_clip)
                    else:
                        h_hat = h

                    param_state = self.state[p]
                    if 'm_buffer' not in param_state:
                        size=p.size()
                        param_state['m_buffer'] = torch.zeros([size[0], int(np.prod(size[1:]))])
                        param_state['v_buffer'] = torch.zeros([size[0], 1])
                        if p.is_cuda:
                            param_state['m_buffer'] = param_state['m_buffer'].cuda()
                            param_state['v_buffer'] = param_state['v_buffer'].cuda()

                        param_state['beta1_power'] = beta1
                        param_state['beta2_power'] = beta2

                    m = param_state['m_buffer']
                    v = param_state['v_buffer']
                    beta1_power = param_state['beta1_power']
                    beta2_power = param_state['beta2_power']

                    mnew = beta1*m  + (1.0-beta1)*h_hat
                    vnew = beta2*v  + (1.0-beta2)*xTy(h_hat,h_hat)

                    alpha = np.sqrt(1.-beta2_power) / (1.-beta1_power)
                    deltas = mnew / vnew.add(epsilon).sqrt()
                    deltas.mul_(-alpha*group['lr'])

                    p.data.copy_(gexp(unity, deltas).view(p.size()))
                    m.copy_(gpt2(unity, mnew, deltas))
                    v.copy_(vnew)

                    param_state['beta1_power']*=beta1
                    param_state['beta2_power']*=beta2
            else:
                momentum = group['momentum']
                weight_decay = group['weight_decay']
                dampening = group['dampening']
                nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss