예제 #1
0
    ])
elif args.dataset == 'svhn':
    # the WRN paper does no augmentation on SVHN
    # obviously flipping is a bad idea, and it makes some sense not to
    # crop because there are a lot of distractor digits in the edges of the
    # image
    transform_train = transforms.ToTensor()

transform_test = transforms.Compose([transforms.ToTensor()])
# ------------------------------------------------------------------------------

# ----------------- DATASET WITH AUX PSEUDO-LABELED DATA -----------------------
trainset = SemiSupervisedDataset(base_dataset=args.dataset,
                                 add_svhn_extra=args.svhn_extra,
                                 root=args.data_dir,
                                 train=True,
                                 transform=transform_train,
                                 aux_data_filename=args.aux_data_filename,
                                 add_aux_labels=not args.remove_pseudo_labels,
                                 aux_take_amount=args.aux_take_amount)

# num_batches=50000 enforces the definition of an "epoch" as passing through 50K
# datapoints
# TODO: make sure that this code works also when trainset.unsup_indices=[]
train_batch_sampler = SemiSupervisedSampler(
    trainset.sup_indices,
    trainset.unsup_indices,
    args.batch_size,
    args.unsup_fraction,
    num_batches=int(np.ceil(50000 / args.batch_size)))
# epoch_size = len(train_batch_sampler) * args.batch_size
예제 #2
0
    results_dir = os.path.join(output_dir, args.output_suffix)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)

    logging.info('Attack evaluation')
    logging.info('Args: %s' % args)

    # settings
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    dl_kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # set up data loader
    transform_test = transforms.Compose([transforms.ToTensor(), ])
    testset = SemiSupervisedDataset(base_dataset=args.dataset,
                                    train=False, root='data',
                                    download=True,
                                    transform=transform_test)

    if args.shuffle_testset:
        np.random.seed(123)
        logging.info("Permuting testset")
        permutation = np.random.permutation(len(testset))
        testset.data = testset.data[permutation, :]
        testset.targets = [testset.targets[i] for i in permutation]

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.batch_size,
                                              shuffle=False, **dl_kwargs)

    checkpoint = torch.load(args.model_path)
    state_dict = checkpoint.get('state_dict', checkpoint)
    base_classifier = torch.nn.DataParallel(base_classifier).cuda()
    # setting loader to be non-strict so we can load Cohen et al.'s model
    base_classifier.load_state_dict(state_dict,
                                    strict=(args.model != 'resnet-110'))

    # create the smooothed classifier g
    smoothed_classifier = Smooth(base_classifier, num_classes, args.sigma)

    # iterate through the dataset
    transform_test = transforms.ToTensor()
    # dataset = datasets.CIFAR10(root='data', train=False,
    #                            download=True,
    #                            transform=transform_test)
    dataset = SemiSupervisedDataset(base_dataset=args.dataset,
                                    train=False,
                                    root=args.data_dir,
                                    download=True,
                                    transform=transform_test)

    # Shuffling the dataset if random seed is not None
    if args.random_seed is not None:
        np.random.seed(args.random_seed)
        np.random.shuffle(dataset.targets)
        np.random.seed(args.random_seed)
        np.random.shuffle(dataset.data)
        filename = args.output_name + '_seed_' + str(args.random_seed) + '.csv'
    else:
        filename = args.output_name + '.csv'

    if os.path.exists(os.path.join(output_dir, filename)):
        logging.info('Output file exists, resuming...')
if args.no_aug:
    # Override
    transform_train = transforms.ToTensor()

transform_test = transforms.Compose([transforms.ToTensor()])

trainset = SemiSupervisedDataset(
    base_dataset=args.dataset,
    downsample=args.train_downsample,
    take_fraction=args.train_take_fraction,
    semisupervised=args.semisup,
    add_cifar100=args.add_cifar100,
    add_svhn_extra=False,
    sup_labels=args.sup_labels,
    unsup_labels=args.unsup_labels,
    root=args.data_dir,
    train=True,
    download=True,
    transform=transform_train,
    aux_data_filename=args.aux_data_filename,
    aux_targets_filename=args.aux_targets_filename,
    add_aux_labels=args.add_aux_labels,
    aux_label_noise=args.aux_label_noise,
    aux_take_amount=args.aux_take_amount,
    take_amount_seed=args.take_amount_seed)

if args.semisup or args.add_cifar100 or (args.aux_data_filename is not None):
    # the repeat option makes sure that the number of gradient steps per 'epoch'
    # is roughly the same as the number of gradient steps in an epoch over full
    # CIFAR-10
    train_batch_sampler = SemiSupervisedSampler(