Exemplo n.º 1
0
def apply_dense_on_grasssmann(grad_clip, grad_on_grassmann, grad_on_oblique,
                              var, learning_rate, times, delta):
    a = tf.maximum(delta, 1 / tf.log((tf.log((times + 2)))))
    n = gutils.unit(gutils.grassmann_project(
        var, grad_on_oblique)) * gutils.norm(grad_on_grassmann)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_grassmann, n)
    b_2 = gutils.norm(grad_on_grassmann)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = learning_rate * (a * grad_on_grassmann + b * n)
        h = -1 * gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_grassmann + b * n)

    var_update = gutils.grassmann_retrction(var, h)
    return var_update
Exemplo n.º 2
0
    def _apply_dense(self, grad, var):
        mom = self.get_slot(var, "momentum")

        unity, _ = unit(var)  # for numerical stability
        h = gproj(unity, grad)

        if self._grad_clip_t != None:
            h_hat = clip_by_norm(h, self._grad_clip_t)
        else:
            h_hat = h

        mom_new = self._momentum_t * mom - self._learning_rate_t * h_hat

        var_update = tf.assign(var, gexp(unity, mom_new))
        mom_update = tf.assign(mom, gpt(unity, mom_new))

        return tf.group(*[var_update, mom_update])
Exemplo n.º 3
0
def _apply_dense_on_oblique(grad_clip, grad_on_grassmann, grad_on_oblique, var,
                            learning_rate, times, delta):
    a = torch.max(delta, 1 / torch.log(times + 2))
    # a=0.5
    n = gutils.unit(gutils.oblique_project(
        var, grad_on_grassmann)) * gutils.norm(grad_on_oblique)
    b_1 = 2 * (1 - a) * gutils.xTy(grad_on_oblique, n)
    b_2 = gutils.norm(grad_on_oblique)
    b = b_1 / (b_2 + 1e-5)

    if grad_clip != None:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)
        h = gutils.clip_by_norm(h, grad_clip)
    else:
        h = -1 * learning_rate * (a * grad_on_oblique + b * n)

    var_update = gutils.oblique_retrction(var, h)
    return var_update
Exemplo n.º 4
0
    def _apply_dense(self, grad, var):
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        unity, _ = unit(var)  # for numerical stability
        h = gproj(unity, grad)

        if self._grad_clip_t != None:
            h_hat = clip_by_norm(h, self._grad_clip_t)
        else:
            h_hat = h

        mnew = self._beta1_t * m + (1.0 - self._beta1_t) * h_hat
        vnew = self._beta2_t * v + (1.0 - self._beta2_t) * xTy(h_hat, h_hat)

        alpha = tf.sqrt(1 - self._beta2_power) / (1. - self._beta1_power)
        deltas = (-alpha * self._lr_t) * mnew / tf.sqrt(vnew + self._epsilon_t)

        var_update = tf.assign(var, gexp(unity, deltas))
        m_update = tf.assign(m, gpt2(unity, mnew, deltas))
        v_update = tf.assign(v, vnew)

        return tf.group(*[var_update, m_update, v_update])
Exemplo n.º 5
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            stiefel = group['stiefel']

            for p in group['params']:
                if p.grad is None:
                    continue

                unity, _ = unit(p.data.view(p.size()[0], -1))
                if stiefel and unity.size()[0] <= unity.size()[1]:

                    weight_decay = group['weight_decay']
                    dampening = group['dampening']
                    nesterov = group['nesterov']

                    rand_num = random.randint(1, 101)
                    if rand_num == 1:
                        unity = qr_retraction(unity)

                    g = p.grad.data.view(p.size()[0], -1)

                    lr = group['lr']

                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        param_state['momentum_buffer'] = torch.zeros(
                            g.t().size())
                        if p.is_cuda:
                            param_state['momentum_buffer'] = param_state[
                                'momentum_buffer'].cuda()

                    V = param_state['momentum_buffer']
                    V = momentum * V - g.t()
                    MX = torch.mm(V, unity)
                    XMX = torch.mm(unity, MX)
                    XXMX = torch.mm(unity.t(), XMX)
                    W_hat = MX - 0.5 * XXMX
                    W = W_hat - W_hat.t()
                    t = 0.5 * 2 / (matrix_norm_one(W) + episilon)
                    alpha = min(t, lr)

                    p_new = Cayley_loop(unity.t(), W, V, alpha)
                    V_new = torch.mm(W, unity.t())  # n-by-p
                    #                     check_identity(p_new.t())
                    p.data.copy_(p_new.view(p.size()))
                    V.copy_(V_new)

                else:
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss
Exemplo n.º 6
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            stiefel = group['stiefel']

            for p in group['params']:
                if p.grad is None:
                    continue

                beta1 = group['momentum']
                beta2 = group['beta2']
                epsilon = group['epsilon']

                unity, _ = unit(p.data.view(p.size()[0], -1))
                if stiefel and unity.size()[0] <= unity.size()[1]:
                    rand_num = random.randint(1, 101)
                    if rand_num == 1:
                        unity = qr_retraction(unity)

                    g = p.grad.data.view(p.size()[0], -1)

                    param_state = self.state[p]
                    if 'm_buffer' not in param_state:
                        size = p.size()
                        param_state['m_buffer'] = torch.zeros(
                            [int(np.prod(size[1:])), size[0]])
                        param_state['v_buffer'] = torch.zeros([1])
                        if p.is_cuda:
                            param_state['m_buffer'] = param_state[
                                'm_buffer'].cuda()
                            param_state['v_buffer'] = param_state[
                                'v_buffer'].cuda()

                        param_state['beta1_power'] = beta1
                        param_state['beta2_power'] = beta2

                    m = param_state['m_buffer']
                    v = param_state['v_buffer']
                    beta1_power = param_state['beta1_power']
                    beta2_power = param_state['beta2_power']

                    mnew = beta1 * m + (1.0 - beta1) * g.t()  # p by n
                    vnew = beta2 * v + (1.0 - beta2) * (torch.norm(g)**2)

                    mnew_hat = mnew / (1 - beta1_power)
                    vnew_hat = vnew / (1 - beta2_power)

                    MX = torch.matmul(mnew_hat, unity)
                    XMX = torch.matmul(unity, MX)
                    XXMX = torch.matmul(unity.t(), XMX)
                    W_hat = MX - 0.5 * XXMX
                    W = (W_hat - W_hat.t()) / vnew_hat.add(epsilon).sqrt()

                    t = 0.5 * 2 / (matrix_norm_one(W) + episilon)
                    alpha = min(t, group['lr'])

                    p_new = Cayley_loop(unity.t(), W, mnew, -alpha)

                    p.data.copy_(p_new.view(p.size()))
                    mnew = torch.matmul(W, unity.t()) * vnew_hat.add(
                        epsilon).sqrt() * (1 - beta1_power)
                    m.copy_(mnew)
                    v.copy_(vnew)

                    param_state['beta1_power'] *= beta1
                    param_state['beta2_power'] *= beta2

                else:
                    momentum = group['momentum']
                    weight_decay = group['weight_decay']
                    dampening = group['dampening']
                    nesterov = group['nesterov']
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss
Exemplo n.º 7
0
def inference(input_tensor, train, regularizer):
    input_tensor = batch_norm.batch_norm(input_tensor)

    with tf.variable_scope('layer1-conv1_oblique'):
        conv1_weights_o = tf.get_variable(
            "weight_o",
            shape=[CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=0.1))
        tf.assign(conv1_weights_o, gutils.unit(conv1_weights_o))
        conv1_biases_o = tf.get_variable(
            "biases_o",
            shape=[CONV1_DEEP],
            initializer=tf.constant_initializer(0.0))

        conv1_weights_o_tmp = tf.get_variable(
            "weight_o_tmp",
            shape=[CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=1))
        conv1_biases_o_tmp = tf.get_variable(
            "biases_o_tmp",
            shape=[CONV1_DEEP],
            initializer=tf.constant_initializer(0.0))

        #卷积网络前向传播,这里步长为1且做全0填充,输出是28*28*32的矩阵,步幅就在第二个参数矩阵里面了。
        conv1_o = tf.nn.conv2d(input_tensor,
                               conv1_weights_o,
                               strides=[1, 1, 1, 1],
                               padding='SAME')
        conv1_batch_o = batch_norm.batch_norm(conv1_o, scale=None)
        relu1_o_oblique = tf.nn.relu(
            tf.nn.bias_add(conv1_batch_o, conv1_biases_o))

    with tf.name_scope('layer2-pool1_oblique'):
        pool1_o = tf.nn.max_pool(relu1_o_oblique,
                                 ksize=[1, 3, 3, 1],
                                 strides=[1, 2, 2, 1],
                                 padding='SAME')  # 第二个参数是步幅,第三个参数是步长
        pool1_batch_o_oblique = batch_norm.batch_norm(pool1_o, scale=None)

    with tf.variable_scope('layer3-conv2_oblique'):
        conv2_weights_o = tf.get_variable(
            'weight_o',
            shape=[CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=0.1))
        tf.assign(conv2_weights_o, gutils.unit(conv2_weights_o))
        conv2_biases_o = tf.get_variable(
            'biases_o',
            shape=[CONV2_DEEP],
            initializer=tf.constant_initializer(0.0))

        conv2_weights_o_tmp = tf.get_variable(
            "weight_o_tmp",
            shape=[CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=1))

        conv2_biases_o_tmp = tf.get_variable(
            "biases_o_tmp",
            shape=[CONV2_DEEP],
            initializer=tf.constant_initializer(0.0))

        #卷积网络前向传播
        conv2_o = tf.nn.conv2d(pool1_batch_o_oblique,
                               conv2_weights_o,
                               strides=[1, 1, 1, 1],
                               padding='SAME')
        conv2_batch_o = batch_norm.batch_norm(conv2_o, scale=None)
        relu2_o_oblique = tf.nn.relu(
            tf.nn.bias_add(conv2_batch_o, conv2_biases_o))

    with tf.name_scope('layer4-pool2_oblique'):
        pool2_o = tf.nn.max_pool(relu2_o_oblique,
                                 ksize=[1, 3, 3, 1],
                                 strides=[1, 2, 2, 1],
                                 padding='SAME')  # 第二个参数是步幅,第三个参数是步长
        pool2_batch_o_oblique = batch_norm.batch_norm(pool2_o, scale=None)

    with tf.variable_scope('layer5-fc1_oblique'):
        pool_shape = pool2_batch_o_oblique.get_shape().as_list()
        # pool_shape的第一个数据pool_shape[0]就是batch的大小
        nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
        # 重新改变输入的结构把它拉成一个向量做全连接
        reshaped_o = tf.reshape(pool2_batch_o_oblique, [-1, nodes])

        fc1_weights_o = tf.get_variable(
            'weight_o',
            shape=[nodes, FC_SIZE],
            initializer=tf.truncated_normal_initializer(stddev=0.1))

        tf.assign(fc1_weights_o, gutils.unit(fc1_weights_o))

        fc1_weights_o_tmp = tf.get_variable(
            'weight_o_tmp',
            shape=[nodes, FC_SIZE],
            initializer=tf.truncated_normal_initializer(stddev=0.1))

        #只对全连接参数做正则化
        if regularizer != None:
            tf.add_to_collection('losses_o_oblique',
                                 regularizer(fc1_weights_o))
        fc1_biases_o = tf.get_variable(
            'biases_o',
            shape=[FC_SIZE],
            initializer=tf.constant_initializer(0.0))
        fc1_biases_o_tmp = tf.get_variable(
            "biases_o_tmp",
            shape=[FC_SIZE],
            initializer=tf.constant_initializer(0.0))

        fc1_o_oblique = tf.nn.relu(
            tf.matmul(reshaped_o, fc1_weights_o) + fc1_biases_o)

        if train:
            fc1_o_oblique = tf.nn.dropout(fc1_o_oblique, 0.5)

    with tf.variable_scope('layer6-fc2_oblique'):
        fc2_weights_o = tf.get_variable(
            'weight_o',
            shape=[FC_SIZE, NUM_LABELS],
            initializer=tf.truncated_normal_initializer(stddev=0.1))
        tf.assign(fc2_weights_o, gutils.unit(fc2_weights_o))

        fc2_weights_o_tmp = tf.get_variable(
            'weight_o_tmp',
            shape=[FC_SIZE, NUM_LABELS],
            initializer=tf.truncated_normal_initializer(stddev=0.1))

        if regularizer != None:
            tf.add_to_collection('losses_o_oblique',
                                 regularizer(fc2_weights_o))

        fc2_biases_o = tf.get_variable(
            'biases_o',
            shape=[NUM_LABELS],
            initializer=tf.constant_initializer(0.0))
        fc2_biases_o_tmp = tf.get_variable(
            "biases_o_tmp",
            shape=[NUM_LABELS],
            initializer=tf.constant_initializer(0.0))

        logit_o_oblique = tf.matmul(fc1_o_oblique,
                                    fc2_weights_o) + fc2_biases_o

        return logit_o_oblique
Exemplo n.º 8
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            #momentum = group['momentum']
            manifold = group['manifold']
            learning_rate = group['lr']
            variance = group['variance']
            times = group['times']

            if manifold != 'None':
                grad_clip = group['grad_clip']

                length = len(group['params'])

                for i in range(length):

                    p_grassmann = group['params'][i]
                    p_oblique = group['params'][i + length / 2]

                    if p_grassmann.grad and p_oblique is None:
                        continue

                    unity_grassmann, _ = gutils.unit(
                        p_grassmann.data.view(p_grassmann.size()[0], -1))
                    unity_oblique, _ = gutils.unit(
                        p_oblique.data.view(p_grassmann.size()[0], -1))

                    grad_grassmann = p_grassmann.grad.data.view(
                        p_grassmann.size()[0], -1)
                    grad_oblique = p_grassmann.grad.data.view(
                        p_oblique.size()[0], -1)

                    # if omega != 0:
                    # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                    # dL/dY=2(YY'Y-Y)
                    # g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h_grassmann = gutils.grassmann_project(
                        unity_grassmann, grad_grassmann)
                    h_oblique = gutils.oblique_project(unity_oblique,
                                                       grad_oblique)

                    # param_state = self.state[p]
                    # if 'momentum_buffer' not in param_state:
                    #    param_state['momentum_buffer'] = torch.zeros(h_hat.size())
                    #    if p.is_cuda:
                    #      param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda()

                    # mom = param_state['momentum_buffer']
                    # mom_new = momentum*mom - group['lr']*h_hat

                    p_grassmann.data.copy_(
                        _apply_dense_on_grassmann_with_noise(
                            grad_clip, h_grassmann, unity_grassmann,
                            learning_rate, times,
                            variance).view(p_grassmann.size()))

                    p_oblique.data.copy_(
                        _apply_dense_on_oblique_with_noise(
                            grad_clip, h_oblique, unity_oblique, learning_rate,
                            times, variance).view(p_oblique.size()))

            elif manifold == "None":
                # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
                weight_decay = group['weight_decay']
                # dampening = group['dampening']
                # nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    # if momentum != 0:
                    #    param_state = self.state[p]
                    #    if 'momentum_buffer' not in param_state:
                    #        buf = param_state['momentum_buffer'] = d_p.clone()
                    #    else:
                    #        buf = param_state['momentum_buffer']
                    #        buf.mul_(momentum).add_(1 - dampening, d_p)
                    #    if nesterov:
                    #        d_p = d_p.add(momentum, buf)
                    #    else:
                    #        d_p = buf

                    p.data.add_(-group['lr'], d_p)

            else:
                raise ValueError("There is no such a manifold")

        return loss
Exemplo n.º 9
0
def main():
    opt = parser.parse_args()
    print('parsed options:', vars(opt))
    epoch_step = json.loads(opt.epoch_step)
    num_classes = 10 if opt.dataset == 'CIFAR10' else 100

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # to prevent opencv from initializing CUDA in workers
    torch.randn(8).cuda()
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    def create_iterator(mode):
        ds = create_dataset(opt, mode)
        return ds.parallel(batch_size=opt.batchSize,
                           shuffle=mode,
                           num_workers=opt.nthread,
                           pin_memory=True)

    train_loader = create_iterator(True)
    test_loader = create_iterator(False)

    if opt.optim_method == 'SGDM' or opt.optim_method == 'SGDE' or opt.optim_method == 'SGDN':

        f_grassmann, params_grassmann, stats_grassmann = resnet.resnet_grassmann(
            opt.depth, opt.width, num_classes)
        f_oblique, params_oblique, stats_oblique = resnet.resnet_oblique(
            opt.depth, opt.width, num_classes)

        key_g = []
        key_o = []

        param_g = []
        param_g_e0 = []
        param_g_e1 = []

        param_o = []
        param_o_e0 = []
        param_o_e1 = []

        params_total = []

        for key, value in params_grassmann.items():
            if 'conv' in key and value.size()[0] < np.prod(value.size()[1:]):
                params_total.append(value)
                key_g.append(key)
                # initlize to scale 1
                unitp = unit(value.data.view(value.size(0), -1))
                value.data.copy_(unitp.view(value.size()))
            elif 'bn' in key or 'bias' in key:
                param_g_e0.append(value)
            else:
                param_g_e1.append(value)

        for key, value in params_oblique.items():
            if 'conv' in key and value.size()[0] < np.prod(value.size()[1:]):
                params_total.append(value)
                key_o.append(key)
                # initlize to scale 1
                unitp = unit(value.data.view(value.size(0), -1))
                value.data.copy_(unitp.view(value.size()))
            elif 'bn' in key or 'bias' in key:
                param_o_e0.append(value)
            else:
                param_o_e1.append(value)

        def create_optimizer(opt, lr, lrm, times):
            print('creating optimizer with lr = ', lrm)

            if opt.optim_method == 'SGDM':
                dict_total = {
                    'params': params_total,
                    'lr': lrm,
                    'manifold': 'True',
                    'grad_clip': opt.grad_clip
                }
                dict_g_e0 = {
                    'params': param_g_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }
                dict_g_e1 = {
                    'params': param_g_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }

                dict_o_e0 = {
                    'params': param_o_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }
                dict_o_e1 = {
                    'params': param_o_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }

                return optimize_function.SGDM(
                    [dict_total, dict_g_e0, dict_g_e1, dict_o_e0, dict_o_e1])

            elif opt.optim_method == 'SGDE':
                dict_total = {
                    'params': params_total,
                    'times': times,
                    'lr': lrm,
                    'manifold': 'True',
                    'grad_clip': opt.grad_clip
                }
                dict_g_e0 = {
                    'params': param_g_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }
                dict_g_e1 = {
                    'params': param_g_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }
                dict_o_e0 = {
                    'params': param_o_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }
                dict_o_e1 = {
                    'params': param_o_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }

                return optimize_function.SGDM(
                    [dict_total, dict_g_e0, dict_g_e1, dict_o_e0, dict_o_e1])

            elif opt.optim_method == 'SGDN':
                dict_total = {
                    'params': params_total,
                    'lr': lrm,
                    'times': times,
                    'manifold': 'True',
                    'grad_clip': opt.grad_clip
                }
                dict_g_e0 = {
                    'params': param_g_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }
                dict_g_e1 = {
                    'params': param_g_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None'
                }

                dict_o_e0 = {
                    'params': param_o_e0,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }
                dict_o_e1 = {
                    'params': param_o_e1,
                    'lr': lr,
                    'weight_decay': opt.bnDecay,
                    'manifold': 'None',
                    'label': 'oblique'
                }

                return optimize_function.SGDM(
                    [dict_total, dict_g_e0, dict_g_e1, dict_o_e0, dict_o_e1])

        epoch = 0
        optimizer = create_optimizer(opt, opt.lr, opt.lrm, epoch)

        if opt.resume != '':
            state_dict = torch.load(opt.resume)
            epoch = state_dict['epoch']
            if opt.optim_method != 'SGD':
                params_tensor, stats, manifold, label = state_dict['params'], state_dict['stats'], state_dict['manifold'], \
                                                        state_dict['label']
                size = manifold.size()[0]

                for i in range(size):

                    if state_dict['manifold'][i] != 'None':
                        length = params_tensor[i].size()[0] / 2

                        tmp_grassmann = list(params_grassmann.items())
                        tmp_grassmann[i].data.copy_(params_tensor[i][0:length])

                        tmp_oblique = list(params_oblique.items())
                        tmp_oblique[i].data.copy_(params_tensor[i][length:])

                    elif state_dict['label'][i] == 'grassmann':

                        tmp_grassmann = list(params_grassmann.items())
                        tmp_grassmann[i].data.copy_(params_tensor[i])

                    else:
                        tmp_oblique = list(params_grassmann.items())
                        tmp_oblique[i].data.copy_(params_tensor[i])

                optimizer.load_state_dict(state_dict['optimizer'])

        print('\nParameters:')
        kmax = max(len(key) for key in params_grassmann.keys())
        for i, (key, v) in enumerate(params_grassmann.items()):
            print(str(i).ljust(5),
                  key.ljust(kmax + 3),
                  str(tuple(v.size())).ljust(23),
                  torch.typename(v.data),
                  end='')
            print(' on G(1,n)' if key in key_g else '')

        meter_loss_ensemble = tnt.meter.AverageValueMeter()
        classacc_ensemble = tnt.meter.ClassErrorMeter(accuracy=True)

        timer_train = tnt.meter.TimeMeter('s')
        timer_test = tnt.meter.TimeMeter('s')

        #print('\nAdditional buffers:')
        #kmax = max(len(key) for key in stats.keys())
        #for i, (key, v) in enumerate(stats.items()):
        #    print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.size())).ljust(23), torch.typename(v))

        #    n_parameters = sum(p.numel() for p in params.values() + stats.values())
        #n_training_params = sum(p.numel() for p in params.values())
        #n_parameters = sum(p.numel() for p in params.values()) + sum(p.numel() for p in stats.values())
        #print('Total number of parameters:', n_parameters, '(%d)'%n_training_params)

        if not os.path.exists(opt.save):
            os.mkdir(opt.save)

        def h_ensemble(sample):
            inputs = Variable(cast(sample[0], opt.dtype))
            targets = Variable(cast(sample[1], 'long'))
            y_grassmann = data_parallel(f_grassmann, inputs, params_grassmann,
                                        stats_grassmann, sample[2],
                                        np.arange(opt.ngpu))
            y_oblique = data_parallel(f_oblique, inputs, params_oblique,
                                      stats_oblique, sample[2],
                                      np.arange(opt.ngpu))
            y_ensemble = y_grassmann + y_oblique
            return F.cross_entropy(y_ensemble, targets), y_ensemble

        def log_grassmann(t, state):
            #        torch.save(dict(params={k: v.data for k, v in params.iteritems()},
            torch.save(
                dict(params_grassmann={
                    k: v.data
                    for k, v in list(params_grassmann.items())
                },
                     stats_grassmann=stats_grassmann,
                     optimizer=state['optimizer'].state_dict(),
                     epoch=t['epoch']),
                open(os.path.join(opt.save_grassmann, 'model.pt7'), 'wb'))
            z = vars(opt).copy()
            z.update(t)
            logname = os.path.join(opt.save_grassmann, 'log.txt')
            with open(logname, 'a') as f:
                f.write('json_stats: ' + json.dumps(z) + '\n')
            print(z)

        def log_oblique(t, state):
            #        torch.save(dict(params={k: v.data for k, v in params.iteritems()},
            torch.save(
                dict(params_oblique={
                    k: v.data
                    for k, v in list(params_oblique.items())
                },
                     stats_oblique=stats_oblique,
                     optimizer=state['optimizer'].state_dict(),
                     epoch=t['epoch']),
                open(os.path.join(opt.save_oblique, 'model.pt7'), 'wb'))
            z = vars(opt).copy()
            z.update(t)
            logname = os.path.join(opt.save_oblique, 'log.txt')
            with open(logname, 'a') as f:
                f.write('json_stats: ' + json.dumps(z) + '\n')
            print(z)

        def on_sample(state):
            state['sample'].append(state['train'])

        def on_forward(state):
            classacc_ensemble.add(state['output'].data,
                                  torch.LongTensor(state['sample'][1]))
            meter_loss_ensemble.add(state['loss'].data[0])

        def on_start(state):
            state['epoch'] = epoch

        def on_start_epoch(state):
            classacc_ensemble.reset()
            meter_loss_ensemble.reset()

            timer_train.reset()
            state['iterator'] = tqdm(train_loader)

            if epoch in epoch_step:
                power = sum(epoch >= i for i in epoch_step)
                lr = opt.lr * pow(opt.lr_decay_ratio, power)
                lrm = opt.lrm * pow(opt.lr_decay_ratio, power)
                times = opt.times + 1
                state['optimizer'] = create_optimizer(opt, lr, lrg, times)

        def on_end_epoch(state):

            train_loss_ensemble = meter_loss_ensemble.value()
            train_acc_ensemble = classacc_ensemble.value()

            train_time = timer_train.value()

            meter_loss_ensemble.reset()
            classacc_ensemble.reset()

            timer_test.reset()

            engine.test(h_ensemble, test_loader)

            test_acc_ensemble = classacc_ensemble.value()[0]

            print(
                log_grassmann(
                    {
                        "train_loss_total": train_loss_ensemble[0],
                        "train_acc_ensemble": train_acc_ensemble[0],
                        "test_loss_ensemble": meter_loss_ensemble.value()[0],
                        "test_acc_ensemble": test_acc_ensemble,
                        "epoch": state['epoch'],
                        "num_classes": num_classes,
                        # "n_parameters": n_parameters,
                        "train_time": train_time,
                        "test_time": timer_test.value(),
                    },
                    state))

            print(
                log_oblique(
                    {
                        "train_loss_ensemble": train_loss_ensemble[0],
                        "train_acc_ensemble": train_acc_ensemble[0],
                        "test_loss_ensemble": meter_loss_ensemble.value()[0],
                        "test_acc_ensemble": test_acc_ensemble,
                        "epoch": state['epoch'],
                        "num_classes": num_classes,
                        # "n_parameters": n_parameters,
                        "train_time": train_time,
                        "test_time": timer_test.value(),
                    },
                    state))
            print(
                '==> id: %s (%d/%d), test_acc_ensemble: \33[91m%.2f\033[0m' % \
                (opt.save, state['epoch'], opt.epochs, test_acc_ensemble))

        engine = Engine()
        engine.hooks['on_sample'] = on_sample
        engine.hooks['on_forward'] = on_forward
        engine.hooks['on_start_epoch'] = on_start_epoch
        engine.hooks['on_end_epoch'] = on_end_epoch
        engine.hooks['on_start'] = on_start
        engine.train(h_ensemble, train_loader, opt.epochs, optimizer)

    else:
        f, params, stats = resnet.resnet(opt.depth, opt.width, num_classes)

        def create_optimizer(opt, lr):
            print('creating optimizer with lr = ', lr)
            if opt.optim_method == 'SGD':
                return torch.optim.SGD(params.values(),
                                       lr,
                                       weight_decay=opt.weightDecay)

        epoch = 0
        optimizer = create_optimizer(opt, opt.lr)

        if opt.resume != '':
            state_dict = torch.load(opt.resume)
            epoch = state_dict['epoch']
            params_tensors, stats = state_dict['params'], state_dict['stats']
            #        for k, v in params.iteritems():
            for k, v in list(params.items()):
                v.data.copy_(params_tensors[k])
            optimizer.load_state_dict(state_dict['optimizer'])

        meter_loss = tnt.meter.AverageValueMeter()
        classacc = tnt.meter.ClassErrorMeter(accuracy=True)

        timer_train = tnt.meter.TimeMeter('s')
        timer_test = tnt.meter.TimeMeter('s')

        if not os.path.exists(opt.save):
            os.mkdir(opt.save)

        def h(sample):
            inputs = Variable(cast(sample[0], opt.dtype))
            targets = Variable(cast(sample[1], 'long'))
            y = data_parallel(f, inputs, params, stats, sample[2],
                              np.arange(opt.ngpu))
            return F.cross_entropy(y, targets), y

        def log(t, state):
            #        torch.save(dict(params={k: v.data for k, v in params.iteritems()},
            torch.save(
                dict(params={k: v.data
                             for k, v in list(params.items())},
                     stats=stats,
                     optimizer=state['optimizer'].state_dict(),
                     epoch=t['epoch']),
                open(os.path.join(opt.save, 'model.pt7'), 'wb'))
            z = vars(opt).copy()
            z.update(t)
            logname = os.path.join(opt.save, 'log.txt')
            with open(logname, 'a') as f:
                f.write('json_stats: ' + json.dumps(z) + '\n')
            print(z)

        def on_sample(state):
            state['sample'].append(state['train'])

        def on_forward(state):
            classacc.add(state['output'].data,
                         torch.LongTensor(state['sample'][1]))
            meter_loss.add(state['loss'].data[0])

        def on_start(state):
            state['epoch'] = epoch

        def on_start_epoch(state):

            classacc.reset()
            meter_loss.reset()
            timer_train.reset()
            state['iterator'] = tqdm(train_loader)

            epoch = state['epoch'] + 1

            if epoch in epoch_step:
                power = sum(epoch >= i for i in epoch_step)
                lr = opt.lr * pow(opt.lr_decay_ratio, power)
                #lrg = opt.lrg * pow(opt.lr_decay_ratio, power)
                state['optimizer'] = create_optimizer(opt, lr)


#            lr = state['optimizer'].param_groups[0]['lr']
#            lrm = state['optimizer'].param_groups[0]['lrm']
#            state['optimizer'] = create_optimizer(opt,
#                                          lr * opt.lr_decay_ratio,
#                                          lrm * opt.lr_decay_ratio)

        def on_end_epoch(state):

            train_loss = meter_loss.value()
            train_acc = classacc.value()
            train_time = timer_train.value()
            meter_loss.reset()
            classacc.reset()
            timer_test.reset()

            engine.test(h, test_loader)
            test_acc = classacc.value()[0]
            print(
                log(
                    {
                        "train_loss": train_loss[0],
                        "train_acc": train_acc[0],
                        "test_loss": meter_loss.value()[0],
                        "test_acc": test_acc,
                        "epoch": state['epoch'],
                        "num_classes": num_classes,
                        #"n_parameters": n_parameters,
                        "train_time": train_time,
                        "test_time": timer_test.value(),
                    },
                    state))
            print('==> id: %s (%d/%d), test_acc: \33[91m%.2f\033[0m' % \
                  (opt.save, state['epoch'], opt.epochs, test_acc))

        engine = Engine()
        engine.hooks['on_sample'] = on_sample
        engine.hooks['on_forward'] = on_forward
        engine.hooks['on_start_epoch'] = on_start_epoch
        engine.hooks['on_end_epoch'] = on_end_epoch
        engine.hooks['on_start'] = on_start
        engine.train(h, train_loader, opt.epochs, optimizer)
Exemplo n.º 10
0
def main():
    opt = parser.parse_args()
    print('parsed options:', vars(opt))
    epoch_step = json.loads(opt.epoch_step)
    num_classes = 100 if opt.dataset == 'CIFAR100' else 10

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # to prevent opencv from initializing CUDA in workers
    torch.randn(8).cuda()
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    def create_iterator(mode):
        ds = create_dataset(opt, mode)
        return ds.parallel(batch_size=opt.batchSize,
                           shuffle=mode,
                           num_workers=opt.nthread,
                           pin_memory=True)

    train_loader = create_iterator(True)
    test_loader = create_iterator(False)

    if opt.model == 'resnet':
        model = resnet
    elif opt.model == 'vgg':
        model = vgg

    f, params, stats = model(opt.depth, opt.width, num_classes)

    key_g = []
    if opt.optim_method in ['SGDG', 'AdamG', 'Cayley_SGD', 'Cayley_Adam']:
        param_g = []
        param_e0 = []
        param_e1 = []

        for key, value in params.items():
            if 'conv' in key and value.size()[0] <= np.prod(value.size()[1:]):
                param_g.append(value)
                key_g.append(key)
                if opt.optim_method in ['SGDG', 'AdamG']:
                    # initlize to scale 1
                    unitp, _ = unit(value.data.view(value.size(0), -1))
                    value.data.copy_(unitp.view(value.size()))
                elif opt.optim_method in ['Cayley_SGD', 'Cayley_Adam']:
                    # initlize to orthogonal matrix
                    q = qr_retraction(value.data.view(value.size(0), -1))
                    value.data.copy_(q.view(value.size()))
            elif 'bn' in key or 'bias' in key:
                param_e0.append(value)
            else:
                param_e1.append(value)

    def create_optimizer(opt, lr, lrg):
        print('creating optimizer with lr = ', lr, ' lrg = ', lrg)
        if opt.optim_method == 'SGD':
            return torch.optim.SGD(params.values(),
                                   lr,
                                   0.9,
                                   weight_decay=opt.weightDecay)

        elif opt.optim_method == 'SGDG':
            dict_g = {
                'params': param_g,
                'lr': lrg,
                'momentum': 0.9,
                'grassmann': True,
                'omega': opt.omega,
                'grad_clip': opt.grad_clip
            }
            dict_e0 = {
                'params': param_e0,
                'lr': lr,
                'momentum': 0.9,
                'grassmann': False,
                'weight_decay': opt.bnDecay,
                'nesterov': True
            }
            dict_e1 = {
                'params': param_e1,
                'lr': lr,
                'momentum': 0.9,
                'grassmann': False,
                'weight_decay': opt.weightDecay,
                'nesterov': True
            }
            return grassmann_optimizer.SGDG([dict_g, dict_e0, dict_e1])

        elif opt.optim_method == 'AdamG':
            dict_g = {
                'params': param_g,
                'lr': lrg,
                'momentum': 0.9,
                'grassmann': True,
                'omega': opt.omega,
                'grad_clip': opt.grad_clip
            }
            dict_e0 = {
                'params': param_e0,
                'lr': lr,
                'momentum': 0.9,
                'grassmann': False,
                'weight_decay': opt.bnDecay,
                'nesterov': True
            }
            dict_e1 = {
                'params': param_e1,
                'lr': lr,
                'momentum': 0.9,
                'grassmann': False,
                'weight_decay': opt.weightDecay,
                'nesterov': True
            }
            return grassmann_optimizer.AdamG([dict_g, dict_e0, dict_e1])

        elif opt.optim_method == 'Cayley_SGD':
            dict_g = {
                'params': param_g,
                'lr': lrg,
                'momentum': 0.9,
                'stiefel': True
            }
            dict_e0 = {
                'params': param_e0,
                'lr': lr,
                'momentum': 0.9,
                'stiefel': False,
                'weight_decay': opt.bnDecay,
                'nesterov': True
            }
            dict_e1 = {
                'params': param_e1,
                'lr': lr,
                'momentum': 0.9,
                'stiefel': False,
                'weight_decay': opt.weightDecay,
                'nesterov': True
            }
            return stiefel_optimizer.SGDG([dict_g, dict_e0, dict_e1])

        elif opt.optim_method == 'Cayley_Adam':
            dict_g = {
                'params': param_g,
                'lr': lrg,
                'momentum': 0.9,
                'stiefel': True
            }
            dict_e0 = {
                'params': param_e0,
                'lr': lr,
                'momentum': 0.9,
                'stiefel': False,
                'weight_decay': opt.bnDecay,
                'nesterov': True
            }
            dict_e1 = {
                'params': param_e1,
                'lr': lr,
                'momentum': 0.9,
                'stiefel': False,
                'weight_decay': opt.weightDecay,
                'nesterov': True
            }
            return stiefel_optimizer.AdamG([dict_g, dict_e0, dict_e1])

    optimizer = create_optimizer(opt, opt.lr, opt.lrg)

    epoch = 0
    if opt.resume != '':
        state_dict = torch.load(opt.resume)
        epoch = state_dict['epoch']
        params_tensors, stats = state_dict['params'], state_dict['stats']
        for k, v in list(params.items()):
            v.data.copy_(params_tensors[k])
        optimizer.load_state_dict(state_dict['optimizer'])

    print('\nParameters:')
    kmax = max(len(key) for key in params.keys())
    for i, (key, v) in enumerate(params.items()):
        print(str(i).ljust(5),
              key.ljust(kmax + 3),
              str(tuple(v.size())).ljust(23),
              torch.typename(v.data),
              end='')
        print(' on G(1,n)' if key in key_g else '')

    print('\nAdditional buffers:')
    kmax = max(len(key) for key in stats.keys())
    for i, (key, v) in enumerate(stats.items()):
        print(
            str(i).ljust(5), key.ljust(kmax + 3),
            str(tuple(v.size())).ljust(23), torch.typename(v))

    n_training_params = sum(p.numel() for p in params.values())
    n_parameters = sum(p.numel()
                       for p in params.values()) + sum(p.numel()
                                                       for p in stats.values())
    print('Total number of parameters:', n_parameters,
          '(%d)' % n_training_params)

    meter_loss = tnt.meter.AverageValueMeter()
    classacc = tnt.meter.ClassErrorMeter(accuracy=True)
    timer_train = tnt.meter.TimeMeter('s')
    timer_test = tnt.meter.TimeMeter('s')

    if not os.path.exists(opt.save):
        os.mkdir(opt.save)

    def h(sample):
        inputs = Variable(cast(sample[0], opt.dtype))
        targets = Variable(cast(sample[1], 'long'))
        y = data_parallel(f, inputs, params, stats, sample[2],
                          np.arange(opt.ngpu))
        return F.cross_entropy(y, targets), y

    def log(t, state):
        torch.save(
            dict(params={k: v.data
                         for k, v in list(params.items())},
                 stats=stats,
                 optimizer=state['optimizer'].state_dict(),
                 epoch=t['epoch']),
            open(os.path.join(opt.save, 'model.pt7'), 'wb'))
        z = vars(opt).copy()
        z.update(t)
        logname = os.path.join(opt.save, 'log.txt')
        with open(logname, 'a') as f:
            f.write('json_stats: ' + json.dumps(z) + '\n')
        print(z)

    def on_sample(state):
        state['sample'].append(state['train'])

    def on_forward(state):
        classacc.add(state['output'].data,
                     torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].data.item())

    def on_start(state):
        state['epoch'] = epoch

    def on_start_epoch(state):
        classacc.reset()
        meter_loss.reset()
        timer_train.reset()
        state['iterator'] = tqdm(train_loader)

        epoch = state['epoch'] + 1
        if epoch in epoch_step:
            power = sum(epoch >= i for i in epoch_step)
            lr = opt.lr * pow(opt.lr_decay_ratio, power)
            lrg = opt.lrg * pow(opt.lr_decay_ratio, power)
            state['optimizer'] = create_optimizer(opt, lr, lrg)

    def on_end_epoch(state):
        train_loss = meter_loss.value()
        train_acc = classacc.value()
        train_time = timer_train.value()
        meter_loss.reset()
        classacc.reset()
        timer_test.reset()

        engine.test(h, test_loader)

        test_acc = classacc.value()[0]
        print(
            log(
                {
                    "train_loss": train_loss[0],
                    "train_acc": train_acc[0],
                    "test_loss": meter_loss.value()[0],
                    "test_acc": test_acc,
                    "epoch": state['epoch'],
                    "num_classes": num_classes,
                    "n_parameters": n_parameters,
                    "train_time": train_time,
                    "test_time": timer_test.value(),
                }, state))
        print('==> id: %s (%d/%d), test_acc: \33[91m%.2f\033[0m' % \
                (opt.save, state['epoch'], opt.epochs, test_acc))

    engine = Engine()
    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.hooks['on_start'] = on_start
    engine.train(h, train_loader, opt.epochs, optimizer)
Exemplo n.º 11
0
def train(LEARNING_RATE_BASE, MODEL_SAVE_PATH, FILE_SAVE_PATH):
    data, labels = reader.unpickle(reader.file)
    file_path_loss_g = os.path.join(
        FILE_SAVE_PATH, ('loss_g_' + str(LEARNING_RATE_GRASSMANN) + '.txt'))
    file_path_loss_o = os.path.join(
        FILE_SAVE_PATH, ('loss_o_' + str(LEARNING_RATE_OBLIQUE) + '.txt'))

    file_path_norm = os.path.join(FILE_SAVE_PATH,
                                  ('norm' + str(LEARNING_RATE_BASE) + '.txt'))

    file1_path = os.path.join(FILE_SAVE_PATH,
                              ('accuracy_' + str(LEARNING_RATE_BASE) + '.txt'))
    file1_path_g = os.path.join(
        FILE_SAVE_PATH,
        ('accuracy_g_' + str(LEARNING_RATE_GRASSMANN) + '.txt'))
    file1_path_o = os.path.join(
        FILE_SAVE_PATH, ('accuracy_o_' + str(LEARNING_RATE_OBLIQUE) + '.txt'))

    file_loss_g = open(file_path_loss_g, 'w')
    file_loss_o = open(file_path_loss_o, 'w')

    file_norm = open(file_path_norm, 'w')

    file_accuracy = open(file1_path, 'w')
    file_accuracy_g = open(file1_path_g, 'w')
    file_accuracy_o = open(file1_path_o, 'w')

    x = tf.placeholder(tf.float32,
                       shape=[None, cifar10_ensemble.INPUT_NODE],
                       name="x-input")
    y_ = tf.placeholder(tf.float32,
                        shape=[None, cifar10_ensemble.OUTPUT_NODE],
                        name="y-output")
    x_reshaped = tf.reshape(x, [
        -1, cifar10_ensemble.IMAGE_SIZE, cifar10_ensemble.IMAGE_SIZE,
        cifar10_ensemble.NUM_CHANNELS
    ])
    times = tf.placeholder(tf.float32, shape=None, name="times")

    #GRAD_CLIP=tf.constant(1.0,dtype=tf.float32)

    #正则化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)

    y_g, y_o = cifar10_ensemble.inference(x_reshaped, False, regularizer)
    global_step = tf.Variable(0, trainable=None)

    #定义损失函数,滑动平均操作等
    variable_averages = tf.train.ExponentialMovingAverage(
        MOVING_AVERAGE_DECAY, global_step)
    #variable_averages_op=variable_averages.apply(tf.trainable_variables())
    cross_entropy_g = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.argmax(y_, 1), logits=y_g)
    cross_entropy_o = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.argmax(y_, 1), logits=y_o)

    cross_entropy_mean_g = tf.reduce_mean(cross_entropy_g)
    cross_entropy_mean_o = tf.reduce_mean(cross_entropy_o)

    #损失函数,其中涉及到对一个列表中的元素(还是一个列表)求和
    loss_g = cross_entropy_mean_g  #+tf.add_n(tf.get_collection('losses_g'))
    loss_o = cross_entropy_mean_o  #+tf.add_n(tf.get_collection('losses_o'))

    learning_rate_g = tf.train.exponential_decay(LEARNING_RATE_GRASSMANN,
                                                 global_step,
                                                 50000 / BATCH_SIZE,
                                                 LEARNING_RATE_DECAY)
    learning_rate_o = tf.train.exponential_decay(LEARNING_RATE_OBLIQUE,
                                                 global_step,
                                                 50000 / BATCH_SIZE,
                                                 LEARNING_RATE_DECAY)
    #learning_rate=LEARNING_RATE_BASE
    #更新参数

    #train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
    #滑动平均并行计算
    #with tf.control_dependencies([train_step,variable_averages_op]):
    #train_op=tf.no_op(name='train')
    correct_prediction_g = tf.equal(tf.argmax(y_, 1), tf.argmax(y_g, 1))
    correct_prediction_o = tf.equal(tf.argmax(y_, 1), tf.argmax(y_o, 1))

    correct_prediction = tf.equal(tf.argmax(y_, 1),
                                  tf.argmax(tf.add(y_g, y_o), 1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    accuracy_g = tf.reduce_mean(tf.cast(correct_prediction_g, tf.float32))
    accuracy_o = tf.reduce_mean(tf.cast(correct_prediction_o, tf.float32))
    #########################################################################################################3
    with tf.variable_scope('layer1-conv1', reuse=True):
        conv1_weights_g = tf.get_variable("weight_g")
        conv1_biases_g = tf.get_variable('biases_g')

        conv1_weights_o = tf.get_variable("weight_o")
        conv1_biases_o = tf.get_variable('biases_o')

        conv1_weights_g_tmp_layer1 = tf.get_variable("weight_g_tmp")
        conv1_weights_o_tmp_layer1 = tf.get_variable("weight_o_tmp")

        conv1_biases_g_tmp = tf.get_variable("biases_g_tmp")
        conv1_biases_o_tmp = tf.get_variable("biases_o_tmp")

        dim_layer1 = conv1_weights_g.get_shape()

        weights_grad_g_base_layer1 = tf.gradients(
            loss_g, conv1_weights_g, stop_gradients=conv1_weights_g)
        weights_grad_o_base_layer1 = tf.gradients(
            loss_o, conv1_weights_o, stop_gradients=conv1_weights_o)

        weights_grad_g_base_biases_layer1 = tf.gradients(
            loss_g, conv1_biases_g, stop_gradients=conv1_biases_g)
        weights_grad_o_base_biases_layer1 = tf.gradients(
            loss_o, conv1_biases_o, stop_gradients=conv1_biases_o)

        weights_g_1 = tf.reshape(conv1_weights_g, shape=[-1, 1])
        weights_o_1 = tf.reshape(conv1_weights_o, shape=[-1, 1])

        tf.convert_to_tensor(weights_grad_g_base_layer1[0], dtype=tf.float32)
        tf.convert_to_tensor(weights_grad_o_base_layer1[0], dtype=tf.float32)

        weights_grad_g_base_1 = tf.reshape(weights_grad_g_base_layer1[0],
                                           shape=[-1, 1])
        weights_grad_o_base_l = tf.reshape(weights_grad_o_base_layer1[0],
                                           shape=[-1, 1])

        grad_on_grassmann_1 = gutils.grassmann_project(weights_g_1,
                                                       weights_grad_g_base_1)
        grad_on_oblique_1 = gutils.oblique_project(weights_o_1,
                                                   weights_grad_o_base_l)

        weights_g_layer1 = optimize_function.apply_dense_on_grasssmann(
            GRAD_CLIP, grad_on_grassmann_1, grad_on_oblique_1, weights_g_1,
            learning_rate_g, times, DELTA)

        weights_o_layer1 = optimize_function._apply_dense_on_oblique(
            GRAD_CLIP, grad_on_grassmann_1, grad_on_oblique_1, weights_o_1,
            learning_rate_o, times, DELTA)

        #weights_g_layer1 = weights_g_1 - learning_rate_g * weights_grad_g_base_1
        #weights_o_layer1 = weights_o_1 - learning_rate_o * weights_grad_o_base_l

        weights_biases_g_layer1 = tf.add(
            -1 * learning_rate_g * tf.convert_to_tensor(
                weights_grad_g_base_biases_layer1[0], tf.float32),
            conv1_biases_g)
        weights_biases_o_layer1 = tf.add(
            -1 * learning_rate_o * tf.convert_to_tensor(
                weights_grad_o_base_biases_layer1[0], tf.float32),
            conv1_biases_o)

        norm_g_1 = tf.square(gutils.norm(grad_on_grassmann_1))
        norm_o_1 = tf.square(gutils.norm(grad_on_oblique_1))

    with tf.variable_scope('layer3-conv2', reuse=True):
        conv2_weights_g = tf.get_variable("weight_g")
        conv2_biases_g = tf.get_variable('biases_g')
        conv2_weights_o = tf.get_variable("weight_o")
        conv2_biases_o = tf.get_variable('biases_o')

        conv2_weights_g_tmp_layer3 = tf.get_variable("weight_g_tmp")
        conv2_weights_o_tmp_layer3 = tf.get_variable("weight_o_tmp")

        conv2_biases_g_tmp = tf.get_variable("biases_g_tmp")
        conv2_biases_o_tmp = tf.get_variable("biases_o_tmp")

        dim_layer3 = conv2_weights_g.get_shape()

        weights_grad_g_base_3 = tf.gradients(loss_g,
                                             conv2_weights_g,
                                             stop_gradients=conv2_weights_g)
        weights_grad_o_base_3 = tf.gradients(loss_o,
                                             conv2_weights_o,
                                             stop_gradients=conv2_weights_o)

        weights_grad_g_base_biases_layer3 = tf.gradients(
            loss_g, conv2_biases_g, stop_gradients=conv2_biases_g)
        weights_grad_o_base_biases_layer3 = tf.gradients(
            loss_o, conv2_biases_o, stop_gradients=conv2_biases_o)

        weights_g_3 = tf.reshape(conv2_weights_g, shape=[-1, 1])
        weights_o_3 = tf.reshape(conv2_weights_o, shape=[-1, 1])

        tf.convert_to_tensor(weights_grad_g_base_3[0], dtype=tf.float32)
        tf.convert_to_tensor(weights_grad_o_base_3[0], dtype=tf.float32)

        weights_grad_g_3 = tf.reshape(weights_grad_g_base_3[0], shape=[-1, 1])
        weights_grad_o_3 = tf.reshape(weights_grad_o_base_3[0], shape=[-1, 1])

        grad_on_grassmann_3 = gutils.grassmann_project(weights_g_3,
                                                       weights_grad_g_3)
        grad_on_oblique_3 = gutils.oblique_project(weights_o_3,
                                                   weights_grad_o_3)

        weights_g_layer3 = optimize_function.apply_dense_on_grasssmann(
            GRAD_CLIP, grad_on_grassmann_3, grad_on_oblique_3, weights_g_3,
            learning_rate_g, times, DELTA)
        weights_o_layer3 = optimize_function._apply_dense_on_oblique(
            GRAD_CLIP, grad_on_grassmann_3, grad_on_oblique_3, weights_o_3,
            learning_rate_o, times, DELTA)

        #weights_g_layer3 = weights_g_3 - learning_rate_g * weights_grad_g_3
        #weights_o_layer3 = weights_o_3 - learning_rate_o * weights_grad_o_3

        weights_biases_g_layer3 = tf.add(
            -1 * learning_rate_g * tf.convert_to_tensor(
                weights_grad_g_base_biases_layer3[0], tf.float32),
            conv2_biases_g)
        weights_biases_o_layer3 = tf.add(
            -1 * learning_rate_o * tf.convert_to_tensor(
                weights_grad_o_base_biases_layer3[0], tf.float32),
            conv2_biases_o)

        norm_g_3 = tf.square(gutils.norm(grad_on_grassmann_3))
        norm_o_3 = tf.square(gutils.norm(grad_on_oblique_3))

    with tf.variable_scope('layer5-fc1', reuse=True):
        fc1_weights_g = tf.get_variable("weight_g")
        fc1_biases_g = tf.get_variable("biases_g")
        fc1_weights_o = tf.get_variable("weight_o")
        fc1_biases_o = tf.get_variable("biases_o")

        fc1_weights_g_tmp_layer5 = tf.get_variable("weight_g_tmp")
        fc1_weights_o_tmp_layer5 = tf.get_variable("weight_o_tmp")

        fc1_biases_g_tmp = tf.get_variable("biases_g_tmp")
        fc1_biases_o_tmp = tf.get_variable("biases_o_tmp")

        dim_layer5 = fc1_weights_g.get_shape()

        weights_grad_g_base_5 = tf.gradients(loss_g,
                                             fc1_weights_g,
                                             stop_gradients=fc1_weights_g)
        weights_grad_o_base_5 = tf.gradients(loss_o,
                                             fc1_weights_o,
                                             stop_gradients=fc1_weights_o)

        weights_grad_g_base_biases_layer5 = tf.gradients(
            loss_g, fc1_biases_g, stop_gradients=fc1_biases_g)
        weights_grad_o_base_biases_layer5 = tf.gradients(
            loss_o, fc1_biases_o, stop_gradients=fc1_biases_o)

        weights_g_5 = tf.reshape(fc1_weights_g, shape=[-1, 1])
        weights_o_5 = tf.reshape(fc1_weights_o, shape=[-1, 1])

        tf.convert_to_tensor(weights_grad_g_base_5[0], dtype=tf.float32)
        tf.convert_to_tensor(weights_grad_o_base_5[0], dtype=tf.float32)

        weights_grad_g_5 = tf.reshape(weights_grad_g_base_5[0], shape=[-1, 1])
        weights_grad_o_5 = tf.reshape(weights_grad_o_base_5[0], shape=[-1, 1])

        grad_on_grassmann_5 = gutils.grassmann_project(weights_g_5,
                                                       weights_grad_g_5)
        grad_on_oblique_5 = gutils.oblique_project(weights_o_5,
                                                   weights_grad_o_5)

        weights_g_layer5 = optimize_function.apply_dense_on_grasssmann(
            GRAD_CLIP, grad_on_grassmann_5, grad_on_oblique_5, weights_g_5,
            learning_rate_g, times, DELTA)
        weights_o_layer5 = optimize_function._apply_dense_on_oblique(
            GRAD_CLIP, grad_on_grassmann_5, grad_on_oblique_5, weights_o_5,
            learning_rate_o, times, DELTA)

        #weights_g_layer5 = weights_g_5 - learning_rate_g * weights_grad_g_5
        #weights_o_layer5 = weights_o_5 - learning_rate_o * weights_grad_o_5

        weights_biases_g_layer5 = tf.add(
            -1 * learning_rate_g * tf.convert_to_tensor(
                weights_grad_g_base_biases_layer5[0], tf.float32),
            fc1_biases_g)
        weights_biases_o_layer5 = tf.add(
            -1 * learning_rate_o * tf.convert_to_tensor(
                weights_grad_o_base_biases_layer5[0], tf.float32),
            fc1_biases_o)

        norm_g_5 = tf.square(gutils.norm(grad_on_grassmann_5))
        norm_o_5 = tf.square(gutils.norm(grad_on_oblique_5))

    with tf.variable_scope('layer6-fc2', reuse=True):
        fc2_weights_g = tf.get_variable("weight_g")
        fc2_biases_g = tf.get_variable("biases_g")
        fc2_weights_o = tf.get_variable("weight_o")
        fc2_biases_o = tf.get_variable("biases_o")

        fc2_weights_g_tmp_layer6 = tf.get_variable("weight_g_tmp")
        fc2_weights_o_tmp_layer6 = tf.get_variable("weight_o_tmp")

        fc2_biases_g_tmp = tf.get_variable("biases_g_tmp")
        fc2_biases_o_tmp = tf.get_variable("biases_o_tmp")

        dim_layer6 = fc2_weights_g.get_shape()

        weights_grad_g_base_6 = tf.gradients(loss_g, fc2_weights_g)
        weights_grad_o_base_6 = tf.gradients(loss_o, fc2_weights_o)

        weights_grad_g_base_biases_layer6 = tf.gradients(
            loss_g, fc2_biases_g, stop_gradients=fc2_biases_g)
        weights_grad_o_base_biases_layer6 = tf.gradients(
            loss_o, fc2_biases_o, stop_gradients=fc2_biases_o)

        weights_g_6 = tf.reshape(fc2_weights_g, shape=[-1, 1])
        weights_o_6 = tf.reshape(fc2_weights_o, shape=[-1, 1])

        tf.convert_to_tensor(weights_grad_g_base_6[0], dtype=tf.float32)
        tf.convert_to_tensor(weights_grad_o_base_6[0], dtype=tf.float32)

        weights_grad_g = tf.reshape(weights_grad_g_base_6[0], shape=[-1, 1])
        weights_grad_o = tf.reshape(weights_grad_o_base_6[0], shape=[-1, 1])

        grad_on_grassmann_6 = gutils.grassmann_project(weights_g_6,
                                                       weights_grad_g)
        grad_on_oblique_6 = gutils.oblique_project(weights_o_6, weights_grad_o)

        weights_g_layer6 = optimize_function.apply_dense_on_grasssmann(
            GRAD_CLIP, grad_on_grassmann_6, grad_on_oblique_6, weights_g_6,
            learning_rate_g, times, DELTA)
        weights_o_layer6 = optimize_function._apply_dense_on_oblique(
            GRAD_CLIP, grad_on_grassmann_6, grad_on_oblique_6, weights_o_6,
            learning_rate_o, times, DELTA)

        #weights_g_layer6 = weights_g_6 - learning_rate_g * weights_grad_g
        #weights_o_layer6 = weights_o_6 - learning_rate_o * weights_grad_o

        weights_biases_g_layer6 = tf.add(
            -1 * learning_rate_g * tf.convert_to_tensor(
                weights_grad_g_base_biases_layer6[0], tf.float32),
            fc2_biases_g)
        weights_biases_o_layer6 = tf.add(
            -1 * learning_rate_o * tf.convert_to_tensor(
                weights_grad_o_base_biases_layer6[0], tf.float32),
            fc2_biases_o)

        norm_g_6 = tf.square(gutils.norm(grad_on_grassmann_6))
        norm_o_6 = tf.square(gutils.norm(grad_on_oblique_6))

    n = norm_g_1 + norm_g_3 + norm_g_5 + norm_g_6 + norm_o_1 + norm_o_3 + norm_o_5 + norm_o_6

    _1 = tf.assign(conv1_weights_g_tmp_layer1,
                   gutils.unit(tf.reshape(weights_g_layer1, shape=dim_layer1)))
    _2 = tf.assign(conv1_weights_o_tmp_layer1,
                   gutils.unit(tf.reshape(weights_o_layer1, shape=dim_layer1)))
    _3 = tf.assign(conv2_weights_g_tmp_layer3,
                   gutils.unit(tf.reshape(weights_g_layer3, shape=dim_layer3)))
    _4 = tf.assign(conv2_weights_o_tmp_layer3,
                   gutils.unit(tf.reshape(weights_o_layer3, shape=dim_layer3)))
    _5 = tf.assign(fc1_weights_g_tmp_layer5,
                   gutils.unit(tf.reshape(weights_g_layer5, shape=dim_layer5)))
    _6 = tf.assign(fc1_weights_o_tmp_layer5,
                   gutils.unit(tf.reshape(weights_o_layer5, shape=dim_layer5)))
    _7 = tf.assign(fc2_weights_g_tmp_layer6,
                   gutils.unit(tf.reshape(weights_g_layer6, shape=dim_layer6)))
    _8 = tf.assign(fc2_weights_o_tmp_layer6,
                   gutils.unit(tf.reshape(weights_o_layer6, shape=dim_layer6)))

    _11 = tf.assign(conv1_biases_g_tmp, weights_biases_g_layer1)
    _12 = tf.assign(conv1_biases_o_tmp, weights_biases_o_layer1)
    _13 = tf.assign(conv2_biases_g_tmp, weights_biases_g_layer3)
    _14 = tf.assign(conv2_biases_o_tmp, weights_biases_o_layer3)
    _15 = tf.assign(fc1_biases_g_tmp, weights_biases_g_layer5)
    _16 = tf.assign(fc1_biases_o_tmp, weights_biases_o_layer5)
    _17 = tf.assign(fc2_biases_g_tmp, weights_biases_g_layer6)
    _18 = tf.assign(fc2_biases_o_tmp, weights_biases_o_layer6)

    _21 = tf.assign(conv1_weights_g, conv1_weights_g_tmp_layer1)
    _22 = tf.assign(conv1_weights_o, conv1_weights_o_tmp_layer1)
    _23 = tf.assign(conv2_weights_g, conv2_weights_g_tmp_layer3)
    _24 = tf.assign(conv2_weights_o, conv2_weights_o_tmp_layer3)
    _25 = tf.assign(fc1_weights_g, fc1_weights_g_tmp_layer5)
    _26 = tf.assign(fc1_weights_o, fc1_weights_o_tmp_layer5)
    _27 = tf.assign(fc2_weights_g, fc2_weights_g_tmp_layer6)
    _28 = tf.assign(fc2_weights_o, fc2_weights_o_tmp_layer6)

    _31 = tf.assign(conv1_biases_g, conv1_biases_g_tmp)
    _32 = tf.assign(conv1_biases_o, conv1_biases_o_tmp)
    _33 = tf.assign(conv2_biases_g, conv2_biases_g_tmp)
    _34 = tf.assign(conv2_biases_o, conv2_biases_o_tmp)
    _35 = tf.assign(fc1_biases_g, fc1_biases_g_tmp)
    _36 = tf.assign(fc1_biases_o, fc1_biases_o_tmp)
    _37 = tf.assign(fc2_biases_g, fc2_biases_g_tmp)
    _38 = tf.assign(fc2_biases_o, fc2_biases_o_tmp)

    norm_1 = gutils.norm(conv1_weights_g)
    ######################################################################################################################
    #初始化持久化类
    #saver=tf.train.Saver()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        #训练模型,其中每隔一段时间会保存训练的结果
        i = 0
        while i <= EPOCH:
            for u in range(TRAINING_STEPS):
                if u * BATCH_SIZE >= 50000:
                    print("run out of all data")
                    break
                xs = data[(u * BATCH_SIZE):((u + 1) * BATCH_SIZE)]
                ys = labels[(u * BATCH_SIZE):((u + 1) * BATCH_SIZE)]
                loss_value_g, loss_value_o, \
                         accuracy_value, accuracy_g_value, accuracy_o_value, step = sess.run(
                    [loss_g, loss_o, accuracy, accuracy_g, accuracy_o,
                     global_step], feed_dict={x: xs, y_: ys})
                #****************************************************************
                #print(sess.run(norm_1,feed_dict={x: xs, y_: ys, times: float(u)}))

                sess.run([_1, _2, _3, _4, _5, _6, _7, _8],
                         feed_dict={
                             x: xs,
                             y_: ys,
                             times: float(u)
                         })
                sess.run([_11, _12, _13, _14, _15, _16, _17, _18],
                         feed_dict={
                             x: xs,
                             y_: ys,
                             times: float(u)
                         })
                sess.run([_21, _22, _23, _24, _25, _26, _27, _28])
                sess.run([_31, _32, _33, _34, _35, _36, _37, _38])

                n_value = sess.run(n,
                                   feed_dict={
                                       x: xs,
                                       y_: ys,
                                       times: float(u)
                                   })

                #print(n_value)
                ##########################################################################################################
                file_loss_g.write(
                    str(u)), file_loss_g.write(' '), file_loss_g.write(
                        str(loss_value_g)), file_loss_g.write("\n")
                file_loss_o.write(
                    str(u)), file_loss_o.write(' '), file_loss_o.write(
                        str(loss_value_o)), file_loss_o.write("\n")

                file_accuracy.write(
                    str(u)), file_accuracy.write(' '), file_accuracy.write(
                        str(accuracy_value)), file_accuracy.write('\n')
                file_accuracy_g.write(
                    str(u)), file_accuracy_g.write(' '), file_accuracy_g.write(
                        str(accuracy_g_value)), file_accuracy_g.write('\n')
                file_accuracy_o.write(
                    str(u)), file_accuracy_o.write(' '), file_accuracy_o.write(
                        str(accuracy_o_value)), file_accuracy_o.write('\n')
                file_norm.write(str(u)), file_norm.write(' '), file_norm.write(
                    str(n)), file_norm.write('\n')

                if u % 100 == 0:
                    print(
                        "After %d training steps, loss_g and loss_o on training batch is %g and %g accuracy is %g"
                        % (u, loss_value_g, loss_value_o, accuracy_value))

                    print(
                        "After %d training steps, accuracy_g and accuracy_o on training batch is %g and %g"
                        % (u, accuracy_g_value, accuracy_o_value))
                    print(time.localtime(time.time()))
            i = i + 1
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            grassmann = group['grassmann']

            if grassmann:
                grad_clip = group['grad_clip']
                omega = group['omega']

                for p in group['params']:
                    if p.grad is None:
                        continue

                    unity,_ = unit(p.data.view(p.size()[0],-1))
                    g = p.grad.data.view(p.size()[0],-1)

                    if omega != 0:
                      # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                      # dL/dY=2(YY'Y-Y)
                      g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h = gproj(unity, g)

                    if grad_clip is not None:
                        h_hat = clip_by_norm(h, grad_clip)
                    else:
                        h_hat = h

                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        param_state['momentum_buffer'] = torch.zeros(h_hat.size())
                        if p.is_cuda:
                          param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda()

                    mom = param_state['momentum_buffer']
                    mom_new = momentum*mom - group['lr']*h_hat

                    p.data.copy_(gexp(unity, mom_new).view(p.size()))
                    mom.copy_(gpt(unity, mom_new))

            else:
                # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
                weight_decay = group['weight_decay']
                dampening = group['dampening']
                nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            grassmann = group['grassmann']

            if grassmann:
                beta1 = group['momentum']
                beta2 = group['beta2']
                epsilon = group['epsilon']
                grad_clip = group['grad_clip']
                omega = group['omega']

                for p in group['params']:
                    if p.grad is None:
                        continue

                    unity,_ = unit(p.data.view(p.size()[0],-1))
                    g = p.grad.data.view(p.size()[0],-1)

                    if omega != 0:
                      # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c
                      # dL/dY=2(YY'Y-Y)
                      g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity)

                    h = gproj(unity, g)

                    if grad_clip is not None:
                        h_hat = clip_by_norm(h, grad_clip)
                    else:
                        h_hat = h

                    param_state = self.state[p]
                    if 'm_buffer' not in param_state:
                        size=p.size()
                        param_state['m_buffer'] = torch.zeros([size[0], int(np.prod(size[1:]))])
                        param_state['v_buffer'] = torch.zeros([size[0], 1])
                        if p.is_cuda:
                            param_state['m_buffer'] = param_state['m_buffer'].cuda()
                            param_state['v_buffer'] = param_state['v_buffer'].cuda()

                        param_state['beta1_power'] = beta1
                        param_state['beta2_power'] = beta2

                    m = param_state['m_buffer']
                    v = param_state['v_buffer']
                    beta1_power = param_state['beta1_power']
                    beta2_power = param_state['beta2_power']

                    mnew = beta1*m  + (1.0-beta1)*h_hat
                    vnew = beta2*v  + (1.0-beta2)*xTy(h_hat,h_hat)

                    alpha = np.sqrt(1.-beta2_power) / (1.-beta1_power)
                    deltas = mnew / vnew.add(epsilon).sqrt()
                    deltas.mul_(-alpha*group['lr'])

                    p.data.copy_(gexp(unity, deltas).view(p.size()))
                    m.copy_(gpt2(unity, mnew, deltas))
                    v.copy_(vnew)

                    param_state['beta1_power']*=beta1
                    param_state['beta2_power']*=beta2
            else:
                momentum = group['momentum']
                weight_decay = group['weight_decay']
                dampening = group['dampening']
                nesterov = group['nesterov']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf

                    p.data.add_(-group['lr'], d_p)

        return loss