def test_get(self): loader = registry.get(self.dataset_hparams, train=True) self.assertIsInstance(loader, base.DataLoader) self.assertEqual(len(loader), 1000) minibatch, labels = next(iter(loader)) self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size) loader = registry.get(self.dataset_hparams, train=False) self.assertIsInstance(loader, base.DataLoader) self.assertEqual(len(loader), 200)
def test_do_not_augment(self): self.dataset_hparams.do_not_augment = True loader = registry.get(self.dataset_hparams, train=True) self.assertIsInstance(loader, base.DataLoader) minibatch, labels = next(iter(loader)) self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
def test_integration(self): default_hparams = registry.get_default_hparams('cifar_resnet_20') model = registry.get(default_hparams.model_hparams) self.assertEqual(self.count_parameters(model), 272474) cifar10 = datasets.registry.get(default_hparams.dataset_hparams, train=True) minibatch, labels = next(iter(cifar10)) loss = model.loss_criterion(model(minibatch), labels) loss.backward()
def test_get_subsample(self): self.dataset_hparams.transformation_seed = 0 self.dataset_hparams.subsample_fraction = 0.1 loader = registry.get(self.dataset_hparams, train=True) self.assertIsInstance(loader, base.DataLoader) self.assertEqual(len(loader), 100) minibatch, labels = next(iter(loader)) self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
def test_get_unsupervised_labels(self): self.dataset_hparams.transformation_seed = 0 self.dataset_hparams.unsupervised_labels = 'rotation' loader = registry.get(self.dataset_hparams, train=True) self.assertIsInstance(loader, base.DataLoader) self.assertEqual(len(loader), 1000) minibatch, labels = next(iter(loader)) self.assertEqual(np.max(labels.numpy()), 3) self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
def test_get_random_labels(self): self.dataset_hparams.transformation_seed = 0 self.dataset_hparams.random_labels_fraction = 1.0 self.dataset_hparams.do_not_augment = True loader = registry.get(self.dataset_hparams, train=True) self.assertIsInstance(loader, base.DataLoader) self.assertEqual(len(loader), 1000) minibatch, labels = next(iter(loader)) self.assertEqual(minibatch.numpy().shape[0], self.dataset_hparams.batch_size)
def test_integration(self): default_hparams = registry.get_default_hparams('mnist_lenet_300_100') model = registry.get(default_hparams.model_hparams) self.assertEqual(self.count_parameters(model), 266610) mnist = datasets.registry.get(default_hparams.dataset_hparams, train=True) minibatch, labels = next(iter(mnist)) loss = model.loss_criterion(model(minibatch), labels) loss.backward()