コード例 #1
0
ファイル: power_iteration.py プロジェクト: sIncerass/QBERT
    def get_eigens():
        vs = []
        lambdass = []
        for block_id in range(12):
            model_block = model.module.bert.encoder.layer[block_id]

            v = [
                torch.randn(p.size()).to(device)
                for p in model_block.parameters()
            ]
            v = de_variable(v)

            lambda_old, lambdas = 0., 1.
            i = 0
            while (abs((lambdas - lambda_old) / lambdas) >= 0.01):

                lambda_old = lambdas

                acc_Hv = [
                    torch.zeros(p.size()).cuda()
                    for p in model_block.parameters()
                ]
                for step, batch in enumerate(
                        tqdm(train_dataloader, desc="Iteration")):
                    if step < percentage_index:

                        loss = get_loss(batch)

                        loss.backward(create_graph=True)
                        grads = [
                            param.grad for param in model_block.parameters()
                        ]
                        params = model_block.parameters()

                        Hv = torch.autograd.grad(grads,
                                                 params,
                                                 grad_outputs=v,
                                                 only_inputs=True,
                                                 retain_graph=True)
                        acc_Hv = [
                            acc_Hv_p + Hv_p
                            for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                        ]
                        model.zero_grad()
                # calculate raylay quotients
                lambdas = group_product(acc_Hv, v).item() / percentage_index

                v = de_variable(acc_Hv)
                logger.info(f'lambda: {lambdas}')
                writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
                csv_file.flush()

                i += 1
            vs.append(v)
            lambdass.append(lambdas)
            break
        return vs, lambdass
コード例 #2
0
ファイル: frank_wolfe_update.py プロジェクト: kylasa/FITRE
def getFrankWolfeUpdate(pk, gradient, network, delta, maxiters, epsilon, X, Y):

    tk = 0
    for i in range(maxiters):

        print('Iteration %d of Frank update' % (i))

        # Scale P, so that it would be in the trust region radius = Delta
        pk = delta / math.sqrt(group_product(pk, pk))

        vk = [(g + p) for g, p in zip(gradient, network.computeHv_r(X, Y, pk))]
        vnorm = math.sqrt(group_product(vk, vk))
        qk = [v.mul_(-delta / vnorm) for v in vk]

        rk = group_add(pk, qk, -1)
        rk = [r.type(torch.cuda.DoubleTensor) for r in rk]

        first_opt = group_product(rk, vk)
        if (first_opt.item() <= epsilon):
            print('first_opt for ', first_opt.item())
            break

        hrk = network.computeHv_r(X, Y, rk)

        h0 = 0
        h1 = getStepModelReduction(1, vk, rk, hrk)
        if (group_product(rk, hrk) <= 0):

            if (h0 <= h1):
                tk = 0
            else:
                tk = 1

        else:
            tk = (group_product(gradient, rk) +
                  group_product(pk, hrk)) / group_product(rk, hrk)
            ht = getStepModelReduction(tk, vk, rk, hrk)

            index = np.argmax([h0, h1, ht])
            tk = [0, 1, tk][index]

        pk = [(p - tk * q) for p, q in zip(pk, group_add(pk, qk, -1))]

    gv = group_product(gradient, pk)
    hv = network.computeHv_r(X, Y, pk)
    m = gv + 0.5 * group_product(pk, hv)

    return pk, m
コード例 #3
0
ファイル: power_iteration.py プロジェクト: sIncerass/QBERT
def get_hessian_trace(train_dataloader, model, get_loss, args, device):
    """
    compute the trace/num_params of model parameters Hessian with a full dataset.
    """
    percentage_index = len(train_dataloader.dataset) * \
        args.data_percentage / args.train_batch_size
    print(f'percentage_index: {percentage_index}')

    # change the model to evaluation mode, otherwise the batch Normalization Layer will change.
    # If you call this functino during training, remember to change the mode
    # back to training mode.
    model.eval()
    model.zero_grad()

    parent_dir = join(dirname(__file__), os.pardir)

    results_dir = join(parent_dir, 'results')

    csv_path = join(
        results_dir,
        f'{args.task_name}-{args.data_percentage}-seed-{args.seed}-trace.csv')
    csv_file = open(csv_path, 'w', newline='')
    writer = csv.writer(csv_file,
                        delimiter=',',
                        quotechar='|',
                        quoting=csv.QUOTE_MINIMAL)
    writer.writerow(['block', 'iters', 'eigenvalue'])

    # make sure requires_grad are true
    for module in model.modules():
        for param in module.parameters():
            param.requires_grad = True

    for block_id in range(12):
        model_block = model.module.bert.encoder.layer[block_id]

        eigenvalue_sum = 0
        for i in range(args.num_iter):
            v = [
                torch.randn(p.size()).to(device)
                for p in model_block.parameters()
            ]
            v = de_variable(v)

            acc_Hv = [
                torch.zeros(p.size()).to(device)
                for p in model_block.parameters()
            ]
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if step < percentage_index:

                    loss = get_loss(batch)

                    loss.backward(create_graph=True)
                    grads = [param.grad for param in model_block.parameters()]
                    params = model_block.parameters()

                    Hv = torch.autograd.grad(grads,
                                             params,
                                             grad_outputs=v,
                                             only_inputs=True,
                                             retain_graph=True)
                    acc_Hv = [
                        acc_Hv_p + Hv_p for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                    ]
                    model.zero_grad()
            # calculate raylay quotients
            eigenvalue = group_product(acc_Hv, v).item() / percentage_index
            eigenvalue_sum += eigenvalue

            writer.writerow([f'{block_id}', f'{i}', f'{eigenvalue}'])
            csv_file.flush()

        print(
            f'result for iter {args.num_iter} is : {eigenvalue_sum/args.num_iter}'
        )
コード例 #4
0
ファイル: power_iteration.py プロジェクト: sIncerass/QBERT
def power_iteration_eigenvecs(train_dataloader, model, get_loss, args, device):
    # Also eig vectors

    percentage_index = len(train_dataloader.dataset) * \
        args.data_percentage / args.train_batch_size
    print(f'percentage_index: {percentage_index}')

    model.eval()

    parent_dir = join(dirname(__file__), os.pardir)

    results_dir = join(parent_dir, 'results')

    csv_path = join(
        results_dir,
        f'{args.task_name}-{args.data_percentage}-seed-{args.seed}-eigens.csv')
    csv_file = open(csv_path, 'w', newline='')
    writer = csv.writer(csv_file,
                        delimiter=',',
                        quotechar='|',
                        quoting=csv.QUOTE_MINIMAL)
    writer.writerow(['block', 'iters', 'max_eigenvalue'])

    # make sure requires_grad are true
    for module in model.modules():
        for param in module.parameters():
            param.requires_grad = True

    block_id = 0

    def get_eigens():
        vs = []
        lambdass = []
        for block_id in range(12):
            model_block = model.module.bert.encoder.layer[block_id]

            v = [
                torch.randn(p.size()).to(device)
                for p in model_block.parameters()
            ]
            v = de_variable(v)

            lambda_old, lambdas = 0., 1.
            i = 0
            while (abs((lambdas - lambda_old) / lambdas) >= 0.01):

                lambda_old = lambdas

                acc_Hv = [
                    torch.zeros(p.size()).cuda()
                    for p in model_block.parameters()
                ]
                for step, batch in enumerate(
                        tqdm(train_dataloader, desc="Iteration")):
                    if step < percentage_index:

                        loss = get_loss(batch)

                        loss.backward(create_graph=True)
                        grads = [
                            param.grad for param in model_block.parameters()
                        ]
                        params = model_block.parameters()

                        Hv = torch.autograd.grad(grads,
                                                 params,
                                                 grad_outputs=v,
                                                 only_inputs=True,
                                                 retain_graph=True)
                        acc_Hv = [
                            acc_Hv_p + Hv_p
                            for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                        ]
                        model.zero_grad()
                # calculate raylay quotients
                lambdas = group_product(acc_Hv, v).item() / percentage_index

                v = de_variable(acc_Hv)
                logger.info(f'lambda: {lambdas}')
                writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
                csv_file.flush()

                i += 1
            vs.append(v)
            lambdass.append(lambdas)
            break
        return vs, lambdass

    v1, lambdas1 = get_eigens()
    # calculate second eig vec
    logger.info("Calculate the second eig vec")
    block_id = 0
    vs = []
    lambdass = []
    for block_id in range(12):
        model_block = model.module.bert.encoder.layer[block_id]

        v = [
            torch.randn(p.size()).to(device) for p in model_block.parameters()
        ]
        v = de_variable(v)

        lambda_old, lambdas = 0., 1.
        i = 0
        while (abs((lambdas - lambda_old) / lambdas) >= 0.01):

            lambda_old = lambdas

            acc_Hv = [
                torch.zeros(p.size()).cuda() for p in model_block.parameters()
            ]
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if step < percentage_index:

                    loss = get_loss(batch)

                    loss.backward(create_graph=True)
                    grads = [param.grad for param in model_block.parameters()]
                    params = model_block.parameters()

                    Hv = torch.autograd.grad(grads,
                                             params,
                                             grad_outputs=v,
                                             only_inputs=True,
                                             retain_graph=True)
                    acc_Hv = [
                        acc_Hv_p + Hv_p for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                    ]
                    model.zero_grad()
            acc_Hv = [acc_hv / percentage_index for acc_hv in acc_Hv]
            # make the product be (H - lambda * v1 * v1^T) v
            tmp = lambdas1[block_id] * group_product(v1[block_id], v).item()

            acc_Hv = group_add(acc_Hv, v1[block_id], alpha=-tmp)
            # calculate raylay quotients
            lambdas = group_product(acc_Hv, v).item() / percentage_index

            v = de_variable(acc_Hv)
            logger.info(f'lambda: {lambdas}')
            writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
            csv_file.flush()

            i += 1

        vs.append(v)
        lambdass.append(lambdas)
        break
    v2, lambdas2 = vs, lambdass
    landscape_data = {
        'v1': v1,
        'v2': v2,
        'lambdas1': lambdas1,
        'lambdas2': lambdas2
    }
    import pickle
    outfile = open(f'results/{args.task_name}_landscape_data', 'wb')
    pickle.dump(landscape_data, outfile)
    outfile.close()
コード例 #5
0
ファイル: power_iteration.py プロジェクト: sIncerass/QBERT
def power_iteration(train_dataloader, model, get_loss, args, device):

    percentage_index = len(train_dataloader.dataset) * \
        args.data_percentage / args.train_batch_size
    print(f'percentage_index: {percentage_index}')

    model.eval()

    parent_dir = join(dirname(__file__), os.pardir)

    results_dir = join(parent_dir, 'results')

    csv_path = join(
        results_dir,
        f'{args.task_name}-{args.data_percentage}-seed-{args.seed}-eigens.csv')
    csv_file = open(csv_path, 'w', newline='')
    writer = csv.writer(csv_file,
                        delimiter=',',
                        quotechar='|',
                        quoting=csv.QUOTE_MINIMAL)
    writer.writerow(['block', 'iters', 'max_eigenvalue'])

    # make sure requires_grad are true
    for module in model.modules():
        for param in module.parameters():
            param.requires_grad = True

    block_id = 0
    for block_id in range(12):
        model_block = model.module.bert.encoder.layer[block_id]

        v = [
            torch.randn(p.size()).to(device) for p in model_block.parameters()
        ]
        v = de_variable(v)

        lambda_old, lambdas = 0., 1.
        i = 0
        while (abs((lambdas - lambda_old) / lambdas) >= 0.01):

            lambda_old = lambdas

            acc_Hv = [
                torch.zeros(p.size()).cuda() for p in model_block.parameters()
            ]
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if step < percentage_index:

                    loss = get_loss(batch)

                    loss.backward(create_graph=True)
                    grads = [param.grad for param in model_block.parameters()]
                    params = model_block.parameters()

                    Hv = torch.autograd.grad(grads,
                                             params,
                                             grad_outputs=v,
                                             only_inputs=True,
                                             retain_graph=True)
                    acc_Hv = [
                        acc_Hv_p + Hv_p for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                    ]
                    model.zero_grad()
            # calculate raylay quotients
            lambdas = group_product(acc_Hv, v).item() / percentage_index

            v = de_variable(acc_Hv)
            logger.info(f'lambda: {lambdas}')
            writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
            csv_file.flush()

            i += 1
コード例 #6
0
ファイル: power_iteration.py プロジェクト: sIncerass/QBERT
def get_trace_family(train_dataloader, model, get_loss, args, device, n_v=1):
    """
    compute the trace/num_params of model parameters Hessian with a full dataset.
    """
    percentage_index = len(train_dataloader.dataset) * \
        args.data_percentage / args.train_batch_size
    print(f'percentage_index: {percentage_index}')

    # change the model to evaluation mode, otherwise the batch Normalization Layer will change.
    # If you call this functino during training, remember to change the mode
    # back to training mode.
    model.eval()
    model.zero_grad()

    parent_dir = join(dirname(__file__), os.pardir)

    results_dir = join(parent_dir, 'results')

    csv_path = join(
        results_dir,
        f'{args.task_name}-{args.data_percentage}-seed-{args.seed}-{args.method}.csv'
    )
    csv_file = open(csv_path, 'w', newline='')
    writer = csv.writer(csv_file,
                        delimiter=',',
                        quotechar='|',
                        quoting=csv.QUOTE_MINIMAL)
    writer.writerow(
        ['block', 'iters', 'slq-trace', 'slq-abstrace', 'slq-quadtrace'])

    # make sure requires_grad are true
    for module in model.modules():
        for param in module.parameters():
            param.requires_grad = True

    for block_id in range(12):
        slq_trace = 0
        slq_quadtrace = 0
        slq_abstrace = 0
        for _ in range(n_v):
            model_block = model.module.bert.encoder.layer[block_id]

            alpha_list = []
            beta_list = []
            v_list = []
            beta = 0
            w = None

            v = [
                torch.randn(p.size()).to(device)
                for p in model_block.parameters()
            ]
            v = de_variable(v)
            for i in range(args.num_iter):

                acc_Hv = [
                    torch.zeros(p.size()).to(device)
                    for p in model_block.parameters()
                ]
                for step, batch in enumerate(
                        tqdm(train_dataloader, desc="Iteration")):
                    if step < percentage_index:

                        loss = get_loss(batch)

                        loss.backward(create_graph=True)
                        grads = [
                            param.grad for param in model_block.parameters()
                        ]
                        params = model_block.parameters()

                        Hv = torch.autograd.grad(grads,
                                                 params,
                                                 grad_outputs=v,
                                                 only_inputs=True,
                                                 retain_graph=True)
                        acc_Hv = [
                            acc_Hv_p + Hv_p
                            for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                        ]
                        model.zero_grad()
                acc_Hv = [
                    w * total_number_parameters(model_block) for w in acc_Hv
                ]
                if i == 0:
                    # Do specific manipluations.
                    # Lanczos algo on wiki, step 2, first iteration.
                    alpha = group_product(acc_Hv, v).item() / percentage_index
                    # Reviwer Double-check: here I choose to average the hessian vector.
                    w = [w / percentage_index for w in acc_Hv]

                    w = [w_i - alpha * v_i for w_i, v_i in zip(w, v)]
                    v_list.append(v)

                    alpha_list.append(alpha)

                    # Prepare for next step
                    beta = (group_product(w, w)**0.5).detach()
                    beta_list.append(beta)
                    # v = [w_i / beta for w_i in w]
                    v = orthonormal(w, v_list)
                    v_list.append(v)

                else:

                    # calculate raylay quotients
                    eigenvalue = group_product(acc_Hv,
                                               v).item() / percentage_index
                    w = [w / percentage_index for w in acc_Hv]
                    alpha = eigenvalue
                    alpha_list.append(alpha)

                    w = [
                        w_i - alpha * v_i - beta * v_old_i
                        for w_i, v_i, v_old_i in zip(w, *v_list[-2:])
                    ]

                    beta = (group_product(w, w)**0.5).detach()
                    beta_list.append(beta)
                    logger.info(f'num_iter: {i}, beta: {beta}')
                    assert beta != 0
                    # v = [w_i / beta for w_i in w]
                    v = orthonormal(w, v_list)
                    v_list.append(v)

            assert len(alpha_list) == args.num_iter
            beta_list.pop()  # The last one is unesscarry
            assert len(beta_list) == (args.num_iter - 1)

            m = args.num_iter
            T = torch.zeros(m, m).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            slq_trace += eigen_list @ weight_list
            slq_quadtrace += (eigen_list**2) @ weight_list
            slq_abstrace += torch.abs(eigen_list) @ weight_list

        slq_trace /= n_v
        slq_quadtrace /= n_v
        slq_abstrace /= n_v
        writer.writerow([
            f'{block_id}', f'{i}', f'{slq_trace}', f'{slq_abstrace}',
            f'{slq_quadtrace}'
        ])
        csv_file.flush()
コード例 #7
0
    def fullPass(self, train):

        for l in train:

            torch.cuda.empty_cache()
            print('Starting the full pass... ')

            X_sample_var = l[0]
            Y_sample_var = l[1]
            tr_ll, tr_accu = self.network.evalModel(X_sample_var, Y_sample_var)

            # With regularization
            self.network.zero_grad()
            grad = self.network.computeGradientIter2(X_sample_var,
                                                     Y_sample_var)

            #KFAC here.
            self.network.zero_grad()
            self.network.startRecording()
            self.network.computeTempsForKFAC(X_sample_var)
            self.network.stopRecording()
            x_kfac = self.network.computeKFACHv(X_sample_var, grad)

            # assert here, KFAC is PSD
            if (group_product(grad, x_kfac) < 0):
                print('Trouble with KFAC direction... it is NOT PSD ')
                exit()

            #Use Frank-Wolfe Direction here.
            fw_direction = []
            for t in grad:
                fw_direction.append(
                    torch.zeros(t.size()).type(torch.cuda.DoubleTensor))

            fw_direction, fw_model = getFrankWolfeUpdate(
                fw_direction, x_kfac, self.network, self.params.delta, 20,
                1e-8, X_sample_var, Y_sample_var)

            for idx, p in enumerate(self.prevWeights):
                fw_direction[idx].add_(self.momentum, p)

            for idx, w in enumerate(self.network.parameters()):
                w.data.add_(-1., fw_direction[idx])
                self.prevWeights[idx].copy_(fw_direction[idx])

            new_ll, new_accu = self.network.evalModel(X_sample_var,
                                                      Y_sample_var)
            rho = (new_ll - tr_ll) / (fw_model - 1e-16)

            if (rho > 0.75):  #1e-4
                self.params.delta = min(self.params.max_delta,
                                        2. * self.params.delta)  # 2
            if (rho < 0.25):
                self.params.delta = max(self.params.min_delta,
                                        0.5 * self.params.delta)  # 2
            if ((rho < 1e-4) or (new_ll > (10. * tr_ll))):
                if (self.debug):
                    print(
                        'Trouble.... Reject this step, since there  is no VISIBLE decrease '
                    )

            for idx, w in enumerate(self.network.parameters()):
                w.data.add_(1., self.prevWeights[idx])
            if (self.debug):
                print(rho, self.params.delta, new_ll, tr_ll, fw_model)

            del grad
コード例 #8
0
ファイル: frank_wolfe_update.py プロジェクト: kylasa/FITRE
def getStepModelReduction(t, vk, rk, hrk):
    return -t * group_product(vk, rk) + 0.5 * t * t * group_product(rk, hrk)
コード例 #9
0
    def fullPass(self, train, curIteration):

        offset = 0
        for l in train:

            print()
            print('Iteration: %d, Offset: %d' % (curIteration, offset))
            torch.cuda.empty_cache()
            #torch.no_grad ()
            if (self.debug):
                print()
                print()
                print()
                print('Starting the full pass... ')
                print(l[0][0, 0, :, :])
                print(l[0][0, 1, :, :])
                print(l[0][0, 2, :, :])
                print()
                print()
                print()

                print(l[1])

                print(l[0][255, 0, :, :])
                print(l[0][255, 1, :, :])
                print(l[0][255, 2, :, :])

            X_sample_var = l[0].type(TYPE).cuda()
            Y_sample_var = l[1].type(torch.LongTensor).cuda()

            # With regularization
            self.network.zero_grad()
            grad = self.network.computeGradientIter2(X_sample_var,
                                                     Y_sample_var)

            #if (self.debug):
            print('grad norm', math.sqrt(group_product(grad, grad)))
            print(
                'weights norm',
                math.sqrt(
                    group_product(self.network.parameters(),
                                  self.network.parameters())))

            #for p in self.network.parameters ():
            #	print( 'layer Norm: ', torch.norm( p.grad.data ) )

            #print( 'conv1: ', self.network.conv1.bias.grad.data )
            #print( 'conv2: ', self.network.conv2.bias.grad.data )
            #import pdb;pdb.set_trace();

            tr_ll, tr_accu = self.network.evalModel(X_sample_var, Y_sample_var)
            #print( grad )
            #print( 'Model: ', tr_ll )

            #KFAC here.
            #compute F{-1} * gradient = approximated natural gradient here.
            #convert the vector to structures.
            self.network.zero_grad()
            #print( 'grad norm', math.sqrt( group_product( grad, grad ) ) )

            self.network.startRecording()
            self.network.computeTempsForKFAC(X_sample_var)
            self.network.stopRecording()

            x_kfac = self.network.computeKFACHv(X_sample_var, grad)
            print('Norm of x_kfac: ', math.sqrt(group_product(x_kfac, x_kfac)))

            # assert here, KFAC is PSD
            if (group_product(grad, x_kfac) < 0):
                print('Trouble with KFAC direction... it is NOT PSD ')
                exit()

            #compute the model reduction here.
            # m = eta * grad * x_kfac + eta * eta * 0.5 * x_kfac * Hessian * x_kfac
            # with regularization
            Hv = self.network.computeHv_r(X_sample_var, Y_sample_var, x_kfac)
            #Hv = self.network.computeHv( X_sample_var, Y_sample_var, x_kfac )
            print('Norm of Hv: ', math.sqrt(group_product(Hv, Hv)))
            #s = 0
            #for l in Hv:
            #	s += torch.norm( l ) * torch.norm( l )
            #print( s, torch.sqrt( s ) )

            vHv = group_product(Hv, x_kfac).item()
            print('KFAC : vHv ', vHv)

            vnorm = math.sqrt(group_product(x_kfac, x_kfac))
            if (self.debug):
                print('KFAC Norm: ', vnorm)
            #handle Negative Curvature here.
            if (vHv < 0):
                #step = (self.params.delta) / vnorm
                #x_kfac = [ v * step for v in x_kfac ]
                #m_kfac = vHv * 0.5 * step * step - group_product( grad, x_kfac ).item ()
                x_kfac = [v * self.params.delta / vnorm for v in x_kfac]
                #print( 'grad * x_kfac: ', group_product( grad, x_kfac ).item () )
                #print( 'vHv term : ', vHv * 0.5 * self.params.delta * self.params.delta / vnorm / vnorm )
                m_kfac = vHv * 0.5 * self.params.delta * self.params.delta / vnorm / vnorm - group_product(
                    grad, x_kfac).item()
                print('Model Reduction kfac direction (Negative): ', m_kfac)
                if (self.debug):
                    print('alpha (NC): ', self.params.delta / vnorm)
            else:
                #import pdb;pdb.set_trace();
                gv = group_product(x_kfac, grad).item()
                if (self.debug):
                    print('group product: ', gv)
                step = gv / (vHv + 1e-6)
                step = min(step, (self.params.delta / (vnorm + 1e-16)))
                x_kfac = [v * step for v in x_kfac]
                if (self.debug):
                    print('alpha ', step)
                m_kfac = vHv * 0.5 * step * step - gv * step
                print('Model Reduction kfac direction (PSD) : ', m_kfac)

            if (self.check_grad == True):
                self.network.zero_grad()

                grad_dot = group_product(grad, grad).item()
                #print( 'Grad norm: ', math.sqrt( grad_dot ) )
                t = self.network.computeHv_r(X_sample_var, Y_sample_var, grad)
                print('Grad Hv Norm: ', math.sqrt(group_product(t, t)))
                #print( 'Grad v horm: ', math.sqrt( group_product( grad, grad ).item ()) )

                vHv = group_product(t, grad).item()
                #print( 'Grad: vHv ', vHv )
                print('Grad v horm: ',
                      math.sqrt(group_product(grad, grad).item()))
                vnorm = math.sqrt(grad_dot)

                if (vHv < 0):
                    #step = (self.params.delta) / vnorm
                    #grad = [ g * step for g in grad ]
                    #m_g_kfac = 0.5 * step * step * vHv - group_product( grad, grad )
                    #print( 'Grad grad * x_kfac: ', group_product( grad, grad).item () )
                    m_g_kfac = vHv * 0.5 * self.params.delta * self.params.delta / vnorm / vnorm - group_product(
                        grad, grad).item() * self.params.delta / vnorm
                    grad = [v * self.params.delta / vnorm for v in grad]

                    #print( 'Grad grad * x_kfac: ', group_product( grad, grad).item () )
                    #print( 'Grad vHv term : ', vHv * 0.5 * self.params.delta * self.params.delta / vnorm / vnorm )
                    #print( 'Grad NC alpha:  ', self.params.delta/vnorm)
                    print('Model reduction gradient kfac( negative ): ',
                          m_g_kfac)
                else:
                    gv = grad_dot
                    step = gv / (vHv + 1e-6)
                    #print( 'Grad alpha: ', step )
                    step = min(step, self.params.delta / (vnorm + 1e-16))
                    grad = [g * step for g in grad]
                    m_g_kfac = vHv * 0.5 * step * step - gv * step
                    print('Model reduction gradient kfac( psd) ): ', m_g_kfac)

            #print( m_g_kfac, m_kfac )
            #import pdb;pdb.set_trace();

            if ((not self.check_grad) or (m_kfac < m_g_kfac)):
                #use Natural Gradient
                #Momentum Here
                #Momentum
                for idx, p in enumerate(self.prevWeights):
                    x_kfac[idx].add_(self.momentum, p)

                #self.network.updateWeights( x_kfac.data.mul_( kfac_step.item () ) )
                for idx, w in enumerate(self.network.parameters()):
                    w.data.add_(-1., x_kfac[idx])
                    self.prevWeights[idx].copy_(x_kfac[idx])

                new_ll, new_accu = self.network.evalModel(
                    X_sample_var, Y_sample_var)

                rho = (new_ll - tr_ll) / (m_kfac - 1e-16)
                if (self.debug):
                    print(rho)

            else:
                #tr_ll, tr_accu = self.network.evalModel( X_sample_var, Y_sample_var )

                #grad.add_ (self.momentum, self.prevWeights )
                #group_add( grad, self.prevWeights, self.momentum )
                for v, mom in zip(grad, self.prevWeights):
                    v.data.add_(self.momentum, mom)

                #self.network.updateWeights( grad.data.mul_( grad_step.item () ) )
                #self.prevWeights.copy_( grad.data )
                for idx, w in enumerate(self.network.parameters()):
                    w.data.add_(-1., grad[idx])
                    self.prevWeights[idx].copy_(grad[idx])

                new_ll, new_accu = self.network.evalModel(
                    X_sample_var, Y_sample_var)
                rho = (new_ll - tr_ll) / (m_g_kfac - 1e-16)

            if (rho > 0.75):  #1e-4
                self.params.delta = min(self.params.max_delta,
                                        2. * self.params.delta)  # 2
            if (rho < 0.25):
                self.params.delta = max(self.params.min_delta,
                                        0.5 * self.params.delta)  # 2
            if ((rho < 1e-4) or (new_ll > (10. * tr_ll))):
                if (self.debug):
                    print(
                        'Trouble.... Reject this step, since there  is no VISIBLE decrease '
                    )
                #self.network.updateWeights( -1 * self.prevWeights );
                for idx, w in enumerate(self.network.parameters()):
                    w.data.add_(1., self.prevWeights[idx])
            if (self.debug):
                #print( m_g_kfac, m_kfac )
                if ((not self.check_grad) or (m_kfac < m_g_kfac)):
                    print(rho, self.params.delta, new_ll, tr_ll, m_kfac)
                else:
                    print(rho, self.params.delta, new_ll, tr_ll, m_g_kfac)

            if ((not self.check_grad) or (m_kfac < m_g_kfac)):
                print('%4.10e  %4.10e  %4.10e  %4.10e %4.10e  %6s  %3.6e' %
                      (tr_ll, new_ll, rho, m_kfac,
                       math.sqrt(
                           group_product(self.network.parameters(),
                                         self.network.parameters())), 'kfac',
                       self.params.delta))
            else:
                print('%4.10e  %4.10e  %4.10e  %4.10e %4.10e  %6s  %3.6e' %
                      (tr_ll, new_ll, rho, m_g_kfac,
                       math.sqrt(
                           group_product(self.network.parameters(),
                                         self.network.parameters())), 'grad',
                       self.params.delta))

            del grad
            offset += l[1].size()[0]