示例#1
0
    def test_get(self):
        loader = registry.get(self.dataset_hparams, train=True)
        self.assertIsInstance(loader, base.DataLoader)
        self.assertEqual(len(loader), 1000)

        minibatch, labels = next(iter(loader))
        self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)

        loader = registry.get(self.dataset_hparams, train=False)
        self.assertIsInstance(loader, base.DataLoader)
        self.assertEqual(len(loader), 200)
示例#2
0
    def test_do_not_augment(self):
        self.dataset_hparams.do_not_augment = True
        loader = registry.get(self.dataset_hparams, train=True)
        self.assertIsInstance(loader, base.DataLoader)

        minibatch, labels = next(iter(loader))
        self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
示例#3
0
    def test_integration(self):
        default_hparams = registry.get_default_hparams('cifar_resnet_20')
        model = registry.get(default_hparams.model_hparams)
        self.assertEqual(self.count_parameters(model), 272474)

        cifar10 = datasets.registry.get(default_hparams.dataset_hparams, train=True)
        minibatch, labels = next(iter(cifar10))
        loss = model.loss_criterion(model(minibatch), labels)
        loss.backward()
示例#4
0
    def test_get_subsample(self):
        self.dataset_hparams.transformation_seed = 0
        self.dataset_hparams.subsample_fraction = 0.1
        loader = registry.get(self.dataset_hparams, train=True)
        self.assertIsInstance(loader, base.DataLoader)
        self.assertEqual(len(loader), 100)

        minibatch, labels = next(iter(loader))
        self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
示例#5
0
    def test_get_unsupervised_labels(self):
        self.dataset_hparams.transformation_seed = 0
        self.dataset_hparams.unsupervised_labels = 'rotation'
        loader = registry.get(self.dataset_hparams, train=True)
        self.assertIsInstance(loader, base.DataLoader)
        self.assertEqual(len(loader), 1000)

        minibatch, labels = next(iter(loader))
        self.assertEqual(np.max(labels.numpy()), 3)
        self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
示例#6
0
    def test_get_random_labels(self):
        self.dataset_hparams.transformation_seed = 0
        self.dataset_hparams.random_labels_fraction = 1.0
        self.dataset_hparams.do_not_augment = True
        loader = registry.get(self.dataset_hparams, train=True)
        self.assertIsInstance(loader, base.DataLoader)
        self.assertEqual(len(loader), 1000)

        minibatch, labels = next(iter(loader))
        self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
示例#7
0
    def test_integration(self):
        default_hparams = registry.get_default_hparams('mnist_lenet_300_100')
        model = registry.get(default_hparams.model_hparams)
        self.assertEqual(self.count_parameters(model), 266610)

        mnist = datasets.registry.get(default_hparams.dataset_hparams,
                                      train=True)
        minibatch, labels = next(iter(mnist))
        loss = model.loss_criterion(model(minibatch), labels)
        loss.backward()