示例#1
0
def hyper_step(train_grad):
# def hyper_step():
    """Estimate the hypergradient, and take an update with it.
    """
    zero_hypergrad(get_hyper_train)

    # xentropy_loss, regularized_loss = train_loss_func()
    # train_grad = grad(regularized_loss, model.parameters(), create_graph=True)
    d_train_loss_d_w = gather_flat_grad(train_grad)

    # Compute gradients of the validation loss w.r.t. the weights/hypers
    num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train())
    d_val_loss_d_theta = torch.zeros(num_weights).cuda()
    # model.eval()
    model.zero_grad()

    val_loss = val_loss_func(hyperval_data)  # eval() is used in here
    d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters()))  # Do we need create_graph=True or retain_graph=True ?

    # compute d / d lambda (partial Lv / partial w * partial Lt / partial w)
    # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda))
    # indirect_grad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1)))
    hypergrad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=d_val_loss_d_theta.detach().view(-1)))
    get_hyper_train()[0].grad = -hypergrad
    # get_hyper_train()[0].grad = hypergrad
    # store_hypergrad(get_hyper_train, hypergrad)
    return val_loss, hypergrad.norm()
示例#2
0
def hyper_step(get_hyper_train, model, val_loss_func, val_loader, d_train_loss_d_w, elementary_lr, use_reg, args):
    """Estimate the hypergradient, and take an update with it.

    :param get_hyper_train:  A function which returns the hyperparameters we want to tune.
    :param model:  A function which returns the elementary parameters we want to tune.
    :param val_loss_func:  A function which takes input x and output y, then returns the scalar valued loss.
    :param val_loader: A generator for input x, output y tuples.
    :param d_train_loss_d_w:  The derivative of the training loss with respect to elementary parameters.
    :param hyper_optimizer: The optimizer which updates the hyperparameters.
    :return: The scalar valued validation loss, the hyperparameter norm, and the hypergradient norm.
    """
    zero_hypergrad(get_hyper_train)

    d_train_loss_d_w = gather_flat_grad(d_train_loss_d_w)

    # Compute gradients of the validation loss w.r.t. the weights/hypers
    num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train())
    d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda()
    model.train(), model.zero_grad()
    for batch_idx, (x, y) in enumerate(val_loader):
        val_loss = val_loss_func(x, y)
        d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters(), retain_graph=use_reg))
        if use_reg:
            direct_grad += gather_flat_grad(grad(val_loss, get_hyper_train()))
            direct_grad[direct_grad != direct_grad] = 0
        break

    # Initialize the preconditioner and counter
    if not args.use_cg:
        preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w,
                                                          elementary_lr, args.num_neumann_terms)
    else:
        def A_vector_multiply_func(vec):
            p1 = d_val_loss_d_theta.view(1, -1) @ vec.view(-1, 1)
            p2 = d_val_loss_d_theta.view(-1, 1) @ p1
            return p2.view(1, -1, 1)

        preconditioner, _ = cg_batch(A_vector_multiply_func, d_val_loss_d_theta.view(1, -1, 1))
    # conjugate_grad(A_vector_multiply_func, d_val_loss_d_theta)

    # compute d / d lambda (partial Lv / partial w * partial Lt / partial w)
    # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda))
    indirect_grad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1)))
    hypergrad = direct_grad + indirect_grad

    store_hypergrad(get_hyper_train, hypergrad)
    return val_loss, hypergrad.norm()
 def A_vector_multiply_func(vec):
     val = gather_flat_grad(
         grad(d_train_loss_d_w,
              model.parameters(),
              grad_outputs=vec.view(-1),
              retain_graph=True))
     # val_2 = gather_flat_grad(grad(d_train_loss_d_w, model.parameters(), grad_outputs=val.view(-1), retain_graph=True))
     # return val_2.view(1, -1, 1)
     return val.view(1, -1, 1)
示例#4
0
    def hyper_step(elementary_lr, do_true_inverse=False):
        """Estimate the hypergradient, and take an update with it.
        """
        zero_hypergrad(get_hyper_train)
        num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train())
        d_train_loss_d_w = torch.zeros(num_weights).cuda()
        model.train(), model.zero_grad()

        # First compute train loss on a batch
        for batch_idx, (x, y) in enumerate(train_loader):
            train_loss, _ = train_loss_func(x, y)
            optimizer.zero_grad()
            d_train_loss_d_w += gather_flat_grad(grad(train_loss, model.parameters(), create_graph=True))
            break
        optimizer.zero_grad()

        # Compute gradients of the validation loss w.r.t. the weights
        d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda()
        model.train(), model.zero_grad()
        for batch_idx, (x, y) in enumerate(val_loader):
            val_loss = val_loss_func(x, y)
            optimizer.zero_grad()
            d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters(), retain_graph=False))
            break

        # Initialize the preconditioner and counter
        preconditioner = d_val_loss_d_theta
        # Neumann series to do hessian inversion
        preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr,
                                                              args.num_neumann_terms, model)

        # compute d / d lambda (partial Lv / partial w * partial Lt / partial w)
        # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda))
        indirect_grad = gather_flat_grad(
            grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1)))

        # Direct grad is zero here due to no data augmentation for val data.
        hypergrad = direct_grad + indirect_grad

        zero_hypergrad(get_hyper_train)
        store_hypergrad(get_hyper_train, -hypergrad)
        return val_loss, hypergrad.norm()
def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w,
                                     elementary_lr, num_neumann_terms, model):
    preconditioner = d_val_loss_d_theta.detach()
    counter = preconditioner
    # Do the fixed point iteration to approximate the vector-inverseHessian product
    for i in range(num_neumann_terms):
        old_counter = counter
        # This increments counter to counter * (I - hessian) = counter - counter * hessian
        hessian_term = gather_flat_grad(
            grad(d_train_loss_d_w,
                 model.parameters(),
                 grad_outputs=counter.contiguous().view(-1),
                 retain_graph=True))
        counter = old_counter - elementary_lr * hessian_term
        preconditioner = preconditioner + counter
    return elementary_lr * preconditioner
def KFAC_optimize(args, model, train_loader, val_loader, hyper_optimizer,
                  kfac_opt, KFAC_damping, epoch_h):
    # set up placeholder for the partial derivative in each batch
    total_d_val_loss_d_lambda = torch.zeros(
        get_hyper_train(args, model).size(0))
    if args.cuda: total_d_val_loss_d_lambda = total_d_val_loss_d_lambda.cuda()

    ######################## Calculate v1 from the paper, i.e. dL_v / dw
    num_weights = sum(p.numel() for p in model.parameters())
    d_val_loss_d_theta = torch.zeros(num_weights).cuda()
    model.train()
    for batch_idx, (x, y) in enumerate(val_loader):
        model.zero_grad()
        x, y = prepare_data(args, x, y)
        val_loss, _ = batch_loss(args, model, x, y, model, val_loss_func)
        val_loss_grad = grad(val_loss, model.parameters())
        d_val_loss_d_theta += gather_flat_grad(val_loss_grad)
        if batch_idx >= args.val_batch_num: break
    d_val_loss_d_theta /= (
        batch_idx + 1
    )  # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end

    ######################## Calculate preconditioner, i.e. v1*(inverse Hessian approximation) [orange term in Figure 2]
    assert args.hessian == 'KFAC', f"Passed {args.hessian}, not a valid choice. Need to choose KFAC"
    # model.zero_grad()
    flat_pre_conditioner = torch.zeros(num_weights).cuda()
    for batch_idx, (x, y) in enumerate(train_loader):
        model.train()
        model.zero_grad(), hyper_optimizer.zero_grad()
        x, y = prepare_data(args, x, y)
        train_loss, _ = batch_loss(args, model, x, y, model, train_loss_func)
        # TODO (JON): Probably don't recompute - use create_graph and retain_graph?
        d_train_loss_d_theta = grad(train_loss,
                                    model.parameters(),
                                    create_graph=True)
        flat_d_train_loss_d_theta = gather_flat_grad(d_train_loss_d_theta)

        current = 0
        for m in model.modules():
            if m.__class__.__name__ in ['Linear', 'Conv2d']:
                # kfac_opt.zero_grad()
                if m.__class__.__name__ == 'Conv2d':
                    size0, size1 = m.weight.size(0), m.weight.view(
                        m.weight.size(0), -1).size(1)
                else:
                    size0, size1 = m.weight.size(0), m.weight.size(1)
                mod_size1 = size1 + 1 if m.bias is not None else size1
                shape = (size0, (mod_size1))
                size = size0 * mod_size1
                pre_conditioner = kfac_opt._get_natural_grad(
                    m, d_val_loss_d_theta[current:current + size].view(shape),
                    KFAC_damping)
                flat_pre_conditioner[current:current +
                                     size] = gather_flat_grad(pre_conditioner)
                current += size
        model.zero_grad(), hyper_optimizer.zero_grad()
        flat_d_train_loss_d_theta.backward(flat_pre_conditioner)
        total_d_val_loss_d_lambda -= get_hyper_train(args, model).grad
        if batch_idx >= args.train_batch_num:
            break
    total_d_val_loss_d_lambda /= (
        batch_idx + 1
    )  # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end

    ##################### Compute direct gradient of val loss w.r.t. lambda. This is usually 0
    direct_d_val_loss_d_lambda = torch.zeros(
        get_hyper_train(args, model).size(0))
    if args.cuda:
        direct_d_val_loss_d_lambda = direct_d_val_loss_d_lambda.cuda()
    model.train()
    for batch_idx, (x_val, y_val) in enumerate(val_loader):
        model.zero_grad(), hyper_optimizer.zero_grad()
        x_val, y_val = prepare_data(args, x_val, y_val)
        val_loss, _ = batch_loss(args, model, x_val, y_val, model,
                                 val_loss_func)
        val_loss_grad = grad(val_loss,
                             get_hyper_train(args, model),
                             allow_unused=True)
        if val_loss_grad is not None and val_loss_grad[0] is not None:
            direct_d_val_loss_d_lambda += gather_flat_grad(val_loss_grad)
        else:
            break
        if batch_idx >= args.val_batch_num: break
    direct_d_val_loss_d_lambda /= (
        batch_idx + 1
    )  # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end

    get_hyper_train(
        args,
        model).grad = direct_d_val_loss_d_lambda + total_d_val_loss_d_lambda
    print("weight={}, update={}".format(
        get_hyper_train(args, model).norm(),
        get_hyper_train(args, model).grad.norm()))

    hyper_optimizer.step()
    model.zero_grad(), hyper_optimizer.zero_grad()
    return get_hyper_train(args, model), get_hyper_train(args, model).grad
def hyper_step(elementary_lr,
               args,
               model,
               train_loader,
               val_loader,
               augment_net,
               reweighting_net,
               optimizer,
               use_reg,
               reg_anneal_epoch,
               stop_reg_epoch,
               graph_iter,
               device,
               do_true_inverse=False):
    # hyper_step(get_hyper_train, model, val_loss_func, val_loader, old_d_train_loss_d_w, elementary_lr, use_reg, args, train_loader, train_loss_func, elementary_optimizer):
    """
    Estimate the hypergradient, and take an update with it.

    :param get_hyper_train:  A function which returns the hyperparameters we want to tune.
    :param model:  A function which returns the elementary parameters we want to tune.
    :param val_loss_func:  A function which takes input x and output y, then returns the scalar valued loss.
    :param val_loader: A generator for input x, output y tuples.
    :param d_train_loss_d_w:  The derivative of the training loss with respect to elementary parameters.
    :param hyper_optimizer: The optimizer which updates the hyperparameters.
    :return: The scalar valued validation loss, the hyperparameter norm, and the hypergradient norm.
    """
    zero_hypergrad(get_hyper_train, args, model, augment_net, reweighting_net)
    num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(
        p.numel()
        for p in get_hyper_train(args, model, augment_net, reweighting_net))
    print(f"num_weights : {num_weights}, num_hypers : {num_hypers}")

    # d_train_loss_d_w = gather_flat_grad(d_train_loss_d_w)  # TODO: Commented this out!
    d_train_loss_d_w = torch.zeros(num_weights).to(device)
    model.train(), model.zero_grad()
    for batch_idx, (x, y) in enumerate(train_loader):
        train_loss, _, graph_iter = train_loss_func(x, y, args, model,
                                                    augment_net,
                                                    reweighting_net,
                                                    graph_iter, device)
        optimizer.zero_grad()
        d_train_loss_d_w += gather_flat_grad(
            grad(train_loss, model.parameters(), create_graph=True))
        break  # TODO (@Mo): Huh?
    optimizer.zero_grad()

    # Compute gradients of the validation loss w.r.t. the weights/hypers
    d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).to(
        device), torch.zeros(num_hypers).to(device)
    model.train(), model.zero_grad()
    for batch_idx, (x, y) in enumerate(val_loader):
        val_loss = val_loss_func(x, y, args, model, augment_net, use_reg,
                                 reg_anneal_epoch, stop_reg_epoch, device)
        optimizer.zero_grad()
        d_val_loss_d_theta += gather_flat_grad(
            grad(val_loss, model.parameters(), retain_graph=use_reg))
        if use_reg:
            direct_grad += gather_flat_grad(
                grad(val_loss, get_hyper_train(), allow_unused=True))
            direct_grad[direct_grad != direct_grad] = 0
        break  # TODO (@Mo): Huh?

    # Initialize the preconditioner and counter
    preconditioner = d_val_loss_d_theta
    if do_true_inverse:
        hessian = torch.zeros(num_weights, num_weights).to(device)
        for i in range(num_weights):
            hess_row = gather_flat_grad(
                grad(d_train_loss_d_w[i],
                     model.parameters(),
                     retain_graph=True))
            hessian[i] = hess_row
            # hessian[-i] = hess_row
        '''
        hessian = hessian.t()
        final_hessian = torch.zeros(num_weights, num_weights).to(device)
        for i in range(num_weights):
            final_hessian[-i] = hessian[i]
        hessian = final_hessian
        '''
        # hessian = hessian  #hessian @ hessian
        # chol = torch.cholesky(hessian.view(1, num_weights, num_weights))[0] + 1e-3*torch.eye(num_weights).to(device)
        inv_hessian = torch.pinverse(hessian)
        # inv_hessian = inv_hessian @ inv_hessian
        preconditioner = d_val_loss_d_theta @ inv_hessian
    elif not args.use_cg:
        preconditioner = neumann_hyperstep_preconditioner(
            d_val_loss_d_theta, d_train_loss_d_w, elementary_lr,
            args.num_neumann_terms, model)
    else:

        def A_vector_multiply_func(vec):
            val = gather_flat_grad(
                grad(d_train_loss_d_w,
                     model.parameters(),
                     grad_outputs=vec.view(-1),
                     retain_graph=True))
            # val_2 = gather_flat_grad(grad(d_train_loss_d_w, model.parameters(), grad_outputs=val.view(-1), retain_graph=True))
            # return val_2.view(1, -1, 1)
            return val.view(1, -1, 1)

        if args.num_neumann_terms > 0:
            preconditioner, _ = cg_batch(A_vector_multiply_func,
                                         d_val_loss_d_theta.view(1, -1, 1),
                                         maxiter=args.num_neumann_terms)

    if args.save_hessian and do_true_inverse:
        '''
        if do_true_inverse:
            name = 'true_inv'
        elif args.use_cg:
            name = 'cg'
        else:
            name = 'neumann_' + str(args.num_neumann_terms)
        '''

        save_hessian(inv_hessian, name='true_inv')
        new_hessian = torch.zeros(inv_hessian.shape).to(device)
        for param_group in optimizer.param_groups:
            cur_step_size = param_group['step_size']
            cur_bias_correction = param_group['bias_correction']
            print(f'size: {cur_step_size}')
            break
        for i in range(10):
            hess_term = torch.eye(inv_hessian.shape[0]).to(device)
            norm_1, norm_2 = torch.norm(torch.eye(
                inv_hessian.shape[0]).to(device),
                                        p=2), torch.norm(hessian, p=2)
            for j in range(i):
                # norm_2 = torch.norm(hessian@hessian, p=2)
                hess_term = hess_term @ (
                    torch.eye(inv_hessian.shape[0]).to(device) -
                    norm_1 / norm_2 * hessian)
            new_hessian += hess_term  # (torch.eye(inv_hessian.shape[0]).to(device) - elementary_lr*0.1*hessian)
            # if (i+1) % 10 == 0 or i == 0:
            save_hessian(new_hessian, name='neumann_' + str(i))
    # conjugate_grad(A_vector_multiply_func, d_val_loss_d_theta)

    # compute d / d lambda (partial Lv / partial w * partial Lt / partial w)
    # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda))
    indirect_grad = gather_flat_grad(
        grad(d_train_loss_d_w,
             get_hyper_train(args, model, augment_net, reweighting_net),
             grad_outputs=preconditioner.view(-1)))
    hypergrad = direct_grad + indirect_grad

    zero_hypergrad(get_hyper_train, args, model, augment_net, reweighting_net)
    store_hypergrad(get_hyper_train, -hypergrad, args, model, augment_net,
                    reweighting_net)
    # get_hyper_train()[0].grad = hypergrad
    return val_loss, hypergrad.norm(), graph_iter
示例#8
0
def main():
    # Loop over epochs.
    lr = args.lr
    start_time = time.time()

    train_epoch = 0
    global_step = 0

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        # param_optimizer = optim.SGD(model.parameters(), lr=args.lr)
        # param_optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
        param_optimizer = optim.Adam(model.parameters(), lr=args.lr)
        hyper_optimizer = optim.Adam(get_hyper_train(), lr=args.hyper_lr)

        while train_epoch < args.epochs:
            epoch_start_time = time.time()

            model.train()
            xentropy_loss, regularized_loss = train_loss_func()

            # ---------Zero grad------------
            current_index = 0
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                if p.grad is not None:
                    p.grad = p.grad * 0  # Explicitly zeroing the gradients -- why is this required? Why not model.zero_grad() ?
                current_index += p_num_params
            param_optimizer.zero_grad()
            # -----End of zero grad---------

            train_grad = grad(regularized_loss, model.parameters(), create_graph=True)

            hyper_optimizer.zero_grad()
            val_loss, grad_norm = hyper_step(train_grad)
            # val_loss, grad_norm = hyper_step()
            hyper_optimizer.step()
            # val_loss = torch.zeros(1)

            # Replace the original gradient for the elementary optimizer step.
            current_index = 0
            flat_train_grad = gather_flat_grad(train_grad)
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape)
                current_index += p_num_params

            param_optimizer.step()
            cur_loss = xentropy_loss
            elapsed = time.time() - start_time

            if global_step % 100 == 0:
                val_loss = evaluate(val_data, test_batch_size)
                test_loss = evaluate(test_data, test_batch_size)

                hparam_dict = make_hparam_dict()
                iteration_dict = { 'iteration': global_step, 'time': time.time() - start_time, 'train_loss': cur_loss.item(),
                                   'val_loss': val_loss, 'train_ppl': math.exp(cur_loss.item()), 'val_ppl': math.exp(val_loss),
                                   'test_loss': test_loss, 'test_ppl': math.exp(test_loss),
                                   **hparam_dict }
                logger.write('iteration', iteration_dict)

                hparam_string = ' | '.join(['{}: {}'.format(key, value) for (key, value) in hparam_dict.items()])

                print('| epoch {:3d} | step {} | lr {:05.5f} | ms/batch {:5.2f} | '
                      'loss {:5.2f} | ppl {:8.2f} | val_loss: {:6.2f} | val ppl: {:6.2f} | test_loss: {:6.2f} | test ppl: {:6.2f} | {}'.format(
                      train_epoch, global_step, param_optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval,
                      cur_loss, math.exp(cur_loss), val_loss, math.exp(val_loss), test_loss, math.exp(test_loss), hparam_string))

            # print('| epoch {:3d} | batches | lr {:05.5f} | ms/batch {:5.2f} | '
            #       'loss {:5.2f} | ppl {:8.2f} | val_loss: {:6.2f} | val ppl: {:6.2f} | wdecay mean: {:6.4e} | wdecay std: {:6.4e} | wdecay: {}'.format(
            #       train_epoch, param_optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval,
            #       cur_loss, math.exp(cur_loss), val_loss, math.exp(val_loss), torch.exp(model.weight_decay).mean().item(),
            #       torch.exp(model.weight_decay).std().item(), model.weight_decay))

            start_time = time.time()
            global_step += 1
            train_epoch += 1
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')
        sys.stdout.flush()
示例#9
0
def experiment(args):
    # Setup the random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # Load the baseline model
    args.load_baseline_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug1.pt'
    args.load_finetune_checkpoint = None  # TODO: Make it load the augment net if this is provided
    model, train_loader, val_loader, test_loader, augment_net, reweighting_net, checkpoint = get_models(args)

    # Load the logger
    from train_augment_net_multiple import load_logger, get_id
    csv_logger, test_id = load_logger(args)
    args.save_loc = './finetuned_checkpoints/' + get_id(args)

    # Hyperparameter access functions
    def get_hyper_train():
        # return torch.cat([p.view(-1) for p in augment_net.parameters()])
        if args.use_augment_net and args.use_reweighting_net:
            return list(augment_net.parameters()) + list(reweighting_net.parameters())
        elif args.use_augment_net:
            return augment_net.parameters()
        elif args.use_reweighting_net:
            return reweighting_net.parameters()

    def get_hyper_train_flat():
        if args.use_augment_net and args.use_reweighting_net:
            return torch.cat([torch.cat([p.view(-1) for p in augment_net.parameters()]),
                              torch.cat([p.view(-1) for p in reweighting_net.parameters()])])
        elif args.use_reweighting_net:
            return torch.cat([p.view(-1) for p in reweighting_net.parameters()])
        elif args.use_augment_net:
            return torch.cat([p.view(-1) for p in augment_net.parameters()])

    # Setup the optimizers
    if args.load_baseline_checkpoint is not None:
        args.lr = args.lr * 0.2 * 0.2 * 0.2
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.wdecay)
    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)  # [60, 120, 160]
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    hyper_optimizer = optim.Adam(get_hyper_train(), lr=1e-3)  # Adam(get_hyper_train())
    hyper_scheduler = MultiStepLR(hyper_optimizer, milestones=[40, 100, 140], gamma=0.2)

    graph_iter = 0
    def train_loss_func(x, y):
        x, y = x.cuda(), y.cuda()
        reg = 0.

        if args.use_augment_net:
            # old_x = x
            x = augment_net(x, class_label=y)
            '''num_sample = 10
            xs =torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda()
            for i in range(num_sample):
                xs[i] = augment_net(x, class_label=y)
            xs_diffs = (torch.mean(xs, dim=0) - old_x) ** 2
            diff_loss = torch.mean(xs_diffs)
            entrop_loss = -torch.mean(torch.std(xs, dim=0) ** 2)
            reg = 10 * diff_loss + entrop_loss'''

        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y, reduction='none')

        if args.use_reweighting_net:
            loss_weights = reweighting_net(x)  # TODO: Or reweighting_net(augment_x) ??
            loss_weights = loss_weights.squeeze()
            loss_weights = F.sigmoid(loss_weights / 10.0 ) * 2.0 + 0.1
            # loss_weights = (loss_weights - torch.mean(loss_weights)) / torch.std(loss_weights)
            # loss_weights = F.softmax(loss_weights)
            # loss_weights = loss_weights * args.batch_size
            # TODO: Want loss_weight vs x_entropy_loss

            nonlocal graph_iter
            if graph_iter % 100 == 0:
                import matplotlib.pyplot as plt
                np_loss = xentropy_loss.data.cpu().numpy()
                np_weight = loss_weights.data.cpu().numpy()
                for i in range(10):
                    class_indices = (y == i).cpu().numpy()
                    class_indices = [val*ind for val, ind in enumerate(class_indices) if val != 0]
                    plt.scatter(np_loss[class_indices], np_weight[class_indices], alpha=0.5, label=str(i))
                # plt.scatter((xentropy_loss*loss_weights).data.cpu().numpy(), loss_weights.data.cpu().numpy(), alpha=0.5, label='weighted')
                # print(np_loss)
                plt.ylim([np.min(np_weight) / 2.0, np.max(np_weight) * 2.0])
                plt.xlim([np.min(np_loss) / 2.0, np.max(np_loss) * 2.0])
                plt.yscale('log')
                plt.xscale('log')
                plt.axhline(1.0, c='k')
                plt.ylabel("loss_weights")
                plt.xlabel("xentropy_loss")
                plt.legend()
                plt.savefig("images/aaaa_lossWeightvsEntropy.pdf")
                plt.clf()

            xentropy_loss = xentropy_loss * loss_weights
        graph_iter += 1

        xentropy_loss = xentropy_loss.mean() + reg
        return xentropy_loss, pred

    use_reg = args.use_augment_net
    reg_anneal_epoch = 0
    stop_reg_epoch = 200
    if args.reg_weight == 0:
        use_reg = False

    def val_loss_func(x, y):
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)

        reg = 0
        if args.use_augment_net:
            if use_reg:
                num_sample = 10
                xs = torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda()
                for i in range(num_sample):
                    xs[i] = augment_net(x, class_label=y)
                xs_diffs = (torch.abs(torch.mean(xs, dim=0) - x))
                diff_loss = torch.mean(xs_diffs)
                stds = torch.std(xs, dim=0)
                entrop_loss = -torch.mean(stds)
                # TODO : Remember to add direct grad back in to hyper_step
                reg = args.reg_weight * (diff_loss + entrop_loss)
            else:
                reg = 0

        # reg *= (args.num_finetune_epochs - reg_anneal_epoch) / (args.num_finetune_epochs + 2)
        if reg_anneal_epoch >= stop_reg_epoch:
            reg *= 0
        return xentropy_loss + reg

    def test(loader, do_test_augment=True, num_augment=10):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct, total = 0., 0.
        losses = []
        for images, labels in loader:
            images, labels = images.cuda(), labels.cuda()

            with torch.no_grad():
                pred = model(images)
                if do_test_augment:
                    if args.use_augment_net and args.num_neumann_terms >= 0:
                        shape_0, shape_1 = pred.shape[0], pred.shape[1]
                        pred = pred.view(1, shape_0, shape_1)  # Batch size, num_classes
                        for _ in range(num_augment):
                            pred = torch.cat((pred, model(augment_net(images)).view(1, shape_0, shape_1)))
                        pred = torch.mean(pred, dim=0)
                xentropy_loss = F.cross_entropy(pred, labels)
                losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    # print(f"Initial Val Loss: {test(val_loader)}")
    # print(f"Initial Test Loss: {test(test_loader)}")

    init_time = time.time()
    val_loss, val_acc = test(val_loader)
    test_loss, test_acc = test(test_loader)
    print(f"Initial Val Loss: {val_loss, val_acc}")
    print(f"Initial Test Loss: {test_loss, test_acc}")
    iteration = 0
    for epoch in range(0, args.num_finetune_epochs):
        reg_anneal_epoch = epoch
        xentropy_loss_avg = 0.
        total_val_loss, val_loss = 0., 0.
        correct = 0.
        total = 0.
        weight_norm, grad_norm = .0, .0

        progress_bar = tqdm(train_loader)
        num_tune_hyper = 45000 / 5000  # 1/5th the val data as train data
        hyper_num = 0
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            xentropy_loss, pred = train_loss_func(images, labels)  # F.cross_entropy(pred, labels)
            xentropy_loss_avg += xentropy_loss.item()

            current_index = 0
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                if p.grad is not None:
                    p.grad = p.grad * 0
                current_index += p_num_params
            # optimizer.zero_grad()
            train_grad = grad(xentropy_loss, model.parameters(), create_graph=True)  #

            if args.num_neumann_terms >= 0:  # if this is less than 0, then don't do hyper_steps
                if i % num_tune_hyper == 0:
                    cur_lr = 1.0
                    for param_group in optimizer.param_groups:
                        cur_lr = param_group['lr']
                        break
                    val_loss, grad_norm = hyper_step(get_hyper_train, model, val_loss_func, val_loader,
                                                     train_grad, cur_lr, use_reg, args)
                    hyper_optimizer.step()

                    weight_norm = get_hyper_train_flat().norm()
                    total_val_loss += val_loss.item()
                    hyper_num += 1

            # Replace the original gradient for the elementary optimizer step.
            current_index = 0
            flat_train_grad = gather_flat_grad(train_grad)
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                # if p.grad is not None:
                p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape)
                current_index += p_num_params
            optimizer.step()

            iteration += 1

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.4f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / max(hyper_num, 1)),
                acc='%.4f' % accuracy,
                weight='%.3f' % weight_norm,
                update='%.3f' % grad_norm
            )
            if i % (num_tune_hyper ** 2) == 0 and args.use_augment_net:
                from train_augment_net_graph import save_images
                if args.do_diagnostic:
                    save_images(images, labels, augment_net, args)
                saver(epoch, model, optimizer, augment_net, reweighting_net, hyper_optimizer, args.save_loc)
                val_loss, val_acc = test(val_loader)
                csv_logger.writerow({'epoch': str(epoch),
                                     'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
                                     'val_loss': str(val_loss), 'val_acc': str(val_acc),
                                     'test_loss': str(test_loss), 'test_acc': str(test_acc),
                                     'run_time': time.time() - init_time,
                                     'iteration': iteration})

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)  # , hyper_scheduler.step(epoch)
        csv_logger.writerow({'epoch': str(epoch),
                             'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
                             'val_loss': str(val_loss), 'val_acc': str(val_acc),
                             'test_loss': str(test_loss), 'test_acc': str(test_acc),
                             'run_time': time.time() - init_time, 'iteration': iteration})