Exemplo n.º 1
0
def main():

    global best_prec1
    best_prec1 = 0

    global val_acc
    val_acc = []

    global class_num

    class_num = args.dataset == 'cifar10' and 10 or 100

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        if args.autoaugment:
            print('Autoaugment')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                Cutout(n_holes=args.n_holes, length=args.length),
                normalize,
            ])

        elif args.cutout:
            print('Cutout')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                Cutout(n_holes=args.n_holes, length=args.length),
                normalize,
            ])

        else:
            print('Standrad Augmentation!')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=training_configurations[args.model]['batch_size'],
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=training_configurations[args.model]['batch_size'],
        shuffle=True,
        **kwargs)

    # create model
    if args.model == 'resnet':
        model = eval('networks.resnet.resnet' + str(args.layers) +
                     '_cifar')(dropout_rate=args.droprate)
    elif args.model == 'se_resnet':
        model = eval('networks.se_resnet.resnet' + str(args.layers) +
                     '_cifar')(dropout_rate=args.droprate)
    elif args.model == 'wideresnet':
        model = networks.wideresnet.WideResNet(args.layers,
                                               args.dataset == 'cifar10' and 10
                                               or 100,
                                               args.widen_factor,
                                               dropRate=args.droprate)
    elif args.model == 'se_wideresnet':
        model = networks.se_wideresnet.WideResNet(
            args.layers,
            args.dataset == 'cifar10' and 10 or 100,
            args.widen_factor,
            dropRate=args.droprate)

    elif args.model == 'densenet_bc':
        model = networks.densenet_bc.DenseNet(
            growth_rate=args.growth_rate,
            block_config=(int((args.layers - 4) / 6), ) * 3,
            compression=args.compression_rate,
            num_init_features=24,
            bn_size=args.bn_size,
            drop_rate=args.droprate,
            small_inputs=True,
            efficient=False)
    elif args.model == 'shake_pyramidnet':
        model = networks.shake_pyramidnet.PyramidNet(dataset=args.dataset,
                                                     depth=args.layers,
                                                     alpha=args.alpha,
                                                     num_classes=class_num,
                                                     bottleneck=True)

    elif args.model == 'resnext':
        if args.cardinality == 8:
            model = networks.resnext.resnext29_8_64(class_num)
        if args.cardinality == 16:
            model = networks.resnext.resnext29_16_64(class_num)

    elif args.model == 'shake_shake':
        if args.widen_factor == 112:
            model = networks.shake_shake.shake_resnet26_2x112d(class_num)
        if args.widen_factor == 32:
            model = networks.shake_shake.shake_resnet26_2x32d(class_num)
        if args.widen_factor == 96:
            model = networks.shake_shake.shake_resnet26_2x32d(class_num)

    elif args.model == 'shake_shake_x':

        model = networks.shake_shake.shake_resnext29_2x4x64d(class_num)

    if not os.path.isdir(check_point):
        mkdir_p(check_point)

    fc = Full_layer(int(model.feature_num), class_num)

    print('Number of final features: {}'.format(int(model.feature_num)))

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()]) +
        sum([p.data.nelement() for p in fc.parameters()])))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    isda_criterion = ISDALoss(int(model.feature_num), class_num).cuda()
    ce_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(
        [{
            'params': model.parameters()
        }, {
            'params': fc.parameters()
        }],
        lr=training_configurations[args.model]['initial_learning_rate'],
        momentum=training_configurations[args.model]['momentum'],
        nesterov=training_configurations[args.model]['nesterov'],
        weight_decay=training_configurations[args.model]['weight_decay'])

    model = torch.nn.DataParallel(model).cuda()
    fc = nn.DataParallel(fc).cuda()

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        fc.load_state_dict(checkpoint['fc'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        isda_criterion = checkpoint['isda_criterion']
        val_acc = checkpoint['val_acc']
        best_prec1 = checkpoint['best_acc']
        np.savetxt(accuracy_file, np.array(val_acc))
    else:
        start_epoch = 0

    for epoch in range(start_epoch,
                       training_configurations[args.model]['epochs']):

        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, fc, isda_criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, fc, ce_criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'fc': fc.state_dict(),
                'best_acc': best_prec1,
                'optimizer': optimizer.state_dict(),
                'isda_criterion': isda_criterion,
                'val_acc': val_acc,
            },
            is_best,
            checkpoint=check_point)
        print('Best accuracy: ', best_prec1)
        np.savetxt(accuracy_file, np.array(val_acc))

    print('Best accuracy: ', best_prec1)
    print('Average accuracy', sum(val_acc[len(val_acc) - 10:]) / 10)
    # val_acc.append(sum(val_acc[len(val_acc) - 10:]) / 10)
    # np.savetxt(val_acc, np.array(val_acc))
    np.savetxt(accuracy_file, np.array(val_acc))
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_weak_100_compose = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

transform_strong_10_compose = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    CIFAR10Policy(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_strong_100_compose = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    CIFAR10Policy(),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
])

transform_strong_randaugment_10_compose = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
Exemplo n.º 3
0
def get_iters(dataset='CIFAR10',
              root_path='.',
              data_transforms=None,
              n_labeled=4000,
              valid_size=1000,
              l_batch_size=32,
              ul_batch_size=128,
              test_batch_size=256,
              workers=8,
              pseudo_label=None):

    train_path = '{}/data/{}/train/'.format(root_path, dataset)
    test_path = '{}/data/{}/test/'.format(root_path, dataset)

    if dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(train_path,
                                         download=True,
                                         train=True,
                                         transform=None)
        test_dataset = datasets.CIFAR10(test_path,
                                        download=True,
                                        train=False,
                                        transform=None)
    elif dataset == 'CIFAR100':
        train_dataset = datasets.CIFAR100(train_path,
                                          download=True,
                                          train=True,
                                          transform=None)
        test_dataset = datasets.CIFAR100(test_path,
                                         download=True,
                                         train=False,
                                         transform=None)
    else:
        raise ValueError

    if data_transforms is None:
        data_transforms = {
            'train':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
            'eval':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
        }

    x_train, y_train = train_dataset.train_data, np.array(
        train_dataset.train_labels)
    x_test, y_test = test_dataset.test_data, np.array(test_dataset.test_labels)

    randperm = np.random.permutation(len(x_train))
    labeled_idx = randperm[:n_labeled]
    validation_idx = randperm[n_labeled:n_labeled + valid_size]
    unlabeled_idx = randperm[n_labeled + valid_size:]

    x_labeled = x_train[labeled_idx]
    x_validation = x_train[validation_idx]
    x_unlabeled = x_train[unlabeled_idx]

    y_labeled = y_train[labeled_idx]
    y_validation = y_train[validation_idx]
    if pseudo_label is None:
        y_unlabeled = y_train[unlabeled_idx]
    else:
        assert isinstance(pseudo_label, np.ndarray)
        y_unlabeled = pseudo_label

    data_iterators = {
        'labeled':
        iter(
            DataLoader(
                SimpleDataset(x_labeled, y_labeled, data_transforms['train']),
                batch_size=l_batch_size,
                num_workers=workers,
                sampler=InfiniteSampler(len(x_labeled)),
            )),
        'unlabeled':
        iter(
            DataLoader(
                MultiDataset(x_unlabeled, y_unlabeled, data_transforms['eval'],
                             data_transforms['train']),
                batch_size=ul_batch_size,
                num_workers=workers,
                sampler=InfiniteSampler(len(x_unlabeled)),
            )),
        'val':
        iter(
            DataLoader(SimpleDataset(x_validation, y_validation,
                                     data_transforms['eval']),
                       batch_size=len(x_validation),
                       num_workers=workers,
                       shuffle=False)),
        'test':
        iter(
            DataLoader(SimpleDataset(x_test, y_test, data_transforms['eval']),
                       batch_size=test_batch_size,
                       num_workers=workers,
                       shuffle=False))
    }

    return data_iterators