transforms.CenterCrop(opt.imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                        ]))
elif opt.dataset == 'cifar10':
    dataset = dset.CIFAR10(root=opt.dataroot,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Scale(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5)),
                           ]))
elif opt.dataset == 'fake':
    dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
                            transform=transforms.ToTensor())
assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batchSize,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3


# custom weights initialization called on netG and netD
def weights_init(m):
Beispiel #2
0
def get_datasets(name, root, cutout):

  if name == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std  = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif name == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std  = [x / 255 for x in [68.2, 65.4, 70.4]]
  elif name == 'fake':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std  = [x / 255 for x in [68.2, 65.4, 70.4]]
  elif name.startswith('imagenet-1k'):
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  elif name.startswith('imagenette'):
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  elif name.startswith('ImageNet16'):
    mean = [x / 255 for x in [122.68, 116.66, 104.01]]
    std  = [x / 255 for x in [63.22,  61.26 , 65.09]]
  else:
    raise TypeError("Unknow dataset : {:}".format(name))

  # Data Argumentation
  if name == 'cifar10' or name == 'cifar100':
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
    if cutout > 0 : lists += [CUTOUT(cutout)]
    train_transform = transforms.Compose(lists)
    test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    xshape = (1, 3, 32, 32)
  elif name == 'fake':
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
    if cutout > 0 : lists += [CUTOUT(cutout)]
    train_transform = transforms.Compose(lists)
    test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    xshape = (1, 3, 32, 32)
  elif name.startswith('ImageNet16'):
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
    if cutout > 0 : lists += [CUTOUT(cutout)]
    train_transform = transforms.Compose(lists)
    test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    xshape = (1, 3, 16, 16)
  elif name == 'tiered':
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
    if cutout > 0 : lists += [CUTOUT(cutout)]
    train_transform = transforms.Compose(lists)
    test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
    xshape = (1, 3, 32, 32)
  elif name.startswith('imagenette'):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    xlists = []
    xlists.append( transforms.ToTensor() )
    xlists.append( normalize )
    #train_transform = transforms.Compose(xlists)
    train_transform  = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
    test_transform  = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
    xshape = (1, 3, 224, 224)
  elif name.startswith('imagenet-1k'):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    if name == 'imagenet-1k':
      xlists    = [transforms.RandomResizedCrop(224)]
      xlists.append(
        transforms.ColorJitter(
        brightness=0.4,
        contrast=0.4,
        saturation=0.4,
        hue=0.2))
      xlists.append( Lighting(0.1))
    elif name == 'imagenet-1k-s':
      xlists    = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
    else: raise ValueError('invalid name : {:}'.format(name))
    xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
    xlists.append( transforms.ToTensor() )
    xlists.append( normalize )
    train_transform = transforms.Compose(xlists)
    test_transform  = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
    xshape = (1, 3, 224, 224)
  else:
    raise TypeError("Unknow dataset : {:}".format(name))

  if name == 'cifar10':
    train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
    test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
    assert len(train_data) == 50000 and len(test_data) == 10000
  elif name == 'cifar100':
    train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
    test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
    assert len(train_data) == 50000 and len(test_data) == 10000
  elif name == 'fake':
    train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform)
    test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform)
  elif name.startswith('imagenette2'):
    train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
    test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform)
  elif name.startswith('imagenet-1k'):
    train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
    test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform)
    assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
  elif name == 'ImageNet16':
    train_data = ImageNet16(root, True , train_transform)
    test_data  = ImageNet16(root, False, test_transform)
    assert len(train_data) == 1281167 and len(test_data) == 50000
  elif name == 'ImageNet16-120':
    train_data = ImageNet16(root, True , train_transform, 120)
    test_data  = ImageNet16(root, False, test_transform , 120)
    assert len(train_data) == 151700 and len(test_data) == 6000
  elif name == 'ImageNet16-150':
    train_data = ImageNet16(root, True , train_transform, 150)
    test_data  = ImageNet16(root, False, test_transform , 150)
    assert len(train_data) == 190272 and len(test_data) == 7500
  elif name == 'ImageNet16-200':
    train_data = ImageNet16(root, True , train_transform, 200)
    test_data  = ImageNet16(root, False, test_transform , 200)
    assert len(train_data) == 254775 and len(test_data) == 10000
  else: raise TypeError("Unknow dataset : {:}".format(name))
  
  class_num = Dataset2Class[name]
  return train_data, test_data, xshape, class_num
Beispiel #3
0
def load_data(dataset: str, iid: str):
    """Loads a dataset.

    :param dataset: Name of the dataset
    :param iid: True if the dataset must not be splitted by target value
    :return: Train dataset, test dataset
    """
    path_to_dataset = '{0}/dataset/'.format(get_path_to_datasets())
    if dataset == "fake":

        transform = transforms.ToTensor()

        train_data = datasets.FakeData(size=200, transform=transform)

        test_data = datasets.FakeData(size=200, transform=transform)

    elif dataset == 'cifar10':

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        train_data = datasets.CIFAR10(root=path_to_dataset, train=True, download=True, transform=transform_train)

        test_data = datasets.CIFAR10(root=path_to_dataset, train=False, download=True, transform=transform_test)

    elif dataset == 'mnist':

        # Normalization see : https://stackoverflow.com/a/67233938
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_data = datasets.MNIST(root=path_to_dataset, train=True, download=False, transform=transform)

        test_data = datasets.MNIST(root=path_to_dataset, train=False, download=False, transform=transform)

    elif dataset == "fashion_mnist":

        train_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        val_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # Download and load the training data
        train_data = datasets.FashionMNIST(path_to_dataset, download=True, train=True, transform=train_transforms)

        # Download and load the test data
        test_data = datasets.FashionMNIST(path_to_dataset, download=True, train=False, transform=val_transforms)

    elif dataset == "femnist":

        transform = transforms.Compose([transforms.ToTensor()])

        train_data = FEMNISTDataset(path_to_dataset, download=True, train=True, transform=transform)

        test_data = FEMNISTDataset(path_to_dataset, download=True, train=False, transform=transform)

    elif dataset == "a9a":

        train_data = A9ADataset(train=True, iid=iid)

        test_data = A9ADataset(train=False, iid=iid)

    elif dataset == "mushroom":

        train_data = MushroomDataset(train=True, iid=iid)

        test_data = MushroomDataset(train=False, iid=iid)

    elif dataset == "phishing":

        train_data = PhishingDataset(train=True, iid=iid)

        test_data = PhishingDataset(train=False, iid=iid)

    elif dataset == "quantum":
        train_data = QuantumDataset(train=True, iid=iid)

        test_data = QuantumDataset(train=False, iid=iid)

    return train_data, test_data
Beispiel #4
0
        return len(self.data)

    def __getitem__(self, index):
        data, target = self.data[index], self.target[index]
        return data, target


gradient_loader = torch.utils.data.DataLoader(GradientDataset(
    epsilon=1, datasets=MNIST_datasets, model=model),
                                              batch_size=1,
                                              shuffle=True)

# %% [markdown]
# ---
# ### 🤠🤠🤠Experiment of rubbish class examples
# ### and similarity whith adversarial exammples
rubbish_loader = torch.utils.data.DataLoader(datasets.FakeData(
    size=10000,
    image_size=(1, 28, 28),
    num_classes=10,
    transform=transforms.Compose([transforms.ToTensor()])),
                                             batch_size=1,
                                             shuffle=True)

print('Test in randomly generated images datasets')
get_rubbishExamples(model, rubbish_loader)
print('Test in gradient of MNIST dataset datasets')
get_rubbishExamples(model, gradient_loader)

# %%
Beispiel #5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = None

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            #model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            #model.cuda()
        else:
            model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss() #.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.data != None:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        pass
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        # you can add other transformations in this list
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    if args.data == None:
        train_dataset = datasets.FakeData(size=100000000, image_size=(
            3, 224, 224), num_classes=200, transform=transform)
        val_loader = torch.utils.data.DataLoader(datasets.FakeData(
            size=1001, image_size=(3, 224, 224), num_classes=200, transform=transform))
        pass
    else:
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

        pass

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(
            train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)
        if args.so_one_shot:
            sys.stdout.flush()
            return
def get_DataLoader(config):
    train_transforms = get_Transforms()

    if config.platform == 0:
        root = '/Volumes/scratch/work/falconr1/datasets/mtg-jamendo-dataset-master'
    elif config.platform == 2:
        root = '/scratch/work/falconr1/datasets/mtg-jamendo-dataset-master'
    elif config.platform == 3:
        root = '/m/cs/work/falconr1/datasets/mtg-jamendo-dataset-master'

    subset = config.subset
    split = 0
    mode = 'train'


    if config.dataset == 'JamendoSpecFolder':
        dataset = JamendoSpecFolder(root,
                                    subset,
                                    split,
                                    mode,
                                    spec_folder='data/processed/spec_npy',
                                    transform=train_transforms)

    elif config.dataset == 'JamendoSpecHDF5':
        dataset = JamendoSpecHDF5(root,
                                  subset,
                                  split,
                                  mode,
                                  train_transforms,
                                  hdf5_filename='data/processed/jamendo.hdf5')
    elif config.dataset == 'JamendoSpecLMDB':
        dataset = JamendoSpecLMDB(root,
                                  subset,
                                  split,
                                  mode,
                                  train_transforms,
                                  lmdb_path='data/processed/triton')
    elif config.dataset == 'JamendoSpecLMDBsubdir':
        dataset = JamendoSpecLMDBsubdir(root,
                                  subset,
                                  split,
                                  mode,
                                  train_transforms,
                                  lmdb_path='data/processed/chunks')
    elif config.dataset == 'fake':
        dataset = dset.FakeData(image_size=(1, 96, 1366),
                                transform=transforms.Compose([
                                    transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                    transforms.ToTensor()
                                ]))
    elif config.dataset == 'SVHN':
        dataset = dset.SVHN(root='/m/cs/work/falconr1/datasets/SVHN',
                            transform=transforms.Compose([
                                    transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                    transforms.ToTensor()
                                ]),
                            download=True)
    elif config.dataset == 'JamendoAudioFolder_torchaudio':
        dataset = JamendoAudioFolder_torchaudio(root,
                                              subset,
                                              split,
                                              mode,
                                              transform=transforms.Compose([
                                                  tforms.MelSpectrogram(sr=44100,
                                                                        n_fft=512,
                                                                        ws=256,
                                                                        hop=256,
                                                                        f_min=20.0,
                                                                        f_max=8000,
                                                                        pad=0,
                                                                        n_mels=96),
                                                  transforms.ToPILImage(),
                                                  transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                                  transforms.ToTensor(),
                                                ])
                                              )
    elif config.dataset == 'JamendoAudioFolder_audtorch':
        dataset = JamendoAudioFolder_audtorch(root,
                                              subset,
                                              split,
                                              mode,
                                              ## transform=tforms2.RandomCrop(size=256*44100),
                                              # transform=tforms2.Compose([
                                              #     tforms2.Downmix(1),
                                              #     tforms2.Normalize(),
                                              #     tforms2.Spectrogram(window_size=256,
                                              #                         hop_size=256,
                                              #                         fft_size=512),
                                              #     tforms2.Log(),
                                              #     # tforms2.LogSpectrogram(window_size=256,
                                              #     #                        hop_size=256,
                                              #     #                        normalize=True),
                                              #     myTforms.Debugger(),
                                              #     myTforms.CFL2FLC(),
                                              #     transforms.ToPILImage(),
                                              #     transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                              #     transforms.ToTensor(),
                                              #   ])
                                              )
    elif config.dataset == 'JamendoAudioFolder_npy':
        dataset = JamendoAudioFolder_npy(root,
                                         subset,
                                         split,
                                         mode,
                                         trim_to_size=config.trim_size,
                                         ###transform=tforms2.Downmix(1),
                                         #transform=tforms2.RandomCrop(size=30*44100),
                                         # transform=tforms2.Compose([
                                         #     tforms2.Downmix(1),
                                         #     tforms2.Normalize(),
                                         #     tforms2.Spectrogram(window_size=256,
                                         #                         hop_size=256,
                                         #                         fft_size=512),
                                         #     tforms2.Log(),
                                         #     # tforms2.LogSpectrogram(window_size=256,
                                         #     #                        hop_size=256,
                                         #     #                        normalize=True),
                                         #     myTforms.Debugger(),
                                         #     myTforms.CFL2FLC(),
                                         #     transforms.ToPILImage(),
                                         #     transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                         #     transforms.ToTensor(),
                                         #   ])
                                         )
    elif config.dataset == 'JamendoAudioFolder_torch':
        dataset = JamendoAudioFolder_torch(root,
                                         subset,
                                         split,
                                         mode,
                                         ###transform=tforms2.Downmix(1),
                                         transform=tforms2.RandomCrop(size=30*44100),
                                         # transform=tforms2.Compose([
                                         #     tforms2.Downmix(1),
                                         #     tforms2.Normalize(),
                                         #     tforms2.Spectrogram(window_size=256,
                                         #                         hop_size=256,
                                         #                         fft_size=512),
                                         #     tforms2.Log(),
                                         #     # tforms2.LogSpectrogram(window_size=256,
                                         #     #                        hop_size=256,
                                         #     #                        normalize=True),
                                         #     myTforms.Debugger(),
                                         #     myTforms.CFL2FLC(),
                                         #     transforms.ToPILImage(),
                                         #     transforms.RandomCrop((96, 256), pad_if_needed=True, padding_mode='reflect'),
                                         #     transforms.ToTensor(),
                                         #   ])
                                         )

    subset_indices = np.random.choice(range(len(dataset)), config.data_limit, replace=False)

    print('------ Dataset length = {}, using {} samples.'.format(len(dataset), len(subset_indices)))

    if config.collate_fn == 'seq2seq':
        collate = Seq2Seq([-1,-1], batch_first=None, sort_sequences=False)
        #collate = Seq2Seq_short([-1, -1], batch_first=None, sort_sequences=False)
    else:
        collate = torch.utils.data.dataloader.default_collate

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=config.batch_size,
                                             # shuffle=True,
                                             num_workers=config.num_workers,
                                             pin_memory=True,
                                             sampler=torch.utils.data.sampler.SubsetRandomSampler(subset_indices),
                                             collate_fn=collate,
                                             drop_last=True,
                                             )

    return dataloader
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset',
                        required=True,
                        help='cifar10 | lsun | imagenet | folder | lfw | fake')
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=64,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_gen', type=int, default=512)
    parser.add_argument('--nch_dis', type=int, default=512)
    parser.add_argument('--nepoch',
                        type=int,
                        default=1000,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--ngpu',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument('--gen',
                        default='',
                        help="path to gen (to continue training)")
    parser.add_argument('--dis',
                        default='',
                        help="path to dis (to continue training)")
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')
    parser.add_argument('--manualSeed', type=int, help='manual seed')

    args = parser.parse_args()
    print(args)

    try:
        os.makedirs(args.outf)
    except OSError:
        pass

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", args.manualSeed)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(args.imageSize),
                                       transforms.CenterCrop(args.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif args.dataset == 'lsun':
        dataset = dset.LSUN(root=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(args.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))  # [0, +1] -> [-1, +1]
    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))

    device = torch.device("cuda:0" if args.cuda else "cpu")
    nch_img = 3

    # custom weights initialization called on gen and dis
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            # m.bias.data.normal_(1.0, 0.02)
            # m.bias.data.fill_(0)

    gen = Generator(args.ngpu, args.nz, args.nch_gen, nch_img).to(device)
    gen.apply(weights_init)
    if args.gen != '':
        gen.load_state_dict(torch.load(args.gen))

    dis = Discriminator(args.ngpu, args.nch_dis, nch_img).to(device)
    dis.apply(weights_init)
    if args.dis != '':
        dis.load_state_dict(torch.load(args.dis))

    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    # fixed_z = torch.randn(args.batchSize, args.nz, 1, 1, device=device)
    fixed_z = torch.randn(8 * 8, args.nz, 1, 1, device=device)
    a_label = 0
    b_label = 1
    c_label = 1

    # setup optimizer
    optim_dis = optim.Adam(dis.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))
    optim_gen = optim.Adam(gen.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))

    for epoch in range(args.nepoch):
        for itr, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            dis.zero_grad()
            real_img = data[0].to(device)
            batch_size = real_img.size(0)
            label = torch.full((batch_size, ), b_label, device=device)

            dis_real = dis(real_img)
            loss_dis_real = criterion(dis_real, label)
            loss_dis_real.backward()

            # train with fake
            z = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_img = gen(z)
            label.fill_(a_label)

            dis_fake1 = dis(fake_img.detach())
            loss_dis_fake = criterion(dis_fake1, label)
            loss_dis_fake.backward()

            loss_dis = loss_dis_real + loss_dis_fake
            optim_dis.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            gen.zero_grad()
            label.fill_(c_label)  # fake labels are real for generator cost

            dis_fake2 = dis(fake_img)
            loss_gen = criterion(dis_fake2, label)
            loss_gen.backward()
            optim_gen.step()

            if (itr + 1) % 100 == 0:
                print(
                    '[{}/{}][{}/{}] LossD:{:.4f} LossG:{:.4f} D(x):{:.4f} D(G(z)):{:.4f}/{:.4f}'
                    .format(epoch + 1, args.nepoch, itr + 1, len(dataloader),
                            loss_dis.item(), loss_gen.item(),
                            dis_real.mean().item(),
                            dis_fake1.mean().item(),
                            dis_fake2.mean().item()))
            # loop end iteration

        if epoch == 0:
            vutils.save_image(real_img,
                              '{}/real_samples.png'.format(args.outf),
                              normalize=True)

        fake_img = gen(fixed_z)
        vutils.save_image(fake_img.detach(),
                          '{}/fake_samples_epoch_{:04}.png'.format(
                              args.outf, epoch),
                          normalize=True)

        # do checkpointing
        torch.save(gen.state_dict(),
                   '{}/gen_epoch_{}.pth'.format(args.outf, epoch))
        torch.save(dis.state_dict(),
                   '{}/dis_epoch_{}.pth'.format(args.outf, epoch))
def getData(dset_name, batch_size, data_transform):
    dataPath = "../data"
    os.makedirs(dataPath, exist_ok=True)
    if dset_name == "CIFAR10":
        trainset = dset.CIFAR10(dataPath,
                                train=True,
                                download=True,
                                transform=data_transform)
        testset = dset.CIFAR10(dataPath,
                               train=False,
                               download=True,
                               transform=data_transform)
    elif dset_name == "LSUN":
        trainset = dset.LSUN(dataPath,
                             train=True,
                             download=True,
                             transform=data_transform)
        testset = dset.LSUN(dataPath,
                            train=False,
                            download=True,
                            transform=data_transform)
    elif dset_name == "FakeData":
        trainset = dset.FakeData(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.FakeData(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "CocoCaptions":
        trainset = dset.CocoCaptions(dataPath,
                                     train=True,
                                     download=True,
                                     transform=data_transform)
        testset = dset.CocoCaptions(dataPath,
                                    train=False,
                                    download=True,
                                    transform=data_transform)
    elif dset_name == "MNIST":
        trainset = dset.MNIST(dataPath,
                              train=True,
                              download=True,
                              transform=data_transform)
        testset = dset.MNIST(dataPath,
                             train=False,
                             download=True,
                             transform=data_transform)
    elif dset_name == "CIFAR100":
        trainset = dset.CIFAR100(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.CIFAR100(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "SVHN":
        trainset = dset.SVHN(dataPath,
                             train=True,
                             download=True,
                             transform=data_transform)
        testset = dset.SVHN(dataPath,
                            train=False,
                            download=True,
                            transform=data_transform)
    elif dset_name == "Flickr8k":
        trainset = dset.Flickr8k(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.Flickr8k(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "Cityscapes":
        trainset = dset.Cityscapes(dataPath,
                                   train=True,
                                   download=True,
                                   transform=data_transform)
        testset = dset.Cityscapes(dataPath,
                                  train=False,
                                  download=True,
                                  transform=data_transform)
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True),\
           torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)
Beispiel #9
0
    random_index = np.load(data_path + '/random_index.npy')
    train_size = 1000
    train_Sampler = SubsetRandomSampler(random_index[range(train_size)])
    test_Sampler = SubsetRandomSampler(random_index[range(
        train_size, len(test_data))])
    Shuffle = False
elif args.dataset == 'fakedata':
    nh = 24
    nw = 24
    nc = 3
    num_class = 10
    end_epoch = 50
    train_size = 1000
    test_size = 1000
    train_data = datasets.FakeData(size=train_size + test_size,
                                   image_size=(nc, nh, nw),
                                   num_classes=num_class,
                                   transform=train_transform)
    test_data = train_data
    train_Sampler = SubsetRandomSampler(range(train_size))
    test_Sampler = SubsetRandomSampler(range(train_size, len(test_data)))
    Shuffle = False
else:
    print('specify dataset')
    exit()
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           sampler=train_Sampler,
                                           shuffle=Shuffle,
                                           **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_batch_size,
    def __init__(self,
                 id_dataset,
                 ood_dataset,
                 data_dir=f"{os.getenv('CC_HOME')}/resources/data",
                 individual_normalization=False,
                 training_sample_size=None,
                 test_sample_size=None):
        """
        @param id_dataset: in-domain-dataset name
        @param ood_dataset: out-of-domain-dataset name
        @param data_dir: directory where datasets should be stored
        @param individual_normalization: whether or not to normalize ood individually
        @param training_sample_size: maximum amount of training samples to provide in datasets
        @param test_sample_size: maximum amount of test samples to provide in datasets
        """
        os.makedirs(data_dir, exist_ok=True)

        IMAGE_SIZE = 32

        id_transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(NORMALIZATION_PARAMETERS[id_dataset]['mean'],
                                 NORMALIZATION_PARAMETERS[id_dataset]['std'])
        ])

        if individual_normalization:
            ood_transform = transforms.Compose([
                transforms.Resize(IMAGE_SIZE),
                transforms.CenterCrop((IMAGE_SIZE, IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(
                    NORMALIZATION_PARAMETERS[ood_dataset]['mean'],
                    NORMALIZATION_PARAMETERS[ood_dataset]['std'])
            ])
        else:
            ood_transform = id_transform

        if id_dataset == "CIFAR10":
            id_train_dataset = datasets.CIFAR10(root=f"{data_dir}/CIFAR10",
                                                download=True,
                                                transform=id_transform,
                                                target_transform=None,
                                                train=True)
            id_test_dataset = datasets.CIFAR10(root=f"{data_dir}/CIFAR10",
                                               download=True,
                                               transform=id_transform,
                                               target_transform=None,
                                               train=False)
        else:
            sys.exit(
                f"{id_dataset} is not a valid name for an id_dataset. Options are: 'CIFAR10'"
            )

        if ood_dataset == "SVHN":
            ood_train_dataset = datasets.SVHN(root=f"{data_dir}/SVHN",
                                              download=True,
                                              transform=ood_transform,
                                              target_transform=None,
                                              split='train')
            ood_test_dataset = datasets.SVHN(root=f"{data_dir}/SVHN",
                                             download=True,
                                             transform=ood_transform,
                                             target_transform=None,
                                             split='test')
        elif ood_dataset == "TIM":
            ood_train_dataset = TinyImageNet(root=f"{data_dir}/TIM",
                                             transform=ood_transform,
                                             target_transform=None,
                                             split="train")
            ood_test_dataset = TinyImageNet(root=f"{data_dir}/TIM",
                                            transform=ood_transform,
                                            target_transform=None,
                                            split="test")
        elif ood_dataset == "Random":
            ood_train_dataset = datasets.FakeData(size=len(id_train_dataset),
                                                  image_size=(3, IMAGE_SIZE,
                                                              IMAGE_SIZE),
                                                  num_classes=10,
                                                  transform=ood_transform,
                                                  target_transform=None)
            ood_test_dataset = datasets.FakeData(size=len(id_test_dataset),
                                                 image_size=(3, IMAGE_SIZE,
                                                             IMAGE_SIZE),
                                                 num_classes=10,
                                                 transform=ood_transform,
                                                 target_transform=None)
        else:
            sys.exit(
                f"{ood_dataset} is not a valid name for an ood_dataset. Options are: 'SVHN', 'TIM', 'Random'"
            )

        self.training_sample_size = training_sample_size or min(
            len(id_train_dataset), len(ood_train_dataset))
        self.test_sample_size = test_sample_size or min(
            len(id_test_dataset), len(ood_test_dataset))

        self.id_train_dataset = random_subset(id_train_dataset,
                                              self.training_sample_size)
        self.ood_train_dataset = random_subset(ood_train_dataset,
                                               self.training_sample_size)
        self.id_test_dataset = random_subset(id_test_dataset,
                                             self.test_sample_size)
        self.ood_test_dataset = random_subset(ood_test_dataset,
                                              self.test_sample_size)
Beispiel #11
0
def get_loader(_dataset, dataroot, batch_size, num_workers, image_size):
    # folder dataset
    if _dataset in ['imagenet', 'folder', 'lfw']:
        dataroot += '/resized_celebA'
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif _dataset == 'lsun':
        dataset = dset.LSUN(db_path=dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))

    elif _dataset == 'cifar10':
        dataset = dset.CIFAR10(root=dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))

    elif _dataset == 'mnist':
        dataset = dset.MNIST(root=dataroot,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)),
                             ]))

    elif _dataset == 'fashion_mnist':
        dataset = dset.FashionMNIST(root=dataroot,
                                    download=True,
                                    transform=transforms.Compose([
                                        transforms.Resize(image_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5),
                                                             (0.5, 0.5, 0.5)),
                                    ]))

    elif _dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)

    return dataloader
Beispiel #12
0
def dataloader(args):
    """Return the dataloader for selected dataset.
    Now have:
    - MNIST
    - FashionMNIST
    - CIFAR10
    - CIFAR100
    - SVHN
    - CelebA (https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZ
      zg?resourcekey=0-rJlzl934LzC-Xp28GeIBzQ)
    - STL10
    - LSUN
    - Fake data

    Parameters
    ----------
    batch_size : int
        Minibatch size.
    dataset_name : str
        Name of the selected dataset.

    Returns
    -------
    tr_set:
        Dataloader for training set.
    te_set:
        Dataloader for test set.

    """

    # resize images or not
    if args.img_resize:
        transform3c = transforms.Compose([
            transforms.Resize(args.img_size),
            transforms.CenterCrop(args.img_size),  # if H != W
            transforms.ToTensor(),
            transforms.Normalize((.5, .5, .5), (.5, .5, .5))])
        transform1c = transforms.Compose([
            transforms.Resize(args.img_size),
            transforms.CenterCrop(args.img_size),  # if H != W
            transforms.ToTensor(), transforms.Normalize((.5), (.5))])
    else:
        transform3c = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((.5, .5, .5),
                                                              (.5, .5, .5))])
        transform1c = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((.5), (.5))])
    # create dataloaders
    datapath, dataset_name, batch_size = 'data', args.dataset, args.batch_size
    if dataset_name == 'mnist':  # handwritten digits, (1, 28, 28)
        tr_set = thv.datasets.MNIST(datapath, train=True, download=True,
                                    transform=transform1c)
        te_set = thv.datasets.MNIST(datapath, train=False, download=True,
                                    transform=transform1c)
    elif dataset_name == 'fashion-mnist':  # fashion (Zalando), (1, 28, 28)
        tr_set = thv.datasets.FashionMNIST(datapath, train=True, download=True,
                                           transform=transform1c)
        te_set = thv.datasets.FashionMNIST(datapath, train=False,
                                           download=True,
                                           transform=transform1c)
    elif dataset_name == 'cifar10':  # 10-class image recognition, (3, 32 32)
        tr_set = thv.datasets.CIFAR10(datapath, train=True, download=True,
                                      transform=transform3c)
        te_set = thv.datasets.CIFAR10(datapath, train=False, download=True,
                                      transform=transform3c)
    elif dataset_name == 'cifar100':  # 100-class image recognition, (3, 32 32)
        tr_set = thv.datasets.CIFAR100(datapath, train=True, download=True,
                                       transform=transform3c)
        te_set = thv.datasets.CIFAR100(datapath, train=False, download=True,
                                       transform=transform3c)
    elif dataset_name == 'svhn':  # digit recognition, (3, 32, 32)
        tr_set = thv.datasets.SVHN(os.path.join(datapath, 'SVHN'),
                                   split='train', download=True,
                                   transform=transform3c)
        te_set = thv.datasets.SVHN(os.path.join(datapath, 'SVHN'),
                                   split='test', download=True,
                                   transform=transform3c)
    elif dataset_name == 'celeba':  # celebrity face, (3, 218, 178)
        celeba = dset.ImageFolder(root='data/celeba', transform=transform3c)
        tr_len = int(len(celeba) * 0.8)
        te_len = len(celeba) - tr_len
        tr_set, te_set = torch.utils.data.random_split(celeba,
                                                       [tr_len, te_len])
    elif dataset_name == 'stl10':  # 10-class image recognition, (3, 96, 96)
        tr_set = thv.datasets.STL10(datapath, split='train', download=True,
                                    transform=transform3c)
        te_set = thv.datasets.STL10(datapath, split='test', download=True,
                                    transform=transform3c)
    elif dataset_name == 'lsun':
        tr_classes = [c + '_train' for c in args.lsun_classes.split(',')]
        te_classes = [c + '_test' for c in args.lsun_classes.split(',')]
        tr_set = dset.LSUN(root='data/lsun', classes=tr_classes)
        te_set = dset.LSUN(root='data/lsun', classes=te_classes)
    elif dataset_name == 'fake':
        tr_set = dset.FakeData(
                               image_size=(3, args.img_size, args.img_size),
                               transform=transforms.ToTensor())
        te_set = dset.FakeData(size=1024,
                               image_size=(3, args.img_size, args.img_size),
                               transform=transforms.ToTensor())
    tr_set = DataLoader(tr_set, batch_size=batch_size, shuffle=True,
                        drop_last=True)
    te_set = DataLoader(te_set, batch_size=batch_size, shuffle=True,
                        drop_last=True)
    args.img_channels = 1 if dataset_name in ['mnist', 'fashion-mnist'] else 3
    if not args.img_resize:  # use original size
        if dataset_name in ['mnist', 'fashion-mnist']:
            args.img_size = 28
        elif dataset_name in ['cifar10', 'cifar100', 'svhn']:
            args.img_size = 32
        elif dataset_name == 'celeba':
            args.img_size = [218, 178]
        elif dataset_name == 'stl10':
            args.img_size = 96
    return tr_set, te_set
Beispiel #13
0
def main():
    ### Synthetic
    global args, best_prec1, MODELS
    ###
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    args.cuda = torch.cuda.is_available()

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ### Synthetic
    if args.dataset is "real":
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
    else:
        model_shape, num_classes, model_fn = MODELS[args.arch]
        dataset_shape = (args.samples, ) + model_shape
        #train_dataset = dataset.SyntheticDataset(dataset_shape, num_classes)
        train_dataset = datasets.FakeData(args.samples,
                                          num_classes=1000,
                                          transform=transforms.Compose(
                                              [transforms.ToTensor()]))
    ###

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    ### Synthetic - no validation needed
    if args.dataset is "real":
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    ###
    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        ### Custom metric
        epoch_start = time.time()
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if args.dataset is "real":
            prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        ### Synthetic
        if args.dataset is "real":
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
        else:
            is_best = False
        ###
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)

        ### Custom metric
        epoch_time = time.time() - epoch_start
        print(
            'Epoch: [{}] \t Speed: {image_persec:.3f} samples/sec \t Time cost={time:.3f}'
            .format(epoch,
                    image_persec=(args.samples / epoch_time),
                    time=epoch_time))