예제 #1
0
                       dropRate=drop_rate)

    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0005)

    scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=gamma)

    try:
        checkpoint_fpath = 'cifar-10/cifar10_wideresnet79.pt'
        checkpoint = torch.load(checkpoint_fpath)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = MultiStepLR(optimizer,
                                milestones=LR_MILESTONES,
                                gamma=0.2,
                                last_epoch=checkpoint['epoch'])
        begin = checkpoint['epoch']
        # print('test_acc :', checkpoint['test_acc'], 'train_acc :', checkpoint['train_acc'])
        # print('last_lr :', checkpoint['scheduler']['_last_lr'])
    except FileNotFoundError:
        # print('starting over..')
        begin = -1

    best_acc = 0
    for epoch in range(epochs):
        if epoch <= begin:
예제 #2
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Constructing Model

    if args.resume != "":
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            test_only = args.test_only
            resume = args.resume
            args = checkpoint["opt"]
            args.test_only = test_only
            args.resume = resume
        else:
            checkpoint = None
            print("=> No checkpoint found at '{}'".format(args.resume))

    model = WideResNet(args.depth, args.widen_factor, args.dropout_rate,
                       args.num_classes)

    if torch.cuda.is_available():
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=args.gpu)

    if args.resume != "":
        model.load_state_dict(checkpoint["model"])
        args.start_epoch = checkpoint["epoch"] + 1
        print("=> Loaded successfully '{}' (epoch {})".format(
            args.resume, checkpoint["epoch"]))
        del checkpoint
        torch.cuda.empty_cache()
    else:
        model.apply(conv_init)

    # Loading Dataset

    if args.augment == "meanstd":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])
    elif args.augment == "zac":
        # To Do: ZCA whitening
        pass
    else:
        raise NotImplementedError

    print("| Preparing CIFAR-10 dataset...")
    sys.stdout.write("| ")
    trainset = CIFAR10(root="./data",
                       train=True,
                       download=True,
                       transform=transform_train)
    testset = CIFAR10(root="./data",
                      train=False,
                      download=False,
                      transform=transform_test)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=2)

    # Test only

    if args.test_only:
        if args.resume != "":
            test(args, test_loader, model)
            sys.exit(0)
        else:
            print("=> Test only model need to resume from a checkpoint")
            raise RuntimeError

    train(args, train_loader, test_loader, model)
    test(args, test_loader, model)
def experiment():
    parser = argparse.ArgumentParser(description='CNN Hyperparameter Fine-tuning')
    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'],
                        help='Choose a dataset')
    parser.add_argument('--model', default='resnet18', choices=['resnet18', 'wideresnet'],
                        help='Choose a model')
    parser.add_argument('--num_finetune_epochs', type=int, default=200,
                        help='Number of fine-tuning epochs')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='sgdm',
                        help='Choose an optimizer')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Mini-batch size')
    parser.add_argument('--data_augmentation', action='store_true', default=True,
                        help='Whether to use data augmentation')
    parser.add_argument('--wdecay', type=float, default=5e-4,
                        help='Amount of weight decay')
    parser.add_argument('--load_checkpoint', type=str,
                        help='Path to pre-trained checkpoint to load and finetune')
    parser.add_argument('--save_dir', type=str, default='finetuned_checkpoints',
                        help='Save directory for the fine-tuned checkpoint')
    args = parser.parse_args()
    args.load_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug0.pt'

    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    test_id = '{}_{}_{}_lr{}_wd{}_aug{}'.format(args.dataset, args.model, args.optimizer, args.lr, args.wdecay,
                                                int(args.data_augmentation))
    filename = os.path.join(args.save_dir, test_id + '.csv')
    csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'],
        filename=filename)

    checkpoint = torch.load(args.load_checkpoint)
    init_epoch = checkpoint['epoch']
    cnn.load_state_dict(checkpoint['model_state_dict'])
    model = cnn.cuda()
    model.train()

    args.hyper_train = 'augment'  # 'all_weight'  # 'weight'

    def init_hyper_train(model):
        """

        :return:
        """
        init_hyper = None
        if args.hyper_train == 'weight':
            init_hyper = np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor([init_hyper]).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        elif args.hyper_train == 'all_weight':
            num_p = sum(p.numel() for p in model.parameters())
            weights = np.ones(num_p) * np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor(weights).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        model = model.cuda()
        return init_hyper

    if args.hyper_train == 'augment':  # Dont do inside the prior function, else scope is wrong
        augment_net = UNet(in_channels=3,
                           n_classes=3,
                           depth=5,
                           wf=6,
                           padding=True,
                           batch_norm=False,
                           up_mode='upconv')  # TODO(PV): Initialize UNet properly
        augment_net = augment_net.cuda()

    def get_hyper_train():
        """

        :return:
        """
        if args.hyper_train == 'weight' or args.hyper_train == 'all_weight':
            return [model.weight_decay]
        if args.hyper_train == 'augment':
            return augment_net.parameters()

    def get_hyper_train_flat():
        return torch.cat([p.view(-1) for p in get_hyper_train()])

    # TODO: Check this size

    init_hyper_train(model)

    if args.hyper_train == 'all_weight':
        wdecay = 0.0
    else:
        wdecay = args.wdecay
    optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.2 * 0.2, momentum=0.9, nesterov=True,
                          weight_decay=wdecay)  # args.wdecay)
    # print(checkpoint['optimizer_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler = MultiStepLR(optimizer, milestones=[60, 120], gamma=0.2)  # [60, 120, 160]
    hyper_optimizer = torch.optim.Adam(get_hyper_train(), lr=1e-3)  # try 0.1 as lr

    # Set random regularization hyperparameters
    # data_augmentation_hparams = {}  # Random values for hue, saturation, brightness, contrast, rotation, etc.
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    def test(loader):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct = 0.
        total = 0.
        losses = []
        for images, labels in loader:
            images = images.cuda()
            labels = labels.cuda()

            with torch.no_grad():
                pred = model(images)

            xentropy_loss = F.cross_entropy(pred, labels)
            losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    def prepare_data(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = x.cuda(), y.cuda()

        # x, y = Variable(x), Variable(y)
        return x, y

    def train_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)

        reg_loss = 0.0
        if args.hyper_train == 'weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
            for p in model.parameters():
                # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
                # print(f"shape: {p.shape}")
                reg_loss = reg_loss + .5 * (model.weight_decay ** 2) * torch.sum(p ** 2)
                # print(f"reg_loss: {reg_loss}")
        elif args.hyper_train == 'all_weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            count = 0
            for p in model.parameters():
                reg_loss = reg_loss + .5 * torch.sum(
                    (model.weight_decay[count: count + p.numel()] ** 2) * torch.flatten(p ** 2))
                count += p.numel()
        elif args.hyper_train == 'augment':
            augmented_x = augment_net(x)
            pred = model(augmented_x)
            xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss + reg_loss, pred

    def val_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss

    for epoch in range(init_epoch, init_epoch + args.num_finetune_epochs):
        xentropy_loss_avg = 0.
        total_val_loss = 0.
        correct = 0.
        total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            # TODO: Take a hyperparameter step here
            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            val_loss, weight_norm, grad_norm = hyper_step(1, 1, get_hyper_train, get_hyper_train_flat,
                                                                model, val_loss_func,
                                                                val_loader, train_loss_func, train_loader,
                                                                hyper_optimizer)
            # del val_loss
            # print(f"hyper: {get_hyper_train()}")

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            # xentropy_loss = F.cross_entropy(pred, labels)
            xentropy_loss, pred = train_loss_func(images, labels)

            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            xentropy_loss.backward()
            optimizer.step()

            xentropy_loss_avg += xentropy_loss.item()

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.5f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / (i + 1)),
                acc='%.4f' % accuracy,
                weight='%.2f' % weight_norm,
                update='%.3f' % grad_norm)

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)

        row = {'epoch': str(epoch),
               'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
               'val_loss': str(val_loss), 'val_acc': str(val_acc),
               'test_loss': str(test_loss), 'test_acc': str(test_acc)}
        csv_logger.writerow(row)