Ejemplo n.º 1
0
    def setup_method(self) -> None:
        super().setup_method()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"identical_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")

        # Also create an archive with all of these files to ensure that we can
        # pass the archive directory.
        current_dir = os.getcwd()
        os.chdir(self.TEST_DIR)
        self.archive_filename = self.TEST_DIR / "all_data.tar.gz"
        with tarfile.open(self.archive_filename, "w:gz") as archive:
            for file_path in glob.glob("identical_*.tsv"):
                archive.add(file_path)
        os.chdir(current_dir)

        self.reader = ShardedDatasetReader(base_reader=self.base_reader)
Ejemplo n.º 2
0
    def test_sharded_read(self):
        reader = ShardedDatasetReader(base_reader=self.base_reader)

        all_instances = []

        for instance in reader.read(self.identical_files_glob):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100
Ejemplo n.º 3
0
class TestShardedDatasetReader(AllenNlpTestCase):
    def setup_method(self) -> None:
        super().setup_method()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"identical_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")

        # Also create an archive with all of these files to ensure that we can
        # pass the archive directory.
        current_dir = os.getcwd()
        os.chdir(self.TEST_DIR)
        self.archive_filename = self.TEST_DIR / "all_data.tar.gz"
        with tarfile.open(self.archive_filename, "w:gz") as archive:
            for file_path in glob.glob("identical_*.tsv"):
                archive.add(file_path)
        os.chdir(current_dir)

        self.reader = ShardedDatasetReader(base_reader=self.base_reader)

    def read_and_check_instances(self, filepath: str):
        all_instances = []
        for instance in self.reader.read(filepath):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100

    def test_sharded_read_glob(self):
        self.read_and_check_instances(self.identical_files_glob)

    def test_sharded_read_archive(self):
        self.read_and_check_instances(str(self.archive_filename))
Ejemplo n.º 4
0
def test_exception_raised_when_base_reader_implements_sharding():
    class ManuallyShardedBaseReader(DatasetReader):
        def __init__(self, **kwargs):
            super().__init__(manual_distributed_sharding=True, **kwargs)

        def _read(self, file_path: str):
            pass

        def text_to_instance(self, text: str):  # type: ignore
            pass

    with pytest.raises(
            ValueError,
            match="should not implement manual distributed sharding"):
        ShardedDatasetReader(ManuallyShardedBaseReader())