def test_save_load(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)
        dataset = datasets.load(dataset_mock.name, **config)
        sample = next(iter(dataset))

        with io.BytesIO() as buffer:
            torch.save(sample, buffer)
            buffer.seek(0)
            assert_samples_equal(torch.load(buffer), sample)
    def test_smoke(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        if not isinstance(dataset, IterDataPipe):
            raise AssertionError(
                f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead."
            )
    def test_num_samples(self, test_home, dataset_mock, config):
        mock_info = dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        num_samples = 0
        for _ in dataset:
            num_samples += 1

        assert num_samples == mock_info["num_samples"]
    def test_label_matches_path(self, test_home, dataset_mock, config):
        # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
        # This test makes sure that they're both the same
        if config.split != "train":
            return

        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        for sample in dataset:
            label_from_path = int(Path(sample["path"]).parent.name)
            assert sample["label"] == label_from_path
예제 #5
0
    def test_sample_content(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        for sample in dataset:
            assert "image" in sample
            assert "label" in sample

            assert isinstance(sample["image"], Image)
            assert isinstance(sample["label"], Label)

            assert sample["image"].shape == (1, 16, 16)
    def test_decoding(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        undecoded_features = {
            key
            for key, value in next(iter(dataset)).items()
            if isinstance(value, io.IOBase)
        }
        if undecoded_features:
            raise AssertionError(
                f"The values of key(s) "
                f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
            )
    def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        vanilla_tensors = {
            key
            for key, value in next(iter(dataset)).items()
            if type(value) is torch.Tensor
        }
        if vanilla_tensors:
            raise AssertionError(
                f"The values of key(s) "
                f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
            )
    def test_extra_label(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        sample = next(iter(dataset))
        for key, type in (
            ("nist_hsf_series", int),
            ("nist_writer_id", int),
            ("digit_index", int),
            ("nist_label", int),
            ("global_digit_index", int),
            ("duplicate", bool),
            ("unused", bool),
        ):
            assert key in sample and isinstance(sample[key], type)
    def test_has_annotations(self, test_home, dataset_mock, config,
                             annotation_dp_type):
        def scan(graph):
            for node, sub_graph in graph.items():
                yield node
                yield from scan(sub_graph)

        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        if not any(
                type(dp) is annotation_dp_type
                for dp in scan(traverse(dataset))):
            raise AssertionError(
                f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe."
            )
예제 #10
0
def main(*names, force=False):
    for name in names:
        path = BUILTIN_DIR / f"{name}.categories"
        if path.exists() and not force:
            continue

        dataset = datasets.load(name)
        try:
            categories = dataset._generate_categories()
        except NotImplementedError:
            continue

        with open(path, "w") as file:
            writer = csv.writer(file, lineterminator="\n")
            for category in categories:
                writer.writerow((
                    category, ) if isinstance(category, str) else category)
    def test_sample(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        try:
            sample = next(iter(dataset))
        except Exception as error:
            raise AssertionError(
                "Drawing a sample raised the error above.") from error

        if not isinstance(sample, dict):
            raise AssertionError(
                f"Samples should be dictionaries, but got {type(sample)} instead."
            )

        if not sample:
            raise AssertionError("Sample dictionary is empty.")
예제 #12
0
 def new_dataset(self, *, num_workers=0):
     return DataLoader2(new_datasets.load(self.name, **self.new_config),
                        num_workers=num_workers)
    def test_traversable(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        traverse(dataset)
    def test_transformable(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        next(iter(dataset.map(transforms.Identity())))
예제 #15
0
    def test_serializable(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        pickle.dumps(dataset)