def test_data_normalized(self): """Checks if the image data provided by this dataset is in [0.0,1.0] given the ToTensor() transformation.""" dataset = regression_dataset.RegressionDataset(_DATASET_PATH, transforms.ToTensor()) for img_data, _ in dataset: self.assertTrue(_is_normalized(img_data.numpy().flatten()))
def test_order_of_entries(self): """dataset[k] returns the kth (img_data, score) tuple. Here, we just test equality on scores. See the dataset.json file in _DATASET_PATH for scores.""" dataset = regression_dataset.RegressionDataset(_DATASET_PATH) self.assertEqual(dataset[0][1], 0.998) self.assertEqual(dataset[1][1], 0.734) self.assertEqual(dataset[2][1], 0.343) self.assertEqual(dataset[3][1], 0.123)
def _loader(dataset_path, batch_size): '''Returns a DataLoader that emits (image_data, score) tuples. Data is not shuffled and crops are always center cropped.''' _IMAGENET_MEAN = [0.485, 0.456, 0.406] _IMAGENET_STD = [0.229, 0.224, 0.225] transform = transforms.Compose([ transforms.CenterCrop([224, 224]), transforms.ToTensor(), transforms.Normalize(_IMAGENET_MEAN, _IMAGENET_STD), ]) dataset = regression_dataset.RegressionDataset(dataset_path, transform) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4) return data_loader
def _loader(dataset_path, batch_size): '''Returns a DataLoader that emits (image_data, score) tuples. Dataset is randomly shuffled and randomly cropped.''' _IMAGENET_MEAN = [0.485, 0.456, 0.406] _IMAGENET_STD = [0.229, 0.224, 0.225] transform = transforms.Compose([ # TODO (carlo): Figure out a sensible way to do data augmentations here. transforms.RandomCrop([224, 224]), transforms.ToTensor(), transforms.Normalize(_IMAGENET_MEAN, _IMAGENET_STD), ]) dataset = regression_dataset.RegressionDataset(dataset_path, transform) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) return data_loader
def test_length(self): dataset = regression_dataset.RegressionDataset(_DATASET_PATH) self.assertEqual(len(dataset), 4)