def test_save_load(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) with io.BytesIO() as buffer: torch.save(sample, buffer) buffer.seek(0) assert_samples_equal(torch.load(buffer), sample)
def test_smoke(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) if not isinstance(dataset, IterDataPipe): raise AssertionError( f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead." )
def test_num_samples(self, test_home, dataset_mock, config): mock_info = dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) num_samples = 0 for _ in dataset: num_samples += 1 assert num_samples == mock_info["num_samples"]
def test_label_matches_path(self, test_home, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same if config.split != "train": return dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path
def test_sample_content(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: assert "image" in sample assert "label" in sample assert isinstance(sample["image"], Image) assert isinstance(sample["label"], Label) assert sample["image"].shape == (1, 16, 16)
def test_decoding(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) undecoded_features = { key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase) } if undecoded_features: raise AssertionError( f"The values of key(s) " f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." )
def test_no_vanilla_tensors(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) vanilla_tensors = { key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor } if vanilla_tensors: raise AssertionError( f"The values of key(s) " f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." )
def test_extra_label(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) for key, type in ( ("nist_hsf_series", int), ("nist_writer_id", int), ("digit_index", int), ("nist_label", int), ("global_digit_index", int), ("duplicate", bool), ("unused", bool), ): assert key in sample and isinstance(sample[key], type)
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def scan(graph): for node, sub_graph in graph.items(): yield node yield from scan(sub_graph) dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) if not any( type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError( f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe." )
def main(*names, force=False): for name in names: path = BUILTIN_DIR / f"{name}.categories" if path.exists() and not force: continue dataset = datasets.load(name) try: categories = dataset._generate_categories() except NotImplementedError: continue with open(path, "w") as file: writer = csv.writer(file, lineterminator="\n") for category in categories: writer.writerow(( category, ) if isinstance(category, str) else category)
def test_sample(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) try: sample = next(iter(dataset)) except Exception as error: raise AssertionError( "Drawing a sample raised the error above.") from error if not isinstance(sample, dict): raise AssertionError( f"Samples should be dictionaries, but got {type(sample)} instead." ) if not sample: raise AssertionError("Sample dictionary is empty.")
def new_dataset(self, *, num_workers=0): return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def test_traversable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) traverse(dataset)
def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) next(iter(dataset.map(transforms.Identity())))
def test_serializable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset)