Exemplo n.º 1
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Constructing Model

    if args.resume != "":
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            test_only = args.test_only
            resume = args.resume
            args = checkpoint["opt"]
            args.test_only = test_only
            args.resume = resume
        else:
            checkpoint = None
            print("=> No checkpoint found at '{}'".format(args.resume))

    model = WideResNet(args.depth, args.widen_factor, args.dropout_rate,
                       args.num_classes)

    if torch.cuda.is_available():
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=args.gpu)

    if args.resume != "":
        model.load_state_dict(checkpoint["model"])
        args.start_epoch = checkpoint["epoch"] + 1
        print("=> Loaded successfully '{}' (epoch {})".format(
            args.resume, checkpoint["epoch"]))
        del checkpoint
        torch.cuda.empty_cache()
    else:
        model.apply(conv_init)

    # Loading Dataset

    if args.augment == "meanstd":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])
    elif args.augment == "zac":
        # To Do: ZCA whitening
        pass
    else:
        raise NotImplementedError

    print("| Preparing CIFAR-10 dataset...")
    sys.stdout.write("| ")
    trainset = CIFAR10(root="./data",
                       train=True,
                       download=True,
                       transform=transform_train)
    testset = CIFAR10(root="./data",
                      train=False,
                      download=False,
                      transform=transform_test)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=2)

    # Test only

    if args.test_only:
        if args.resume != "":
            test(args, test_loader, model)
            sys.exit(0)
        else:
            print("=> Test only model need to resume from a checkpoint")
            raise RuntimeError

    train(args, train_loader, test_loader, model)
    test(args, test_loader, model)
Exemplo n.º 2
0
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=2)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    model = WideResNet(depth=depth,
                       num_classes=num_classes,
                       widen_factor=widen_factor,
                       dropRate=drop_rate)

    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0005)

    scheduler = MultiStepLR(optimizer, milestones=LR_MILESTONES, gamma=gamma)

    try:
        checkpoint_fpath = 'cifar-10/cifar10_wideresnet79.pt'
        checkpoint = torch.load(checkpoint_fpath)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = MultiStepLR(optimizer,
Exemplo n.º 3
0
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_WideResNet%s-%s' %(args.dataset, args.depth, args.width)
        if args.cutout:
            args.name += '_wCutout'
        if args.auto_augment:
            args.name += '_wAutoAugment'

    if not os.path.exists('models/%s' %args.name):
        os.makedirs('models/%s' %args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' %(arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' %args.name, 'w') as f:
        for arg in vars(args):
            print('%s: %s' %(arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' %args.name)

    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # data loading code
    if args.dataset == 'cifar10':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if args.auto_augment:
            transform_train.append(AutoAugment())
        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761)),
        ])

        train_set = datasets.CIFAR10(
            root='~/data',
            train=True,
            download=True,
            transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=128,
            shuffle=True,
            num_workers=8)

        test_set = datasets.CIFAR10(
            root='~/data',
            train=False,
            download=True,
            transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=128,
            shuffle=False,
            num_workers=8)

        num_classes = 10

    elif args.dataset == 'cifar100':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if args.auto_augment:
            transform_train.append(AutoAugment())
        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR100(
            root='~/data',
            train=True,
            download=True,
            transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=128,
            shuffle=True,
            num_workers=8)

        test_set = datasets.CIFAR100(
            root='~/data',
            train=False,
            download=True,
            transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=128,
            shuffle=False,
            num_workers=8)

        num_classes = 100

    # create model
    model = WideResNet(args.depth, args.width, num_classes=num_classes)
    model = model.cuda()

    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
            momentum=args.momentum, weight_decay=args.weight_decay)

    scheduler = lr_scheduler.MultiStepLR(optimizer,
            milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma)

    log = pd.DataFrame(index=[], columns=[
        'epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'
    ])

    best_acc = 0
    for epoch in range(args.epochs):
        print('Epoch [%d/%d]' %(epoch+1, args.epochs))

        scheduler.step()

        # train for one epoch
        train_log = train(args, train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        val_log = validate(args, test_loader, model, criterion)

        print('loss %.4f - acc %.4f - val_loss %.4f - val_acc %.4f'
            %(train_log['loss'], train_log['acc'], val_log['loss'], val_log['acc']))

        tmp = pd.Series([
            epoch,
            scheduler.get_lr()[0],
            train_log['loss'],
            train_log['acc'],
            val_log['loss'],
            val_log['acc'],
        ], index=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'])

        log = log.append(tmp, ignore_index=True)
        log.to_csv('models/%s/log.csv' %args.name, index=False)

        if val_log['acc'] > best_acc:
            torch.save(model.state_dict(), 'models/%s/model.pth' %args.name)
            best_acc = val_log['acc']
            print("=> saved best model")
def experiment():
    parser = argparse.ArgumentParser(description='CNN Hyperparameter Fine-tuning')
    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'],
                        help='Choose a dataset')
    parser.add_argument('--model', default='resnet18', choices=['resnet18', 'wideresnet'],
                        help='Choose a model')
    parser.add_argument('--num_finetune_epochs', type=int, default=200,
                        help='Number of fine-tuning epochs')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='sgdm',
                        help='Choose an optimizer')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Mini-batch size')
    parser.add_argument('--data_augmentation', action='store_true', default=True,
                        help='Whether to use data augmentation')
    parser.add_argument('--wdecay', type=float, default=5e-4,
                        help='Amount of weight decay')
    parser.add_argument('--load_checkpoint', type=str,
                        help='Path to pre-trained checkpoint to load and finetune')
    parser.add_argument('--save_dir', type=str, default='finetuned_checkpoints',
                        help='Save directory for the fine-tuned checkpoint')
    args = parser.parse_args()
    args.load_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug0.pt'

    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    test_id = '{}_{}_{}_lr{}_wd{}_aug{}'.format(args.dataset, args.model, args.optimizer, args.lr, args.wdecay,
                                                int(args.data_augmentation))
    filename = os.path.join(args.save_dir, test_id + '.csv')
    csv_logger = CSVLogger(
        fieldnames=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'],
        filename=filename)

    checkpoint = torch.load(args.load_checkpoint)
    init_epoch = checkpoint['epoch']
    cnn.load_state_dict(checkpoint['model_state_dict'])
    model = cnn.cuda()
    model.train()

    args.hyper_train = 'augment'  # 'all_weight'  # 'weight'

    def init_hyper_train(model):
        """

        :return:
        """
        init_hyper = None
        if args.hyper_train == 'weight':
            init_hyper = np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor([init_hyper]).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        elif args.hyper_train == 'all_weight':
            num_p = sum(p.numel() for p in model.parameters())
            weights = np.ones(num_p) * np.sqrt(args.wdecay)
            model.weight_decay = Variable(torch.FloatTensor(weights).cuda(), requires_grad=True)
            model.weight_decay = model.weight_decay.cuda()
        model = model.cuda()
        return init_hyper

    if args.hyper_train == 'augment':  # Dont do inside the prior function, else scope is wrong
        augment_net = UNet(in_channels=3,
                           n_classes=3,
                           depth=5,
                           wf=6,
                           padding=True,
                           batch_norm=False,
                           up_mode='upconv')  # TODO(PV): Initialize UNet properly
        augment_net = augment_net.cuda()

    def get_hyper_train():
        """

        :return:
        """
        if args.hyper_train == 'weight' or args.hyper_train == 'all_weight':
            return [model.weight_decay]
        if args.hyper_train == 'augment':
            return augment_net.parameters()

    def get_hyper_train_flat():
        return torch.cat([p.view(-1) for p in get_hyper_train()])

    # TODO: Check this size

    init_hyper_train(model)

    if args.hyper_train == 'all_weight':
        wdecay = 0.0
    else:
        wdecay = args.wdecay
    optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.2 * 0.2, momentum=0.9, nesterov=True,
                          weight_decay=wdecay)  # args.wdecay)
    # print(checkpoint['optimizer_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler = MultiStepLR(optimizer, milestones=[60, 120], gamma=0.2)  # [60, 120, 160]
    hyper_optimizer = torch.optim.Adam(get_hyper_train(), lr=1e-3)  # try 0.1 as lr

    # Set random regularization hyperparameters
    # data_augmentation_hparams = {}  # Random values for hue, saturation, brightness, contrast, rotation, etc.
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)

    def test(loader):
        model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
        correct = 0.
        total = 0.
        losses = []
        for images, labels in loader:
            images = images.cuda()
            labels = labels.cuda()

            with torch.no_grad():
                pred = model(images)

            xentropy_loss = F.cross_entropy(pred, labels)
            losses.append(xentropy_loss.item())

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        avg_loss = float(np.mean(losses))
        acc = correct / total
        model.train()
        return avg_loss, acc

    def prepare_data(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = x.cuda(), y.cuda()

        # x, y = Variable(x), Variable(y)
        return x, y

    def train_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)

        reg_loss = 0.0
        if args.hyper_train == 'weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
            for p in model.parameters():
                # print(f"weight_decay: {torch.exp(model.weight_decay).shape}")
                # print(f"shape: {p.shape}")
                reg_loss = reg_loss + .5 * (model.weight_decay ** 2) * torch.sum(p ** 2)
                # print(f"reg_loss: {reg_loss}")
        elif args.hyper_train == 'all_weight':
            pred = model(x)
            xentropy_loss = F.cross_entropy(pred, y)
            count = 0
            for p in model.parameters():
                reg_loss = reg_loss + .5 * torch.sum(
                    (model.weight_decay[count: count + p.numel()] ** 2) * torch.flatten(p ** 2))
                count += p.numel()
        elif args.hyper_train == 'augment':
            augmented_x = augment_net(x)
            pred = model(augmented_x)
            xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss + reg_loss, pred

    def val_loss_func(x, y):
        """

        :param x:
        :param y:
        :return:
        """
        x, y = prepare_data(x, y)
        pred = model(x)
        xentropy_loss = F.cross_entropy(pred, y)
        return xentropy_loss

    for epoch in range(init_epoch, init_epoch + args.num_finetune_epochs):
        xentropy_loss_avg = 0.
        total_val_loss = 0.
        correct = 0.
        total = 0.

        progress_bar = tqdm(train_loader)
        for i, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description('Finetune Epoch ' + str(epoch))

            # TODO: Take a hyperparameter step here
            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            val_loss, weight_norm, grad_norm = hyper_step(1, 1, get_hyper_train, get_hyper_train_flat,
                                                                model, val_loss_func,
                                                                val_loader, train_loss_func, train_loader,
                                                                hyper_optimizer)
            # del val_loss
            # print(f"hyper: {get_hyper_train()}")

            images, labels = images.cuda(), labels.cuda()
            # pred = model(images)
            # xentropy_loss = F.cross_entropy(pred, labels)
            xentropy_loss, pred = train_loss_func(images, labels)

            optimizer.zero_grad(), hyper_optimizer.zero_grad()
            xentropy_loss.backward()
            optimizer.step()

            xentropy_loss_avg += xentropy_loss.item()

            # Calculate running average of accuracy
            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels.data).sum().item()
            accuracy = correct / total

            progress_bar.set_postfix(
                train='%.5f' % (xentropy_loss_avg / (i + 1)),
                val='%.4f' % (total_val_loss / (i + 1)),
                acc='%.4f' % accuracy,
                weight='%.2f' % weight_norm,
                update='%.3f' % grad_norm)

        val_loss, val_acc = test(val_loader)
        test_loss, test_acc = test(test_loader)
        tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format(
            val_loss, val_acc, test_loss, test_acc))

        scheduler.step(epoch)

        row = {'epoch': str(epoch),
               'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy),
               'val_loss': str(val_loss), 'val_acc': str(val_acc),
               'test_loss': str(test_loss), 'test_acc': str(test_acc)}
        csv_logger.writerow(row)
Exemplo n.º 5
0
        testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
        nclasses = 100

train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=8)

model = WideResNet(16, 8, .0, nclasses)

# enable sparsification for Conv and Linear layers above pthres. Initial pruning threshold is set lower than the requested. 
iter_sparsify(model, erfinv(.6 * args.starget) * math.sqrt(2), True, args.pthres)

# not adaptive, per-layer fixed, sparsity:
#iter_sparsify(model, erfinv(1-args.starget) * math.sqrt(2), False)

if args.cuda:
    model.cuda(gpu_id)

parameters, sparameters = [], []
for name, p in model.named_parameters():
    if ".r" in name:
        sparameters += [p]
    else:
        parameters += [p]

# ensuring convergence: slower lr on sparsity controlling pruning parameter (w/o weight decay)!
optimizer = optim.SGD([{"params":parameters}, {"params":sparameters, "lr":args.lr/100.0, "weight_decay":0}],
                      lr=args.lr, momentum=.9, weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.epochs//2, T_mult=1)