예제 #1
0
def get_data_loader(args):

    if args.dataset == 'mnist':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        train_dataset = MNIST(root=args.dataroot, train=True, download=args.download, transform=trans, few_shot_class=args.few_shot_class, test_emnist=args.test_emnist, max_test_sample=args.max_test_sample)
        test_dataset = MNIST(root=args.dataroot, train=False, download=args.download, transform=trans, few_shot_class=args.few_shot_class, test_emnist=args.test_emnist, max_test_sample=args.max_test_sample)
    elif args.dataset == 'fashion-mnist':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = FashionMNIST(root=args.dataroot, train=True, download=args.download, transform=trans)
        test_dataset = FashionMNIST(root=args.dataroot, train=False, download=args.download, transform=trans)

    elif args.dataset == 'cifar':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        train_dataset = dset.CIFAR10(root=args.dataroot, train=True, download=args.download, transform=trans)
        test_dataset = dset.CIFAR10(root=args.dataroot, train=False, download=args.download, transform=trans)

    elif args.dataset == 'stl10':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])
        train_dataset = dset.STL10(root=args.dataroot, train=True, download=args.download, transform=trans)
        test_dataset = dset.STL10(root=args.dataroot, train=False, download=args.download, transform=trans)


    # Check if everything is ok with loading datasets
    assert train_dataset
    assert test_dataset

    train_dataloader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_dataloader  = data_utils.DataLoader(test_dataset,  batch_size=args.batch_size, shuffle=True)

    return train_dataloader, test_dataloader
예제 #2
0
def get_data_loader(args):

    if args.dataset == 'mnist':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ])
        train_dataset = MNIST(root=args.dataroot,
                              train=True,
                              download=args.download,
                              transform=trans)
        test_dataset = MNIST(root=args.dataroot,
                             train=False,
                             download=args.download,
                             transform=trans)

    elif args.dataset == 'fashion-mnist':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ])
        train_dataset = FashionMNIST(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = FashionMNIST(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'cifar':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        train_dataset = dset.CIFAR10(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = dset.CIFAR10(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'celeba':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        train_dataset = dset.CelebA(root=args.dataroot,
                                    split='train',
                                    download=args.download,
                                    transform=trans)
        test_dataset = dset.CelebA(root=args.dataroot,
                                   split='test',
                                   download=args.download,
                                   transform=trans)

    elif args.dataset == 'stl10':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = dset.STL10(root=args.dataroot,
                                   train=True,
                                   download=args.download,
                                   transform=trans)
        test_dataset = dset.STL10(root=args.dataroot,
                                  train=False,
                                  download=args.download,
                                  transform=trans)

    elif args.dataset == 'lsun':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = dset.LSUN(root=args.dataroot,
                                  classes=['bedroom_train'],
                                  transform=trans)
        test_dataset = dset.LSUN(root=args.dataroot,
                                 classes=['bedroom_val'],
                                 transform=trans)
    elif args.dataset == 'imagenet':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = dset.ImageFolder(root=args.dataroot, transfrom=trans)
        test_dataset = dset.ImageFolder(root=args.dataroot, transform=trans)
    elif args.dataset == 'custom':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = PreProcessDataset(args.dataroot,
                                          train=True,
                                          transform=trans)
        test_dataset = PreProcessDataset(args.dataroot,
                                         train=False,
                                         transform=trans)
    # Check if everything is ok with loading datasets
    assert train_dataset
    assert test_dataset

    train_dataloader = data_utils.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_dataloader = data_utils.DataLoader(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True)

    return train_dataloader, test_dataloader
예제 #3
0
def get_data_loader(args):

    if args.dataset == 'mnist':
        trans = transforms.Compose([
            transforms.Grayscale(3),
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = MNIST(root=args.dataroot,
                              train=True,
                              download=args.download,
                              transform=trans)
        test_dataset = MNIST(root=args.dataroot,
                             train=False,
                             download=args.download,
                             transform=trans)

    elif args.dataset == 'fashion-mnist':
        trans = transforms.Compose([
            transforms.Grayscale(3),
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = FashionMNIST(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = FashionMNIST(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'cifar':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = dset.CIFAR10(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = dset.CIFAR10(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'stl10':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])
        train_dataset = dset.STL10(root=args.dataroot,
                                   train=True,
                                   download=args.download,
                                   transform=trans)
        test_dataset = dset.STL10(root=args.dataroot,
                                  train=False,
                                  download=args.download,
                                  transform=trans)

    # Check if everything is ok with loading datasets
    assert train_dataset
    assert test_dataset

    train_dataloader = data_utils.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        drop_last=True,
    )
    test_dataloader = data_utils.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        drop_last=True,
    )

    return train_dataloader, test_dataloader
예제 #4
0
파일: aae.py 프로젝트: ygoshu/PyTorch-GAN
trans = transforms.Compose([
    transforms.Scale(28),
    transforms.ToTensor(),
    transforms.Normalize(((0.5, 0.5, 0.5)), (0.5, 0.5, 0.5)),
])

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_dataloader = torch.utils.data.DataLoader(datasets.MNIST(
    './data/train_mnist', train=True, download=True, transform=trans),
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               **kwargs)

test_emnist_dataset = MNIST(root='./data/test_emnist',
                            train=False,
                            download=True,
                            transform=trans,
                            few_shot_class=5,
                            test_emnist=True)
test_emnist_loader = torch.utils.data.DataLoader(
    test_emnist_dataset, batch_size=opt.test_batch_size, shuffle=True)

test_mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(
    './data/test_mnist', train=False, download=True, transform=trans),
                                                batch_size=opt.test_batch_size,
                                                shuffle=True,
                                                **kwargs)

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(),
                                               decoder.parameters()),
                               lr=opt.lr,