def test_all_at_once(self): readers = { "f": PlainTextReader(), "g": PlainTextReader(), "h": PlainTextReader() } reader = InterleavingDatasetReader(readers, dataset_field_name="source", scheme="all_at_once") data_dir = self.FIXTURES_ROOT / "data" file_path = f"""{{ "f": "{data_dir / 'babi.txt'}", "g": "{data_dir / 'conll2000.txt'}", "h": "{data_dir / 'conll2003.txt'}" }}""" buckets = [] last_source = None # Fill up a bucket until the source changes, then start a new one for instance in reader.read(file_path): source = instance.fields["source"].metadata if source != last_source: buckets.append([]) last_source = source buckets[-1].append(instance) # should be in 3 buckets assert len(buckets) == 3
def test_round_robin(self): readers = { "a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader() } reader = InterleavingDatasetReader(readers) data_dir = self.FIXTURES_ROOT / "data" file_path = f"""{{ "a": "{data_dir / 'babi.txt'}", "b": "{data_dir / 'conll2000.txt'}", "c": "{data_dir / 'conll2003.txt'}" }}""" instances = list(reader.read(file_path)) first_three_keys = { instance.fields["dataset"].metadata for instance in instances[:3] } assert first_three_keys == {"a", "b", "c"} next_three_keys = { instance.fields["dataset"].metadata for instance in instances[3:6] } assert next_three_keys == {"a", "b", "c"}
def test_batches(self): readers = { "a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader() } reader = InterleavingDatasetReader(readers) data_dir = self.FIXTURES_ROOT / "data" file_path = f"""{{ "a": "{data_dir / 'babi.txt'}", "b": "{data_dir / 'conll2000.txt'}", "c": "{data_dir / 'conll2003.txt'}" }}""" instances = list(reader.read(file_path)) vocab = Vocabulary.from_instances(instances) actual_instance_type_counts = Counter(instance.fields["dataset"].metadata for instance in instances) iterator = HomogeneousBatchIterator(batch_size=3) iterator.index_with(vocab) observed_instance_type_counts = Counter() for batch in iterator(instances, num_epochs=1, shuffle=True): # batch should be homogeneous instance_types = set(batch["dataset"]) assert len(instance_types) == 1 observed_instance_type_counts.update(batch["dataset"]) assert observed_instance_type_counts == actual_instance_type_counts
def test_skip_smaller_batches(self): readers = { "a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader() } reader = InterleavingDatasetReader(readers) data_dir = self.FIXTURES_ROOT / "data" file_path = f"""{{ "a": "{data_dir / 'babi.txt'}", "b": "{data_dir / 'conll2000.txt'}", "c": "{data_dir / 'conll2003.txt'}" }}""" instances = list(reader.read(file_path)) vocab = Vocabulary.from_instances(instances) iterator = HomogeneousBatchIterator(batch_size=3, skip_smaller_batches=True) iterator.index_with(vocab) for batch in iterator(instances, num_epochs=1, shuffle=True): # every batch should have length 3 (batch size) assert len(batch["dataset"]) == 3
def test_with_multi_process_loading(self, lazy): readers = {"a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader()} reader = InterleavingDatasetReader(readers) data_dir = self.FIXTURES_ROOT / "data" file_path = { "a": data_dir / "babi.txt", "b": data_dir / "conll2003.txt", "c": data_dir / "conll2003.txt", } vocab = Vocabulary.from_instances(reader.read(file_path)) loader = MultiProcessDataLoader( reader, file_path, num_workers=1, batch_size=1, max_instances_in_memory=2 if lazy else None, ) loader.index_with(vocab) list(loader.iter_instances()) list(loader)