Esempio n. 1
0
def fetch_dataloader(args, train=True, download=True, mini_size=128):
    # load dataset and init in the dataloader

    transforms = T.Compose([T.ToTensor()])
    dataset = MNIST(root=args.data_dir,
                    train=train,
                    download=download,
                    transform=transforms)

    # load dataset and init in the dataloader
    if args.mini_data:
        if train:
            dataset.train_data = dataset.train_data[:mini_size]
            dataset.train_labels = dataset.train_labels[:mini_size]
        else:
            dataset.test_data = dataset.test_data[:mini_size]
            dataset.test_labels = dataset.test_labels[:mini_size]

    kwargs = {
        'num_workers': 1,
        'pin_memory': True
    } if args.device.type is 'cuda' else {}

    dl = DataLoader(dataset,
                    batch_size=args.batch_size,
                    shuffle=train,
                    drop_last=True,
                    **kwargs)

    return dl
def fetch_dataloader(params, train=True, mini_size=128):

    # load dataset and init in the dataloader
    transforms = T.Compose([T.ToTensor()])
    dataset = MNIST(root=params.data_dir,
                    train=train,
                    download=True,
                    transform=transforms)

    if params.dict.get('mini_data'):
        if train:
            dataset.train_data = dataset.train_data[:mini_size]
            dataset.train_labels = dataset.train_labels[:mini_size]
        else:
            dataset.test_data = dataset.test_data[:mini_size]
            dataset.test_labels = dataset.test_labels[:mini_size]

    if params.dict.get('mini_ones'):
        if train:
            labels = dataset.train_labels[:2000]
            mask = labels == 1
            dataset.train_labels = labels[mask][:mini_size]
            dataset.train_data = dataset.train_data[:2000][mask][:mini_size]
        else:
            labels = dataset.test_labels[:2000]
            mask = labels == 1
            dataset.test_labels = labels[mask][:mini_size]
            dataset.test_data = dataset.test_data[:2000][mask][:mini_size]

    kwargs = {
        'num_workers': 1,
        'pin_memory': True
    } if torch.cuda.is_available() and params.device.type is 'cuda' else {}

    return DataLoader(dataset,
                      batch_size=params.batch_size,
                      shuffle=True,
                      drop_last=True,
                      **kwargs)
Esempio n. 3
0
def get_dataloaders(data='mnist',
                    train_bs=128,
                    test_bs=500,
                    root='./data',
                    ohe_labels=False,
                    train_fraction=1.):
    to_tensor = transforms.ToTensor()
    if data == 'mnist':
        trainset = MNIST(root, train=True, download=True, transform=to_tensor)
        if train_fraction < 1.:
            data, _, labels, _ = train_test_split(
                trainset.train_data.numpy(),
                trainset.train_labels.numpy(),
                stratify=trainset.train_labels.numpy(),
                train_size=train_fraction)
            trainset.train_data, trainset.train_labels = torch.ByteTensor(
                data), torch.LongTensor(labels)

        idx = torch.LongTensor(np.where(trainset.train_labels.numpy() == 0)[0])
        trainset.train_data = trainset.train_data[idx]
        trainset.train_labels = trainset.train_labels[idx]

        if ohe_labels:
            x = trainset.train_labels.numpy()
            ohe = np.zeros((len(x), 10))
            ohe[np.arange(ohe.shape[0]), x] = 1
            trainset.train_labels = torch.from_numpy(ohe.astype(np.float32))

        testset = MNIST(root, train=False, download=True, transform=to_tensor)
        if ohe_labels:
            x = testset.test_labels.numpy()
            ohe = np.zeros((len(x), 10))
            ohe[np.arange(ohe.shape[0]), x] = 1
            testset.test_labels = torch.from_numpy(ohe.astype(np.float32))
    elif data == 'not-mnist':
        trainset = MNIST(root=os.path.join(root, 'not-mnist'),
                         train=False,
                         download=True,
                         transform=to_tensor)
        testset = MNIST(root=os.path.join(root, 'not-mnist'),
                        train=False,
                        download=True,
                        transform=to_tensor)
    else:
        raise NotImplementedError

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_bs)
    testloader = torch.utils.data.DataLoader(testset, batch_size=test_bs)

    return trainloader, testloader
Esempio n. 4
0
def torchvision_dataset(transform=None,
                        target_transform=None,
                        train=True,
                        subset=None):
    """Creates a dataset from torchvision, configured using Command Line Arguments.

    Args:
        transform (callable, optional): A function that transforms an image (default None).
        target_transform (callable, optional): A function that transforms a label (default None).
        train (bool, optional): Training set or validation - if applicable (default True).
        subset (string, optional): Specifies the subset of the relevant
            categories, if any of them was split (default, None).

    Relevant Command Line Arguments:

        - **dataset**: `--data`, `--torchvision_dataset`.

    Note:
        Settings are automatically acquired from a call to :func:`dlt.config.parse`
        from the built-in ones. If :func:`dlt.config.parse` was not called in the 
        main script, this function will call it.

    Warning:
        Unlike the torchvision datasets, this function returns a dataset that
        uses NumPy Arrays instead of a PIL Images.
    """
    opts = fetch_opts(['dataset'], subset)

    if opts.torchvision_dataset is None:
        if subset is not None:
            apnd = '_' + subset
        else:
            apnd = ''
        raise ValueError(
            'No value given for --torchvision_dataset{0}.'.format(apnd))

    if opts.torchvision_dataset == 'mnist':
        from torchvision.datasets import MNIST
        MNIST.__getitem__ = _custom_get_item
        ret_dataset = MNIST(opts.data,
                            train=train,
                            download=True,
                            transform=transform,
                            target_transform=target_transform)
        # Add channel dimension and make numpy for consistency
        if train:
            ret_dataset.train_data = ret_dataset.train_data.unsqueeze(
                3).numpy()
            ret_dataset.train_labels = ret_dataset.train_labels.numpy()
        else:
            ret_dataset.test_data = ret_dataset.test_data.unsqueeze(3).numpy()
            ret_dataset.test_labels = ret_dataset.test_labels.numpy()
    elif opts.torchvision_dataset == 'fashionmnist':
        from torchvision.datasets import FashionMNIST
        FashionMNIST.__getitem__ = _custom_get_item
        ret_dataset = FashionMNIST(opts.data,
                                   train=train,
                                   download=True,
                                   transform=transform,
                                   target_transform=target_transform)
        if train:
            ret_dataset.train_data = ret_dataset.train_data.unsqueeze(
                3).numpy()
            ret_dataset.train_labels = ret_dataset.train_labels.numpy()
        else:
            ret_dataset.test_data = ret_dataset.test_data.unsqueeze(3).numpy()
            ret_dataset.test_labels = ret_dataset.test_labels.numpy()
    elif opts.torchvision_dataset == 'cifar10':
        from torchvision.datasets import CIFAR10
        CIFAR10.__getitem__ = _custom_get_item
        ret_dataset = CIFAR10(opts.data,
                              train=train,
                              download=True,
                              transform=transform,
                              target_transform=target_transform)
    elif opts.torchvision_dataset == 'cifar100':
        from torchvision.datasets import CIFAR100
        CIFAR100.__getitem__ = _custom_get_item
        ret_dataset = CIFAR100(opts.data,
                               train=train,
                               download=True,
                               transform=transform,
                               target_transform=target_transform)
    return ret_dataset