예제 #1
0
def get_dataset(name,
                split='train',
                transform=None,
                target_transform=None,
                download=True,
                datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10_whitened':
        x = load_lua('/media/SSD/Datasets/cifar10/cifar10_whitened.t7')
        if train:
            return dataset.TensorDataset([
                x['trainData']['data'],
                (x['trainData']['labels'] - 5.5).sign()
            ])
        else:
            return dataset.TensorDataset([
                x['testData']['data'], (x['testData']['labels'] - 5.5).sign()
            ])
    if name == 'tinyImagenet':
        if train:
            return datasets.ImageFolder(
                root='/home/ehoffer/Datasets/ImageNet/tiny',
                transform=transform,
                target_transform=target_transform)
            # x = load_lua('/home/ehoffer/Datasets/ImageNet/tinyImageNet.t7')
            # return dataset.TensorDataset(
            #     [x['data'],  (x['label'].float()-500.5).sign()])
        else:
            return dataset.TensorDataset(
                [torch.rand(100, 3, 64, 64).float(),
                 torch.rand(100).float()])
    elif name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
예제 #2
0
    def testTensorDataset(self):
        # dict input
        data = {
            # 'input': torch.range(0,7),
            'input': np.arange(0, 8),
            'target': np.arange(0, 8),
        }
        d = dataset.TensorDataset(data)
        self.assertEqual(len(d), 8)
        self.assertEqual(d[2], {'input': 2, 'target': 2})

        # tensor input
        a = torch.randn(8)
        d = dataset.TensorDataset(a)
        self.assertEqual(len(a), len(d))
        self.assertEqual(a[1], d[1])

        # list of tensors input
        d = dataset.TensorDataset([a])
        self.assertEqual(len(a), len(d))
        self.assertEqual(a[1], d[1][0])
예제 #3
0
 def testShuffleDataset(self):
     tbl = dataset.TensorDataset(np.asarray([0, 1, 2, 3, 4]))
     d = dataset.ShuffleDataset(tbl)
     self.assertEqual(len(d), 5)
예제 #4
0
 def testResampleDataset(self):
     tbl = dataset.TensorDataset(np.asarray([0, 1, 2]))
     d = dataset.ResampleDataset(tbl, lambda dataset, i: i % 2)
     self.assertEqual(len(d), 3)
     self.assertEqual(d[0], 0)
     self.assertEqual(d[2], 0)