Exemplo n.º 1
0
def load_mod_dl(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation,
            subset=[args.train_size, args.val_size, args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(
            args.batch_size,
            subset=[args.train_size, args.val_size, args.test_size],
            num_train=num_train,
            only_split_train=False)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)

    # This essentially does no mixup.
    mixup_mat = -100 * torch.ones([num_classes, num_classes]).cuda()

    checkpoint = None
    if args.load_checkpoint:
        checkpoint = torch.load(args.load_checkpoint)
        mixup_mat = checkpoint['mixup_grid']
        print(f"loaded mixupmat from {args.load_checkpoint}")

        if args.rand_mixup:
            # Randomise mixup grid
            rng = np.random.RandomState(args.seed)
            mixup_mat = rng.uniform(
                0.5, 1.0, (num_classes, num_classes)).astype(np.float32)
            print("Randomised the mixup mat")
        mixup_mat = torch.from_numpy(
            mixup_mat.reshape(num_classes, num_classes)).cuda()

    model = cnn.cuda()
    model.train()

    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
Exemplo n.º 2
0
def load_mod_dl(args):
    """
    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        imsize, in_channel, num_classes = 32, 3, 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation,
                                                                          subset=[args.train_size, args.val_size,
                                                                                  args.test_size])
    elif args.dataset == 'cifar100':
        imsize, in_channel, num_classes = 32, 3, 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation,
                                                                           subset=[args.train_size, args.val_size,
                                                                                   args.test_size])
    elif args.dataset == 'mnist':
        imsize, in_channel, num_classes = 28, 1, 10
        num_train = 50000
        train_loader, val_loader, test_loader = data_loaders.load_mnist(args.batch_size,
                                                           subset=[args.train_size, args.val_size, args.test_size],
                                                           num_train=num_train, only_split_train=False)


    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes, num_channels=in_channel)
    elif args.model == 'cbr':
        cnn = CBRStudent(in_channel, num_classes)
        
    mixup_mat = -1*torch.ones([num_classes,num_classes]).cuda()
    mixup_mat.requires_grad = True

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, mixup_mat, train_loader, val_loader, test_loader, checkpoint
Exemplo n.º 3
0
def load_baseline_model(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)
    elif args.dataset == 'mnist':
        args.datasize, args.valsize, args.testsize = 100, 100, 100
        num_train = args.datasize
        if args.datasize == -1:
            num_train = 50000

        from data_loaders import load_mnist
        train_loader, val_loader, test_loader = load_mnist(args.batch_size,
                                                           subset=[args.datasize, args.valsize, args.testsize],
                                                           num_train=num_train)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, train_loader, val_loader, test_loader, checkpoint
    train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)
if args.cutout:
    train_transform.transforms.append(
        Cutout(n_holes=args.n_holes, length=args.length))

test_transform = transforms.Compose([transforms.ToTensor(), normalize])

if args.dataset == 'cifar10':
    num_classes = 10
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)
elif args.dataset == 'cifar100':
    num_classes = 100
    train_loader, val_loader, test_loader = data_loaders.load_cifar100(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)

if args.model == 'resnet18':
    cnn = ResNet18(num_classes=num_classes)
elif args.model == 'wideresnet':
    cnn = WideResNet(depth=28,
                     num_classes=num_classes,
                     widen_factor=10,
                     dropRate=0.3)

cnn = cnn.cuda()
criterion = nn.CrossEntropyLoss().cuda()

if args.optimizer == 'sgdm':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.lr,
Exemplo n.º 5
0
def cnn_val_loss(config={}, reporter=None, callback=None, return_all=False):
    print("Starting cnn_val_loss...")

    ###############################################################################
    # Arguments
    ###############################################################################
    dataset_options = ['cifar10', 'cifar100', 'fashion']

    ## Tuning parameters: all of the dropouts
    parser = argparse.ArgumentParser(description='CNN')
    parser.add_argument('--dataset',
                        default='cifar10',
                        choices=dataset_options,
                        help='Choose a dataset (cifar10, cifar100)')
    parser.add_argument(
        '--model',
        default='resnet32',
        choices=['resnet32', 'wideresnet', 'simpleconvnet'],
        help='Choose a model (resnet32, wideresnet, simpleconvnet)')

    #### Optimization hyperparameters
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='Input batch size for training (default: 128)')
    parser.add_argument('--epochs',
                        type=int,
                        default=int(config['epochs']),
                        help='Number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=float(config['lr']),
                        help='Learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=float(config['momentum']),
                        help='Nesterov momentum')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=float(config['lr_decay']),
                        help='Factor by which to multiply the learning rate.')

    # parser.add_argument('--weight_decay', type=float, default=float(config['weight_decay']),
    #                     help='Amount of weight decay to use.')
    # parser.add_argument('--dropout', type=float, default=config['dropout'] if 'dropout' in config else 0.0,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout1', type=float, default=config['dropout1'] if 'dropout1' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout2', type=float, default=config['dropout2'] if 'dropout2' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    # parser.add_argument('--dropout3', type=float, default=config['dropout3'] if 'dropout3' in config else -1,
    #                     help='Amount of dropout for wideresnet')
    parser.add_argument('--dropout_type',
                        type=str,
                        default=config['dropout_type'],
                        help='Type of dropout (bernoulli or gaussian)')

    # Data augmentation hyperparameters
    parser.add_argument(
        '--inscale',
        type=float,
        default=0 if 'inscale' not in config else config['inscale'],
        help='defines input scaling factor')
    parser.add_argument('--hue',
                        type=float,
                        default=0. if 'hue' not in config else config['hue'],
                        help='hue jitter rate')
    parser.add_argument(
        '--brightness',
        type=float,
        default=0. if 'brightness' not in config else config['brightness'],
        help='brightness jitter rate')
    parser.add_argument(
        '--saturation',
        type=float,
        default=0. if 'saturation' not in config else config['saturation'],
        help='saturation jitter rate')
    parser.add_argument(
        '--contrast',
        type=float,
        default=0. if 'contrast' not in config else config['contrast'],
        help='contrast jitter rate')

    # Weight decay and dropout hyperparameters for each layer
    parser.add_argument(
        '--weight_decays',
        type=str,
        default='0.0',
        help=
        'Amount of weight decay to use for each layer, represented as a comma-separated string of floats.'
    )
    parser.add_argument(
        '--dropouts',
        type=str,
        default='0.0',
        help=
        'Dropout rates for each layer, represented as a comma-separated string of floats'
    )

    parser.add_argument(
        '--nonmono',
        '-nonm',
        type=int,
        default=60,
        help='how many previous epochs to consider for nonmonotonic criterion')
    parser.add_argument(
        '--patience',
        type=int,
        default=75,
        help=
        'How long to wait for the val loss to improve before early stopping.')

    parser.add_argument(
        '--data_augmentation',
        action='store_true',
        default=config['data_augmentation'],
        help='Augment data by cropping and horizontal flipping')

    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        help='how many steps before logging stats from training set')
    parser.add_argument(
        '--valid_log_interval',
        type=int,
        default=50,
        help='how many steps before logging stats from validations set')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--save',
                        action='store_true',
                        default=False,
                        help='whether to save current run')
    parser.add_argument('--seed',
                        type=int,
                        default=11,
                        help='random seed (default: 11)')
    parser.add_argument(
        '--save_dir',
        default=config['save_dir'],
        help=
        'subdirectory of logdir/savedir to save in (default changes to date/time)'
    )

    args, unknown = parser.parse_known_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    cudnn.benchmark = True  # Should make training should go faster for large models

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print(args)
    sys.stdout.flush()

    # args.dropout1 = args.dropout1 if args.dropout1 != -1 else args.dropout
    # args.dropout2 = args.dropout2 if args.dropout2 != -1 else args.dropout
    # args.dropout3 = args.dropout3 if args.dropout3 != -1 else args.dropout

    ###############################################################################
    # Saving
    ###############################################################################
    timestamp = '{:%Y-%m-%d}'.format(datetime.datetime.now())
    random_hash = random.getrandbits(16)
    exp_name = '{}-dset:{}-model:{}-seed:{}-hash:{}'.format(
        timestamp, args.dataset, args.model,
        args.seed if args.seed else 'None', random_hash)

    dropout_rates = [float(value) for value in args.dropouts.split(',')]
    weight_decays = [float(value) for value in args.weight_decays.split(',')]

    # Create log folder
    BASE_SAVE_DIR = 'experiments'
    save_dir = os.path.join(BASE_SAVE_DIR, args.save_dir, exp_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Check whether the result.csv file exists already
    if os.path.exists(os.path.join(save_dir, 'result.csv')):
        if not args.overwrite:
            print(
                'The result file {} exists! Run with --overwrite to overwrite this experiment.'
                .format(os.path.join(save_dir, 'result.csv')))
            sys.exit(0)

    # Save command-line arguments
    with open(os.path.join(save_dir, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f)

    epoch_csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'],
        filename=os.path.join(save_dir, 'epoch_log.csv'))

    ###############################################################################
    # Data Loading/Model/Optimizer
    ###############################################################################

    if args.dataset == 'cifar10':
        train_loader, valid_loader, test_loader = data_loaders.load_cifar10(
            args,
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_loader, valid_loader, test_loader = data_loaders.load_cifar100(
            args,
            args.batch_size,
            val_split=True,
            augmentation=args.data_augmentation)
        num_classes = 100
    elif args.dataset == 'fashion':
        train_loader, valid_loader, test_loader = data_loaders.load_fashion_mnist(
            args.batch_size, val_split=True)
        num_classes = 10

    if args.model == 'resnet32':
        cnn = resnet_cifar.resnet32(dropRates=dropout_rates)
    elif args.model == 'wideresnet':
        cnn = wide_resnet.WideResNet(depth=16,
                                     num_classes=num_classes,
                                     widen_factor=8,
                                     dropRates=dropout_rates,
                                     dropType=args.dropout_type)
        # cnn = wide_resnet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=args.dropout)
    elif args.model == 'simpleconvnet':
        cnn = models.SimpleConvNet(dropType=args.dropout_type,
                                   conv_drop1=args.dropout1,
                                   conv_drop2=args.dropout2,
                                   fc_drop=args.dropout3)

    def optim_parameters(model):
        module_list = [
            m for m in model.modules()
            if type(m) == nn.Linear or type(m) == nn.Conv2d
        ]
        weight_decays = [1e-4] * len(module_list)
        return [{
            'params': layer.parameters(),
            'weight_decay': wdecay
        } for (layer, wdecay) in zip(module_list, weight_decays)]

    cnn = cnn.to(device)
    criterion = nn.CrossEntropyLoss()
    # cnn_optimizer = torch.optim.SGD(cnn.parameters(),
    #                                 lr=args.lr,
    #                                 momentum=args.momentum,
    #                                 nesterov=True,
    #                                 weight_decay=args.weight_decay)
    cnn_optimizer = torch.optim.SGD(optim_parameters(cnn),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)

    ###############################################################################
    # Training/Evaluation
    ###############################################################################
    def evaluate(loader):
        """Returns the loss and accuracy on the entire validation/test set."""
        cnn.eval()
        correct = total = loss = 0.
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                pred = cnn(images)
                loss += F.cross_entropy(pred, labels, reduction='sum').item()
                hard_pred = torch.max(pred, 1)[1]
                total += labels.size(0)
                correct += (hard_pred == labels).sum().item()

        accuracy = correct / total
        mean_loss = loss / total
        cnn.train()
        return mean_loss, accuracy

    epoch = 1
    global_step = 0
    patience_elapsed = 0
    stored_loss = 1e8
    best_val_loss = []
    start_time = time.time()

    # This is based on the schedule used for WideResNets. The gamma (decay factor) can also be 0.2 (= 5x decay)
    # Right now we're not using the scheduler because we use nonmonotonic lr decay (based on validation performance)
    # scheduler = MultiStepLR(cnn_optimizer, milestones=[60,120,160], gamma=args.lr_decay)

    while epoch < args.epochs + 1 and patience_elapsed < args.patience:

        running_xentropy = correct = total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Epoch ' + str(epoch))
            images, labels = images.to(device), labels.to(device)

            if args.inscale > 0:
                noise = torch.rand(images.size(0), device=device)
                scaled_noise = (
                    (1 + args.inscale) -
                    (1 / (1 + args.inscale))) * noise + (1 /
                                                         (1 + args.inscale))
                images = images * scaled_noise[:, None, None, None]

            # images = F.dropout(images, p=args.indropout, training=True)  # TODO: Incorporate input dropout
            cnn.zero_grad()
            pred = cnn(images)

            xentropy_loss = criterion(pred, labels)
            xentropy_loss.backward()
            cnn_optimizer.step()

            running_xentropy += xentropy_loss.item()

            # Calculate running average of accuracy
            _, hard_pred = torch.max(pred, 1)
            total += labels.size(0)
            correct += (hard_pred == labels).sum().item()
            accuracy = correct / float(total)

            global_step += 1
            progress_bar.set_postfix(
                xentropy='%.3f' % (running_xentropy / (i + 1)),
                acc='%.3f' % accuracy,
                lr='%.3e' % cnn_optimizer.param_groups[0]['lr'])

        val_loss, val_acc = evaluate(valid_loader)
        print('Val loss: {:6.4f} | Val acc: {:6.4f}'.format(val_loss, val_acc))
        sys.stdout.flush()
        stats = {
            'global_step': global_step,
            'time': time.time() - start_time,
            'loss': val_loss,
            'acc': val_acc
        }
        # logger.write('valid', stats)

        if (len(best_val_loss) > args.nonmono
                and val_loss > min(best_val_loss[:-args.nonmono])):
            cnn_optimizer.param_groups[0]['lr'] *= args.lr_decay
            print('Decaying the learning rate to {}'.format(
                cnn_optimizer.param_groups[0]['lr']))
            sys.stdout.flush()

        if val_loss < stored_loss:
            with open(os.path.join(save_dir, 'best_checkpoint.pt'), 'wb') as f:
                torch.save(cnn.state_dict(), f)
            print('Saving model (new best validation)')
            sys.stdout.flush()
            stored_loss = val_loss
            patience_elapsed = 0
        else:
            patience_elapsed += 1

        best_val_loss.append(val_loss)

        # scheduler.step(epoch)

        avg_xentropy = running_xentropy / (i + 1)
        train_acc = correct / float(total)

        if callback is not None:
            callback(epoch, avg_xentropy, train_acc, val_loss, val_acc, config)

        if reporter is not None:
            reporter(timesteps_total=epoch, mean_loss=val_loss)

        if cnn_optimizer.param_groups[0][
                'lr'] < 1e-7:  # Another stopping criterion based on decaying the lr
            break

        epoch += 1

        epoch_row = {
            'epoch': str(epoch),
            'train_loss': avg_xentropy,
            'train_acc': str(train_acc),
            'val_loss': str(val_loss),
            'val_acc': str(val_acc)
        }
        epoch_csv_logger.writerow(epoch_row)

    # Load best model and run on test
    with open(os.path.join(save_dir, 'best_checkpoint.pt'), 'rb') as f:
        cnn.load_state_dict(torch.load(f))

    train_loss = avg_xentropy
    train_acc = correct / float(total)

    # Run on val and test data.
    val_loss, val_acc = evaluate(valid_loader)
    test_loss, test_acc = evaluate(test_loader)

    print('=' * 89)
    print(
        '| End of training | trn loss: {:8.5f} | trn acc {:8.5f} | val loss {:8.5f} | val acc {:8.5f} | test loss {:8.5f} | test acc {:8.5f}'
        .format(train_loss, train_acc, val_loss, val_acc, test_loss, test_acc))
    print('=' * 89)
    sys.stdout.flush()

    # Save the final val and test performance to a results CSV file
    with open(os.path.join(save_dir, 'result_{}.csv'.format(time.time())),
              'w') as result_file:
        result_writer = csv.DictWriter(result_file,
                                       fieldnames=[
                                           'train_loss', 'train_acc',
                                           'val_loss', 'val_acc', 'test_loss',
                                           'test_acc'
                                       ])
        result_writer.writeheader()
        result_writer.writerow({
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        })
        result_file.flush()

    if return_all:
        print("RETURNING ", train_loss, train_acc, val_loss, val_acc,
              test_loss, test_acc)
        sys.stdout.flush()
        return train_loss, train_acc, val_loss, val_acc, test_loss, test_acc
    else:
        print("RETURNING ", stored_loss)
        sys.stdout.flush()
        return stored_loss
def experiment():
    parser = argparse.ArgumentParser(description='CNN Hyperparameter Fine-tuning')
    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'],
                        help='Choose a dataset')
    parser.add_argument('--model', default='resnet18', choices=['resnet18', 'wideresnet'],
                        help='Choose a model')
    parser.add_argument('--num_finetune_epochs', type=int, default=200,
                        help='Number of fine-tuning epochs')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='sgdm',
                        help='Choose an optimizer')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Mini-batch size')
    parser.add_argument('--data_augmentation', action='store_true', default=True,
                        help='Whether to use data augmentation')
    parser.add_argument('--wdecay', type=float, default=5e-4,
                        help='Amount of weight decay')
    parser.add_argument('--load_checkpoint', type=str,
                        help='Path to pre-trained checkpoint to load and finetune')
    parser.add_argument('--save_dir', type=str, default='finetuned_checkpoints',
                        help='Save directory for the fine-tuned checkpoint')
    args = parser.parse_args()
    args.load_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug0.pt'

    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    test_id = '{}_{}_{}_lr{}_wd{}_aug{}'.format(args.dataset, args.model, args.optimizer, args.lr, args.wdecay,
                                                int(args.data_augmentation))
    filename = os.path.join(args.save_dir, test_id + '.csv')
    csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'],
        filename=filename)

    checkpoint = torch.load(args.load_checkpoint)
    init_epoch = checkpoint['epoch']
    cnn.load_state_dict(checkpoint['model_state_dict'])
    model = cnn.cuda()
    model.train()

    args.hyper_train = 'augment'  # 'all_weight'  # 'weight'

    def init_hyper_train(model):
        """

        :return:
        """
        init_hyper = None
        if args.hyper_train == 'weight':
            init_hyper = np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor([init_hyper]).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        elif args.hyper_train == 'all_weight':
            num_p = sum(p.numel() for p in model.parameters())
            weights = np.ones(num_p) * np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor(weights).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        model = model.cuda()
        return init_hyper

    if args.hyper_train == 'augment':  # Dont do inside the prior function, else scope is wrong
        augment_net = UNet(in_channels=3,
                           n_classes=3,
                           depth=5,
                           wf=6,
                           padding=True,
                           batch_norm=False,
                           up_mode='upconv')  # TODO(PV): Initialize UNet properly
        augment_net = augment_net.cuda()

    def get_hyper_train():
        """

        :return:
        """
        if args.hyper_train == 'weight' or args.hyper_train == 'all_weight':
            return [model.weight_decay]
        if args.hyper_train == 'augment':
            return augment_net.parameters()

    def get_hyper_train_flat():
        return torch.cat([p.view(-1) for p in get_hyper_train()])

    # TODO: Check this size

    init_hyper_train(model)

    if args.hyper_train == 'all_weight':
        wdecay = 0.0
    else:
        wdecay = args.wdecay
    optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.2 * 0.2, momentum=0.9, nesterov=True,
                          weight_decay=wdecay)  # args.wdecay)
    # print(checkpoint['optimizer_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler = MultiStepLR(optimizer, milestones=[60, 120], gamma=0.2)  # [60, 120, 160]
    hyper_optimizer = torch.optim.Adam(get_hyper_train(), lr=1e-3)  # try 0.1 as lr

    # Set random regularization hyperparameters
    # data_augmentation_hparams = {}  # Random values for hue, saturation, brightness, contrast, rotation, etc.
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    def test(loader):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct = 0.
        total = 0.
        losses = []
        for images, labels in loader:
            images = images.cuda()
            labels = labels.cuda()

            with torch.no_grad():
                pred = model(images)

            xentropy_loss = F.cross_entropy(pred, labels)
            losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    def prepare_data(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = x.cuda(), y.cuda()

        # x, y = Variable(x), Variable(y)
        return x, y

    def train_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)

        reg_loss = 0.0
        if args.hyper_train == 'weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
            for p in model.parameters():
                # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
                # print(f"shape: {p.shape}")
                reg_loss = reg_loss + .5 * (model.weight_decay ** 2) * torch.sum(p ** 2)
                # print(f"reg_loss: {reg_loss}")
        elif args.hyper_train == 'all_weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            count = 0
            for p in model.parameters():
                reg_loss = reg_loss + .5 * torch.sum(
                    (model.weight_decay[count: count + p.numel()] ** 2) * torch.flatten(p ** 2))
                count += p.numel()
        elif args.hyper_train == 'augment':
            augmented_x = augment_net(x)
            pred = model(augmented_x)
            xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss + reg_loss, pred

    def val_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss

    for epoch in range(init_epoch, init_epoch + args.num_finetune_epochs):
        xentropy_loss_avg = 0.
        total_val_loss = 0.
        correct = 0.
        total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            # TODO: Take a hyperparameter step here
            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            val_loss, weight_norm, grad_norm = hyper_step(1, 1, get_hyper_train, get_hyper_train_flat,
                                                                model, val_loss_func,
                                                                val_loader, train_loss_func, train_loader,
                                                                hyper_optimizer)
            # del val_loss
            # print(f"hyper: {get_hyper_train()}")

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            # xentropy_loss = F.cross_entropy(pred, labels)
            xentropy_loss, pred = train_loss_func(images, labels)

            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            xentropy_loss.backward()
            optimizer.step()

            xentropy_loss_avg += xentropy_loss.item()

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.5f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / (i + 1)),
                acc='%.4f' % accuracy,
                weight='%.2f' % weight_norm,
                update='%.3f' % grad_norm)

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)

        row = {'epoch': str(epoch),
               'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
               'val_loss': str(val_loss), 'val_acc': str(val_acc),
               'test_loss': str(test_loss), 'test_acc': str(test_acc)}
        csv_logger.writerow(row)