def test_all_at_once(self):
        readers = {
            "f": PlainTextReader(),
            "g": PlainTextReader(),
            "h": PlainTextReader()
        }

        reader = InterleavingDatasetReader(readers,
                                           dataset_field_name="source",
                                           scheme="all_at_once")
        data_dir = self.FIXTURES_ROOT / "data"

        file_path = f"""{{
            "f": "{data_dir / 'babi.txt'}",
            "g": "{data_dir / 'conll2000.txt'}",
            "h": "{data_dir / 'conll2003.txt'}"
        }}"""

        buckets = []
        last_source = None

        # Fill up a bucket until the source changes, then start a new one
        for instance in reader.read(file_path):
            source = instance.fields["source"].metadata
            if source != last_source:
                buckets.append([])
                last_source = source
            buckets[-1].append(instance)

        # should be in 3 buckets
        assert len(buckets) == 3
    def test_round_robin(self):
        readers = {
            "a": PlainTextReader(),
            "b": PlainTextReader(),
            "c": PlainTextReader()
        }

        reader = InterleavingDatasetReader(readers)
        data_dir = self.FIXTURES_ROOT / "data"

        file_path = f"""{{
            "a": "{data_dir / 'babi.txt'}",
            "b": "{data_dir / 'conll2000.txt'}",
            "c": "{data_dir / 'conll2003.txt'}"
        }}"""

        instances = list(reader.read(file_path))
        first_three_keys = {
            instance.fields["dataset"].metadata
            for instance in instances[:3]
        }
        assert first_three_keys == {"a", "b", "c"}

        next_three_keys = {
            instance.fields["dataset"].metadata
            for instance in instances[3:6]
        }
        assert next_three_keys == {"a", "b", "c"}
Example #3
0
    def test_batches(self):
        readers = {
                "a": PlainTextReader(),
                "b": PlainTextReader(),
                "c": PlainTextReader()
        }

        reader = InterleavingDatasetReader(readers)
        data_dir = self.FIXTURES_ROOT / "data"

        file_path = f"""{{
            "a": "{data_dir / 'babi.txt'}",
            "b": "{data_dir / 'conll2000.txt'}",
            "c": "{data_dir / 'conll2003.txt'}"
        }}"""

        instances = list(reader.read(file_path))
        vocab = Vocabulary.from_instances(instances)

        actual_instance_type_counts = Counter(instance.fields["dataset"].metadata
                                              for instance in instances)

        iterator = HomogeneousBatchIterator(batch_size=3)
        iterator.index_with(vocab)

        observed_instance_type_counts = Counter()

        for batch in iterator(instances, num_epochs=1, shuffle=True):
            # batch should be homogeneous
            instance_types = set(batch["dataset"])
            assert len(instance_types) == 1

            observed_instance_type_counts.update(batch["dataset"])

        assert observed_instance_type_counts == actual_instance_type_counts
    def test_skip_smaller_batches(self):
        readers = {
            "a": PlainTextReader(),
            "b": PlainTextReader(),
            "c": PlainTextReader()
        }

        reader = InterleavingDatasetReader(readers)
        data_dir = self.FIXTURES_ROOT / "data"

        file_path = f"""{{
            "a": "{data_dir / 'babi.txt'}",
            "b": "{data_dir / 'conll2000.txt'}",
            "c": "{data_dir / 'conll2003.txt'}"
        }}"""

        instances = list(reader.read(file_path))
        vocab = Vocabulary.from_instances(instances)

        iterator = HomogeneousBatchIterator(batch_size=3,
                                            skip_smaller_batches=True)
        iterator.index_with(vocab)

        for batch in iterator(instances, num_epochs=1, shuffle=True):
            # every batch should have length 3 (batch size)
            assert len(batch["dataset"]) == 3
Example #5
0
    def test_with_multi_process_loading(self, lazy):
        readers = {"a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader()}
        reader = InterleavingDatasetReader(readers)
        data_dir = self.FIXTURES_ROOT / "data"
        file_path = {
            "a": data_dir / "babi.txt",
            "b": data_dir / "conll2003.txt",
            "c": data_dir / "conll2003.txt",
        }
        vocab = Vocabulary.from_instances(reader.read(file_path))
        loader = MultiProcessDataLoader(
            reader,
            file_path,
            num_workers=1,
            batch_size=1,
            max_instances_in_memory=2 if lazy else None,
        )
        loader.index_with(vocab)

        list(loader.iter_instances())
        list(loader)