def graph_single_args(args, save_loc=None):
    """

    :param args:
    :return:
    """
    from train_augment_net_multiple import get_id
    if save_loc is None:
        args.save_loc = finetune_location + get_id(args)
    else:
        args.save_loc = save_loc

    args.load_baseline_checkpoint = None
    args.load_finetune_checkpoint = args.save_loc + '/checkpoint.pt'

    args.data_augmentation = False  # Don't use data augmentation for constructing graphs

    from train_augment_net2 import get_models
    model, train_loader, val_loader, test_loader, augment_net, reweighting_net, checkpoint = get_models(
        args)

    progress_bar = tqdm(train_loader)
    for i, (images, labels) in enumerate(progress_bar):
        images, labels = images.cuda(), labels.cuda()

        save_images(images, labels, augment_net, args)
Exemple #2
0
    def run_val_prop_compare(self):
        # TODO (@Mo): Use itertools' product
        for seed in self.seeds:
            for dataset in self.datasets:
                for hyperparam in self.hyperparams:
                    for data_size in self.data_sizes:
                        data_to_save = {
                            'val_losses': [],
                            'val_accs': [],
                            'test_losses': [],
                            'test_accs': [],
                            'val_losses_re': [],
                            'val_accs_re': [],
                            'test_losses_re': [],
                            'test_accs_re': [],
                            'info': ''
                        }
                        for val_prop in self.val_props:
                            print(
                                f"seed:{seed}, dataset:{dataset}, hyperparam:{hyperparam}, data_size:{data_size}, prop:{val_prop}"
                            )
                            args = make_val_size_compare_finetune_params(
                                hyperparam, val_prop, data_size, dataset,
                                self.model, self.num_finetune_epochs, self.lr)
                            args.seed = seed
                            train_loss, accuracy, val_loss, val_acc, test_loss, test_acc = experiment(
                                args, self.device)
                            data_to_save['val_losses'] += [val_loss]
                            data_to_save['val_accs'] += [val_acc]
                            data_to_save['test_losses'] += [test_loss]
                            data_to_save['test_accs'] += [test_acc]

                            second_args = make_val_size_compare_finetune_params(
                                hyperparam, 0, data_size, dataset, self.model,
                                self.num_finetune_epochs, self.lr)
                            second_args.seed = seed
                            second_args.num_neumann_terms = -1
                            loc = '/sailhome/motiwari/data-augmentation/implicit-hyper-opt/CG_IFT_test/finetuned_checkpoints/'
                            loc += get_id(args) + '/'
                            loc += 'checkpoint.pt'
                            second_args.load_finetune_checkpoint = loc
                            train_loss_re, accuracy_re, val_loss_re, val_acc_re, test_loss_re, test_acc_re = experiment(
                                second_args, self.device)
                            data_to_save['val_losses_re'] += [val_loss_re]
                            data_to_save['val_accs_re'] += [val_acc_re]
                            data_to_save['test_losses_re'] += [test_loss_re]
                            data_to_save['test_accs_re'] += [test_acc_re]
                        '''print(f"Data size = {data_size}")
                        print(f"Proportions: {val_props}")
                        print(f"val_losses: {data_to_save['val_losses']}")
                        print(f"val_accuracies: {data_to_save['val_accs']}")
                        print(f"test_losses: {data_to_save['test_losses']}")
                        print(f"test_accuracies: {data_to_save['test_accs']}")'''
                        with open(
                                f'finetuned_checkpoints/dataset:{dataset}_datasize:{data_size}_hyperparam:{hyperparam}_seed:{seed}.pkl',
                                'wb') as f:
                            pickle.dump(data_to_save, f)
def experiment(args, device):
    if args.do_print:
        print(args)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # Load the baseline model
    args.load_baseline_checkpoint = None  # '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug1.pt'
    args.load_finetune_checkpoint = ''  # TODO: Make it load the augment net if this is provided

    train_loader, val_loader, test_loader = DataLoaders.get_data_loaders(
        dataset=args.dataset,
        batch_size=args.batch_size,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        num_train=50000,
        data_augment=args.data_augmentation)
    model_loader = ModelLoader(args, device)
    model, augment_net, reweighting_net, checkpoint = model_loader.get_models()

    # Load the logger
    csv_logger, test_id = load_logger(args)
    args.save_loc = './finetuned_checkpoints/' + get_id(args)

    # Setup the optimizers
    if args.load_baseline_checkpoint is not None:
        args.lr = args.lr * 0.2 * 0.2 * 0.2  # TODO (@Mo): oh my god no
    if args.use_weight_decay:
        # optimizer = optim.Adam(model.parameters(), lr=1e-3)
        args.wdecay = 0

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=args.wdecay)
    if args.dataset == DATASET_BOSTON:
        optimizer = optim.Adam(model.parameters())
    use_scheduler = False
    if not args.do_simple:
        use_scheduler = True
    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160],
                            gamma=0.2)  # [60, 120, 160]
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    use_hyper_scheduler = False
    hyper_optimizer = optim.RMSprop(
        get_hyper_train(args, model, augment_net, reweighting_net))
    if not args.do_simple:
        hyper_optimizer = optim.SGD(get_hyper_train(args, model, augment_net,
                                                    reweighting_net),
                                    lr=args.lr,
                                    momentum=0.9,
                                    nesterov=True)
        use_hyper_scheduler = True
    hyper_scheduler = MultiStepLR(hyper_optimizer,
                                  milestones=[40, 100, 140],
                                  gamma=0.2)

    graph_iter = 0
    use_reg = args.use_augment_net and not args.use_reweighting_net
    reg_anneal_epoch = 0
    stop_reg_epoch = 200
    if args.reg_weight == 0:
        use_reg = False

    init_time = time.time()
    val_loss, val_acc = test(val_loader, args, model, augment_net, device)
    test_loss, test_acc = test(test_loader, args, model, augment_net, device)
    if args.do_print:
        print(f"Initial Val Loss: {val_loss, val_acc}")
        print(f"Initial Test Loss: {test_loss, test_acc}")
    iteration = 0
    hypergradient_cos_diff, hypergradient_l2_diff = -1, -1
    for epoch in range(0, args.num_finetune_epochs):
        reg_anneal_epoch = epoch
        xentropy_loss_avg = 0.
        total_val_loss, val_loss = 0., 0.
        correct = 0.
        total = 0.
        weight_norm, grad_norm = .0, .0

        if args.do_print:
            progress_bar = tqdm(train_loader)
        else:
            progress_bar = train_loader
        num_tune_hyper = 45000 / 5000  # 1/5th the val data as train data
        if args.do_simple:
            num_tune_hyper = 1
        hyper_num = 0
        for i, (images, labels) in enumerate(progress_bar):
            if args.do_print:
                progress_bar.set_description('Finetune Epoch ' + str(epoch))

            images, labels = images.to(device), labels.to(device)
            # pred = model(images)
            optimizer.zero_grad()  # TODO: ADDED
            xentropy_loss, pred, graph_iter = train_loss_func(
                images, labels, args, model, augment_net, reweighting_net,
                graph_iter, device)  # F.cross_entropy(pred, labels)
            xentropy_loss.backward()  # TODO: ADDED
            optimizer.step()  # TODO: ADDED
            optimizer.zero_grad()  # TODO: ADDED
            xentropy_loss_avg += xentropy_loss.item()

            if epoch > args.warmup_epochs and args.num_neumann_terms >= 0 and args.load_finetune_checkpoint == '':  # if this is less than 0, then don't do hyper_steps
                if i % num_tune_hyper == 0:
                    cur_lr = 1.0
                    for param_group in optimizer.param_groups:
                        cur_lr = param_group['lr']
                        break
                    train_grad = None  # TODO: ADDED
                    val_loss, grad_norm, graph_iter = hyper_step(
                        cur_lr, args, model, train_loader, val_loader,
                        augment_net, reweighting_net, optimizer, use_reg,
                        reg_anneal_epoch, stop_reg_epoch, graph_iter, device)

                    if args.do_inverse_compare:
                        approx_hypergradient = get_hyper_train_flat(
                            args, model, augment_net, reweighting_net).grad
                        # TODO: Call hyper_step with the true inverse
                        _, _, graph_iter = hyper_step(cur_lr,
                                                      args,
                                                      model,
                                                      train_loader,
                                                      val_loader,
                                                      augment_net,
                                                      reweighting_net,
                                                      optimizer,
                                                      use_reg,
                                                      reg_anneal_epoch,
                                                      stop_reg_epoch,
                                                      graph_iter,
                                                      device,
                                                      do_true_inverse=True)
                        true_hypergradient = get_hyper_train_flat(
                            args, model, augment_net, reweighting_net).grad
                        hypergradient_l2_norm = torch.norm(
                            true_hypergradient - approx_hypergradient, p=2)
                        norm_1, norm_2 = torch.norm(true_hypergradient,
                                                    p=2), torch.norm(
                                                        approx_hypergradient,
                                                        p=2)
                        hypergradient_cos_norm = (
                            true_hypergradient @ approx_hypergradient) / (
                                norm_1 * norm_2)
                        hypergradient_cos_diff = hypergradient_cos_norm.item()
                        hypergradient_l2_diff = hypergradient_l2_norm.item()
                        print(
                            f"hypergrad_diff, l2: {hypergradient_l2_norm}, cos: {hypergradient_cos_norm}"
                        )
                    # get_hyper_train, model, val_loss_func, val_loader, train_grad, cur_lr, use_reg, args, train_loader, train_loss_func, optimizer)
                    hyper_optimizer.step()

                    weight_norm = get_hyper_train_flat(args, model,
                                                       augment_net,
                                                       reweighting_net).norm()
                    total_val_loss += val_loss.item()
                    hyper_num += 1

            # Replace the original gradient for the elementary optimizer step.
            '''
            current_index = 0
            flat_train_grad = gather_flat_grad(train_grad)
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                # if p.grad is not None:
                p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape)
                current_index += p_num_params
            optimizer.step()
            '''

            iteration += 1

            # Calculate running average of accuracy
            if args.do_classification:
                pred = torch.max(pred.data, 1)[1]
                total += labels.size(0)
                correct += (pred == labels.data).sum().item()
                accuracy = correct / total
            else:
                total = 1
                accuracy = 0

            if args.do_print:
                progress_bar.set_postfix(
                    train='%.4f' % (xentropy_loss_avg / (i + 1)),
                    val='%.4f' % (total_val_loss / max(hyper_num, 1)),
                    acc='%.4f' % accuracy,
                    weight='%.3f' % weight_norm,
                    update='%.3f' % grad_norm)
            if i % (num_tune_hyper**2) == 0:
                if args.use_augment_net:
                    if args.do_diagnostic:
                        save_images(images, labels, augment_net, args)
                if not args.do_simple or args.do_inverse_compare:
                    if not args.do_simple:
                        save_models(epoch, model, optimizer, augment_net,
                                    reweighting_net, hyper_optimizer,
                                    args.save_loc)
                    val_loss, val_acc = test(val_loader, args, model,
                                             augment_net, device)
                    csv_logger.writerow(epoch, xentropy_loss_avg / (i + 1),
                                        accuracy, val_loss, val_acc, test_loss,
                                        test_acc, hypergradient_cos_diff,
                                        hypergradient_l2_diff,
                                        time.time() - init_time, iteration)
        if use_scheduler:
            scheduler.step(epoch)
        if use_hyper_scheduler:
            hyper_scheduler.step(epoch)
        train_loss = xentropy_loss_avg / (i + 1)

        if not args.only_print_final_vals:
            val_loss, val_acc = test(val_loader, args, model, augment_net,
                                     device)
            # if val_acc >= 0.99 and accuracy >= 0.99 and epoch >= 50: break
            test_loss, test_acc = test(test_loader, args, model, augment_net,
                                       device)
            tqdm.write(
                'epoch: {:d} | val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'
                .format(epoch, val_loss, val_acc, test_loss, test_acc))

            csv_logger.writerow(epoch, train_loss, accuracy, val_loss, val_acc,
                                test_loss, test_acc, hypergradient_cos_diff,
                                hypergradient_l2_diff,
                                time.time() - init_time, iteration)
        elif args.do_print:
            val_loss, val_acc = test(val_loader,
                                     args,
                                     model,
                                     augment_net,
                                     device,
                                     do_test_augment=False)
            tqdm.write('val loss: {:6.4f} | val acc: {:6.4f}'.format(
                val_loss, val_acc))

    val_loss, val_acc = test(val_loader, args, model, augment_net, device)
    test_loss, test_acc = test(test_loader, args, model, augment_net, device)
    save_models(args.num_finetune_epochs, model, optimizer, augment_net,
                reweighting_net, hyper_optimizer, args.save_loc)
    return train_loss, accuracy, val_loss, val_acc, test_loss, test_acc
def graph_final_multiple_args(argss,
                              ylabels,
                              name_supplement='',
                              xmin=0,
                              legend_loc=None,
                              fontsize=12,
                              handlelength=None):
    fig, axs = init_ax(fontsize=fontsize, nrows=len(ylabels))

    cg_data = {ylabel: [] for ylabel in ylabels}
    neumann_data = {ylabel: [] for ylabel in ylabels}
    for args in argss:
        temp_args = copy.deepcopy(args)
        temp_args.seed = float(
            1)  # A copy for storing the id of a set of different seeds

        from train_augment_net_multiple import get_id
        args.save_loc = finetune_location + get_id(args)
        try:
            args_data = load_from_csv(args.save_loc, do_val=True, do_test=True)
        except FileNotFoundError:
            print(f"Can't load {args.save_loc}")
            break

        for ylabel in ylabels:
            smoothed_data = smooth_data(args_data[ylabel],
                                        num_smooth=10)[xmin:]
            print(smoothed_data[-1])
            if args.use_cg:
                cg_data[ylabel] += [smoothed_data[-1]]
            else:
                neumann_data[ylabel] += [smoothed_data[-1]]

    num_smooth = 10
    for label_index, ylabel in enumerate(ylabels):
        # axs[label_index].set_ylabel(ylabel)
        if ylabel == 'hypergradient_l2_diff':
            axs[label_index].semilogy(range(len(cg_data[ylabel])),
                                      smooth_data(cg_data[ylabel], num_smooth),
                                      label='CG',
                                      linestyle='--',
                                      linewidth=linewidth)
            axs[label_index].semilogy(range(len(neumann_data[ylabel])),
                                      smooth_data(neumann_data[ylabel],
                                                  num_smooth),
                                      label='Neumann',
                                      linestyle='--',
                                      linewidth=linewidth)
        else:
            axs[label_index].plot(range(len(cg_data[ylabel])),
                                  smooth_data(cg_data[ylabel], num_smooth),
                                  label='CG',
                                  linestyle='--',
                                  linewidth=linewidth)
            axs[label_index].plot(range(len(neumann_data[ylabel])),
                                  smooth_data(neumann_data[ylabel],
                                              num_smooth),
                                  label='Neumann',
                                  linestyle='--',
                                  linewidth=linewidth)

    for label_index, ylabel in enumerate(ylabels):
        # if label_index % len(xlabels) == 0:  # Only label the left size with y-labels
        #    axs[label_index].set_ylabel(ylabel)
        # else:
        axs[label_index].set_yticks([])

        if label_index < (len(ylabels) - 1):
            axs[label_index].set_xticks([])

        if ylabel == 'hypergradient_cos_diff':
            axs[label_index].set_ylim([0.0, 1.0])
        elif ylabel == 'hypergradient_l2_diff':
            axs[label_index].set_ylim([10e-4, 10e3])

    axs = [
        setup_ax(ax,
                 alpha=0.75,
                 fontsize=fontsize,
                 legend_loc=legend_loc,
                 handlelength=handlelength) for ax in axs[-1:]
    ]

    name = f"./images/graph_final_multiple_args_{name_supplement}"
    fig.savefig(name + ".pdf", bbox_inches='tight')
    plt.close(fig)
def graph_multiple_args(argss,
                        ylabels,
                        name_supplement='',
                        xmin=0,
                        legend_loc=None,
                        fontsize=12,
                        handlelength=None):
    xlabels = ['iteration']  # , 'run_time']  # , 'epoch']

    fig, axs = init_ax(fontsize=fontsize,
                       ncols=len(xlabels),
                       nrows=len(ylabels))

    color_dict = {}
    data_dict = {ylabel: {} for ylabel in ylabels}
    min_xlim = 10e32
    for args in argss:
        temp_args = copy.deepcopy(args)
        temp_args.seed = float(
            1)  # A copy for storing the id of a set of different seeds

        from train_augment_net_multiple import get_id
        args.save_loc = finetune_location + get_id(args)
        try:
            args_data = load_from_csv(args.save_loc, do_val=True, do_test=True)
        except FileNotFoundError:
            print(f"Can't load {args.save_loc}")
            break
        label_index = 0
        for ylabel in ylabels:
            all_smoothed_data = []
            for xlabel in xlabels:
                smoothed_data = smooth_data(args_data[ylabel],
                                            num_smooth=20)[xmin:]
                if args.seed == 1:
                    label = get_id(args)
                    if ylabel == 'hypergradient_l2_diff' or ylabel == 'hypergradient_cos_diff':
                        if args.use_cg:
                            label = str(args.num_neumann_terms) + ' CG Steps'
                        else:
                            if args.num_neumann_terms == 1:
                                label = str(
                                    args.num_neumann_terms) + ' Neumann'
                            else:
                                label = str(
                                    args.num_neumann_terms) + ' Neumann'
                    if ylabel == 'hypergradient_l2_diff':
                        plot = axs[label_index].semilogy(
                            args_data[xlabel][xmin:],
                            smoothed_data,
                            label=label,
                            alpha=1.0,
                            linewidth=linewidth)
                    else:
                        plot = axs[label_index].plot(args_data[xlabel][xmin:],
                                                     smoothed_data,
                                                     label=label,
                                                     alpha=1.0,
                                                     linewidth=linewidth)
                    plot = plot[0]
                    color_dict[get_id(args)] = plot.get_color()
                    data_dict[ylabel][get_id(args)] = []
                else:
                    if ylabel == 'hypergradient_l2_diff':
                        plot = axs[label_index].semilogy(
                            args_data[xlabel][xmin:],
                            smoothed_data,
                            alpha=1.0,
                            color=color_dict[get_id(temp_args)],
                            linewidth=linewidth)
                    else:
                        plot = axs[label_index].plot(
                            args_data[xlabel][xmin:],
                            smoothed_data,
                            alpha=1.0,
                            color=color_dict[get_id(temp_args)],
                            linewidth=linewidth)
                # smoothed_data = smooth_data(args_data[ylabel], num_smooth=20)
                '''all_smoothed_data += [smoothed_data]
                if args.seed == 1:
                    data_dict[ylabel][get_id(args)] = []
                elif args.seed == 3:
                    all_smoothed_data = np.array(all_smoothed_data)
                    mean = np.mean(all_smoothed_data, axis=0)
                    std = np.std(all_smoothed_data, axis=0)
                    plot = axs[label_index].errorbar(args_data[xlabel],
                                                 mean, std,
                                                 label=get_id(args), alpha=0.5)'''
                min_xlim = min(min_xlim, args_data[xlabel][-1])
                # if args.seed == 1:
                #    data_dict[ylabel][get_id(args)] = []
                data_dict[ylabel][get_id(temp_args)] += [smoothed_data[-1]]
                if False:  # ylabel[:3] == 'val' and xlabel == 'iteration':
                    diagnostic = f"y: {ylabel}"
                    # diagnostic += f", x: {xlabel}"
                    diagnostic += f", num_neumann: {args.num_neumann_terms}"
                    diagnostic += f", reg: {args.reg_weight}"
                    diagnostic += f", cg: {args.use_cg}"
                    diagnostic += f", seed: {args.seed}"
                    diagnostic += f", final value: {args_data[ylabel][-1]}"
                    print(diagnostic)

                label_index += 1
    # TODO: Store in array based on seed.
    # TODO: Compute mean and std_dev for each method
    # print(data_dict)
    label_index = 0
    for ylabel in data_dict:
        print(f"Value: {ylabel}")
        for id in data_dict[ylabel]:
            # mean = np.mean(data_dict[ylabel][id])
            # print(mean)
            # entry = np.asarray([np.array(x) for x in pre_entry])
            print(f"    ID: {id}")
            print(
                f"        mean = {np.mean(data_dict[ylabel][id]):.4f}, std = {np.std(data_dict[ylabel][id]):.4f}"
            )
            print(
                f"        max = {np.max(data_dict[ylabel][id]):.4f}, min = {np.min(data_dict[ylabel][id]):.4f}"
            )
        label_index += 1

    label_index = 0
    for ylabel in ylabels:
        for xlabel in xlabels:
            if label_index % len(
                    xlabels) == 0:  # Only label the left size with y-labels
                # axs[label_index].set_ylabel(ylabel)
                pass
            else:
                axs[label_index].set_yticks([])

            if label_index > (len(ylabels) - 1) * (
                    len(xlabels)) - 1:  # Only label the bottom with x-labels
                # axs[label_index].set_xlabel(xlabel)
                pass
            else:
                axs[label_index].set_xticks([])

            if ylabel[-3:] == 'acc':
                axs[label_index].set_ylim([0.92, 0.94])  # [.93, .965])
            elif ylabel[:3] in ['val', 'tes'] and ylabel[-4:] == 'loss':
                axs[label_index].set_ylim([.23, .35])  # [.17, .25])
            elif ylabel[-4:] == 'loss':
                axs[label_index].set_ylim([0.11, 0.17])
            elif ylabel == 'hypergradient_cos_diff':
                axs[label_index].set_ylim([0.0, 1.0])
            elif ylabel == 'hypergradient_l2_diff':
                axs[label_index].set_ylim([10e-4, 10e3])

            # if xmin is not None:
            #
            #    axs[label_index].set_xlim([xmin, axs[label_index].get_xlim()[-1]])
            label_index += 1

    axs = [
        setup_ax(ax,
                 alpha=0.75,
                 fontsize=fontsize,
                 legend_loc=legend_loc,
                 handlelength=handlelength) for ax in axs[-1:]
    ]

    name = f"./images/graph_multiple_args_{name_supplement}"
    fig.savefig(name + ".pdf", bbox_inches='tight')
    plt.close(fig)
Exemple #6
0
def experiment(args):
    # Setup the random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # Load the baseline model
    args.load_baseline_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug1.pt'
    args.load_finetune_checkpoint = None  # TODO: Make it load the augment net if this is provided
    model, train_loader, val_loader, test_loader, augment_net, reweighting_net, checkpoint = get_models(args)

    # Load the logger
    from train_augment_net_multiple import load_logger, get_id
    csv_logger, test_id = load_logger(args)
    args.save_loc = './finetuned_checkpoints/' + get_id(args)

    # Hyperparameter access functions
    def get_hyper_train():
        # return torch.cat([p.view(-1) for p in augment_net.parameters()])
        if args.use_augment_net and args.use_reweighting_net:
            return list(augment_net.parameters()) + list(reweighting_net.parameters())
        elif args.use_augment_net:
            return augment_net.parameters()
        elif args.use_reweighting_net:
            return reweighting_net.parameters()

    def get_hyper_train_flat():
        if args.use_augment_net and args.use_reweighting_net:
            return torch.cat([torch.cat([p.view(-1) for p in augment_net.parameters()]),
                              torch.cat([p.view(-1) for p in reweighting_net.parameters()])])
        elif args.use_reweighting_net:
            return torch.cat([p.view(-1) for p in reweighting_net.parameters()])
        elif args.use_augment_net:
            return torch.cat([p.view(-1) for p in augment_net.parameters()])

    # Setup the optimizers
    if args.load_baseline_checkpoint is not None:
        args.lr = args.lr * 0.2 * 0.2 * 0.2
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.wdecay)
    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)  # [60, 120, 160]
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    hyper_optimizer = optim.Adam(get_hyper_train(), lr=1e-3)  # Adam(get_hyper_train())
    hyper_scheduler = MultiStepLR(hyper_optimizer, milestones=[40, 100, 140], gamma=0.2)

    graph_iter = 0
    def train_loss_func(x, y):
        x, y = x.cuda(), y.cuda()
        reg = 0.

        if args.use_augment_net:
            # old_x = x
            x = augment_net(x, class_label=y)
            '''num_sample = 10
            xs =torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda()
            for i in range(num_sample):
                xs[i] = augment_net(x, class_label=y)
            xs_diffs = (torch.mean(xs, dim=0) - old_x) ** 2
            diff_loss = torch.mean(xs_diffs)
            entrop_loss = -torch.mean(torch.std(xs, dim=0) ** 2)
            reg = 10 * diff_loss + entrop_loss'''

        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y, reduction='none')

        if args.use_reweighting_net:
            loss_weights = reweighting_net(x)  # TODO: Or reweighting_net(augment_x) ??
            loss_weights = loss_weights.squeeze()
            loss_weights = F.sigmoid(loss_weights / 10.0 ) * 2.0 + 0.1
            # loss_weights = (loss_weights - torch.mean(loss_weights)) / torch.std(loss_weights)
            # loss_weights = F.softmax(loss_weights)
            # loss_weights = loss_weights * args.batch_size
            # TODO: Want loss_weight vs x_entropy_loss

            nonlocal graph_iter
            if graph_iter % 100 == 0:
                import matplotlib.pyplot as plt
                np_loss = xentropy_loss.data.cpu().numpy()
                np_weight = loss_weights.data.cpu().numpy()
                for i in range(10):
                    class_indices = (y == i).cpu().numpy()
                    class_indices = [val*ind for val, ind in enumerate(class_indices) if val != 0]
                    plt.scatter(np_loss[class_indices], np_weight[class_indices], alpha=0.5, label=str(i))
                # plt.scatter((xentropy_loss*loss_weights).data.cpu().numpy(), loss_weights.data.cpu().numpy(), alpha=0.5, label='weighted')
                # print(np_loss)
                plt.ylim([np.min(np_weight) / 2.0, np.max(np_weight) * 2.0])
                plt.xlim([np.min(np_loss) / 2.0, np.max(np_loss) * 2.0])
                plt.yscale('log')
                plt.xscale('log')
                plt.axhline(1.0, c='k')
                plt.ylabel("loss_weights")
                plt.xlabel("xentropy_loss")
                plt.legend()
                plt.savefig("images/aaaa_lossWeightvsEntropy.pdf")
                plt.clf()

            xentropy_loss = xentropy_loss * loss_weights
        graph_iter += 1

        xentropy_loss = xentropy_loss.mean() + reg
        return xentropy_loss, pred

    use_reg = args.use_augment_net
    reg_anneal_epoch = 0
    stop_reg_epoch = 200
    if args.reg_weight == 0:
        use_reg = False

    def val_loss_func(x, y):
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)

        reg = 0
        if args.use_augment_net:
            if use_reg:
                num_sample = 10
                xs = torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda()
                for i in range(num_sample):
                    xs[i] = augment_net(x, class_label=y)
                xs_diffs = (torch.abs(torch.mean(xs, dim=0) - x))
                diff_loss = torch.mean(xs_diffs)
                stds = torch.std(xs, dim=0)
                entrop_loss = -torch.mean(stds)
                # TODO : Remember to add direct grad back in to hyper_step
                reg = args.reg_weight * (diff_loss + entrop_loss)
            else:
                reg = 0

        # reg *= (args.num_finetune_epochs - reg_anneal_epoch) / (args.num_finetune_epochs + 2)
        if reg_anneal_epoch >= stop_reg_epoch:
            reg *= 0
        return xentropy_loss + reg

    def test(loader, do_test_augment=True, num_augment=10):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct, total = 0., 0.
        losses = []
        for images, labels in loader:
            images, labels = images.cuda(), labels.cuda()

            with torch.no_grad():
                pred = model(images)
                if do_test_augment:
                    if args.use_augment_net and args.num_neumann_terms >= 0:
                        shape_0, shape_1 = pred.shape[0], pred.shape[1]
                        pred = pred.view(1, shape_0, shape_1)  # Batch size, num_classes
                        for _ in range(num_augment):
                            pred = torch.cat((pred, model(augment_net(images)).view(1, shape_0, shape_1)))
                        pred = torch.mean(pred, dim=0)
                xentropy_loss = F.cross_entropy(pred, labels)
                losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    # print(f"Initial Val Loss: {test(val_loader)}")
    # print(f"Initial Test Loss: {test(test_loader)}")

    init_time = time.time()
    val_loss, val_acc = test(val_loader)
    test_loss, test_acc = test(test_loader)
    print(f"Initial Val Loss: {val_loss, val_acc}")
    print(f"Initial Test Loss: {test_loss, test_acc}")
    iteration = 0
    for epoch in range(0, args.num_finetune_epochs):
        reg_anneal_epoch = epoch
        xentropy_loss_avg = 0.
        total_val_loss, val_loss = 0., 0.
        correct = 0.
        total = 0.
        weight_norm, grad_norm = .0, .0

        progress_bar = tqdm(train_loader)
        num_tune_hyper = 45000 / 5000  # 1/5th the val data as train data
        hyper_num = 0
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            xentropy_loss, pred = train_loss_func(images, labels)  # F.cross_entropy(pred, labels)
            xentropy_loss_avg += xentropy_loss.item()

            current_index = 0
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                if p.grad is not None:
                    p.grad = p.grad * 0
                current_index += p_num_params
            # optimizer.zero_grad()
            train_grad = grad(xentropy_loss, model.parameters(), create_graph=True)  #

            if args.num_neumann_terms >= 0:  # if this is less than 0, then don't do hyper_steps
                if i % num_tune_hyper == 0:
                    cur_lr = 1.0
                    for param_group in optimizer.param_groups:
                        cur_lr = param_group['lr']
                        break
                    val_loss, grad_norm = hyper_step(get_hyper_train, model, val_loss_func, val_loader,
                                                     train_grad, cur_lr, use_reg, args)
                    hyper_optimizer.step()

                    weight_norm = get_hyper_train_flat().norm()
                    total_val_loss += val_loss.item()
                    hyper_num += 1

            # Replace the original gradient for the elementary optimizer step.
            current_index = 0
            flat_train_grad = gather_flat_grad(train_grad)
            for p in model.parameters():
                p_num_params = np.prod(p.shape)
                # if p.grad is not None:
                p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape)
                current_index += p_num_params
            optimizer.step()

            iteration += 1

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.4f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / max(hyper_num, 1)),
                acc='%.4f' % accuracy,
                weight='%.3f' % weight_norm,
                update='%.3f' % grad_norm
            )
            if i % (num_tune_hyper ** 2) == 0 and args.use_augment_net:
                from train_augment_net_graph import save_images
                if args.do_diagnostic:
                    save_images(images, labels, augment_net, args)
                saver(epoch, model, optimizer, augment_net, reweighting_net, hyper_optimizer, args.save_loc)
                val_loss, val_acc = test(val_loader)
                csv_logger.writerow({'epoch': str(epoch),
                                     'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
                                     'val_loss': str(val_loss), 'val_acc': str(val_acc),
                                     'test_loss': str(test_loss), 'test_acc': str(test_acc),
                                     'run_time': time.time() - init_time,
                                     'iteration': iteration})

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)  # , hyper_scheduler.step(epoch)
        csv_logger.writerow({'epoch': str(epoch),
                             'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
                             'val_loss': str(val_loss), 'val_acc': str(val_acc),
                             'test_loss': str(test_loss), 'test_acc': str(test_acc),
                             'run_time': time.time() - init_time, 'iteration': iteration})