Пример #1
0
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'


        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'identical_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.all_distinct_path = str(self.TEST_DIR / 'all_distinct.tsv')
        with open(self.all_distinct_path, 'w') as all_distinct:
            for i in range(100):
                file_path = self.TEST_DIR / f'distinct_{i}.tsv'
                line = f"This###DT\tis###VBZ\tsentence###NN\t{i}###CD\t.###.\n"
                with open(file_path, 'w') as f:
                    f.write(line)
                all_distinct.write(line)

        self.identical_files_glob = str(self.TEST_DIR / 'identical_*.tsv')
        self.distinct_files_glob = str(self.TEST_DIR / 'distinct_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))
Пример #2
0
 def setUp(self):
     super().setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"
     )
     self.instances_lazy = SequenceTaggingDatasetReader(lazy=True).read(
         self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"
     )
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params(
         {
             "text_field_embedder": {
                 "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
             },
             "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
         }
     )
     self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9)
     self.data_loader = DataLoader(self.instances, batch_size=2, collate_fn=allennlp_collate)
     self.data_loader_lazy = DataLoader(
         self.instances_lazy, batch_size=2, collate_fn=allennlp_collate
     )
     self.validation_data_loader = DataLoader(
         self.instances, batch_size=2, collate_fn=allennlp_collate
     )
     self.instances.index_with(vocab)
     self.instances_lazy.index_with(vocab)
    def setup_method(self) -> None:
        super().setup_method()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        self.base_reader_multi_process = SequenceTaggingDatasetReader(
            lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()

        for i in range(100):
            file_path = self.TEST_DIR / f"identical_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")

        # Also create an archive with all of these files to ensure that we can
        # pass the archive directory.
        current_dir = os.getcwd()
        os.chdir(self.TEST_DIR)
        self.archive_filename = self.TEST_DIR / "all_data.tar.gz"
        with tarfile.open(self.archive_filename, "w:gz") as archive:
            for file_path in glob.glob("identical_*.tsv"):
                archive.add(file_path)
        os.chdir(current_dir)

        self.reader = ShardedDatasetReader(base_reader=self.base_reader)
        self.reader_multi_process = ShardedDatasetReader(
            base_reader=self.base_reader_multi_process, multi_process=True)
Пример #4
0
    def setUp(self):
        super(SimpleTaggerTest, self).setUp()
        dataset = SequenceTaggingDatasetReader().read(
            'tests/fixtures/data/sequence_tagging.tsv')
        vocab = Vocabulary.from_dataset(dataset)
        self.vocab = vocab
        dataset.index_instances(vocab)
        self.dataset = dataset

        params = Params({
            "text_field_embedder": {
                "tokens": {
                    "type": "embedding",
                    "embedding_dim": 5
                }
            },
            "stacked_encoder": {
                "type": "lstm",
                "input_size": 5,
                "hidden_size": 7,
                "num_layers": 2
            }
        })

        self.model = SimpleTagger.from_params(self.vocab, params)
Пример #5
0
class TestDenseSparseAdam(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.instances = SequenceTaggingDatasetReader().read(
            self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
        self.vocab = Vocabulary.from_instances(self.instances)
        self.model_params = Params({
            "text_field_embedder": {
                "token_embedders": {
                    "tokens": {
                        "type": "embedding",
                        "embedding_dim": 5,
                        "sparse": True
                    }
                }
            },
            "encoder": {
                "type": "lstm",
                "input_size": 5,
                "hidden_size": 7,
                "num_layers": 2
            },
        })
        self.model = SimpleTagger.from_params(vocab=self.vocab,
                                              params=self.model_params)

    def test_can_optimise_model_with_dense_and_sparse_params(self):
        optimizer_params = Params({"type": "dense_sparse_adam"})
        parameters = [[n, p] for n, p in self.model.named_parameters()
                      if p.requires_grad]
        optimizer = Optimizer.from_params(model_parameters=parameters,
                                          params=optimizer_params)
        self.instances.index_with(self.vocab)
        GradientDescentTrainer(self.model, optimizer,
                               PyTorchDataLoader(self.instances, 2)).train()
Пример #6
0
    def setUp(self):
        super().setUp()
        # TODO make this a set of dataset readers
        # Classification may be easier in this case. Same dataset reader but with different paths 
        self.instances_list = []
        self.instances_list.append(SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'meta_seq' / 'sequence_tagging.tsv'))
        self.instances_list.append(SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'meta_seq' / 'sequence_tagging1.tsv'))
        self.instances_list.append(SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'meta_seq' / 'sequence_tagging2.tsv'))
        # loop through dataset readers and extend vocab
        combined_vocab = Vocabulary.from_instances(self.instances_list[0])

        for instance in self.instances_list:
            combined_vocab.extend_from_instances(Params({}), instances=instance)
        self.vocab = combined_vocab
        # Figure out params TODO 
        self.model_params = Params({
                "text_field_embedder": {
                        "token_embedders": {
                                "tokens": {
                                        "type": "embedding",
                                        "embedding_dim": 5
                                        }
                                }
                        },
                "encoder": {
                        "type": "lstm",
                        "input_size": 5,
                        "hidden_size": 7,
                        "num_layers": 2
                        }
                })
        self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9)
        self.iterator = BasicIterator(batch_size=2)
        self.iterator.index_with(combined_vocab)
Пример #7
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     dataset = SequenceTaggingDatasetReader().read(
         'tests/fixtures/data/sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(dataset)
     self.vocab = vocab
     dataset.index_instances(vocab)
     self.dataset = dataset
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "stacked_encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(self.vocab, self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
Пример #8
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(vocab=self.vocab,
                                           params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
Пример #9
0
 def setUp(self):
     super().setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model = ConstantModel(vocab)
Пример #10
0
    def setUp(self):
        super().setUp()

        # A lot of the tests want access to the metric tracker
        # so we add a property that gets it by grabbing it from
        # the relevant callback.
        def metric_tracker(self: CallbackTrainer):
            for callback in self.handler.callbacks():
                if isinstance(callback, TrackMetrics):
                    return callback.metric_tracker
            return None

        setattr(CallbackTrainer, 'metric_tracker', property(metric_tracker))

        self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
        vocab = Vocabulary.from_instances(self.instances)
        self.vocab = vocab
        self.model_params = Params({
                "text_field_embedder": {
                        "token_embedders": {
                                "tokens": {
                                        "type": "embedding",
                                        "embedding_dim": 5
                                        }
                                }
                        },
                "encoder": {
                        "type": "lstm",
                        "input_size": 5,
                        "hidden_size": 7,
                        "num_layers": 2
                        }
                })
        self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9)
Пример #11
0
 def setup_method(self):
     super().setup_method()
     self.data_path = str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
     self.reader = SequenceTaggingDatasetReader()
     self.data_loader = MultiProcessDataLoader(self.reader, self.data_path, batch_size=2)
     self.data_loader_lazy = MultiProcessDataLoader(
         self.reader, self.data_path, batch_size=2, max_instances_in_memory=10
     )
     self.instances = list(self.data_loader.iter_instances())
     self.vocab = Vocabulary.from_instances(self.instances)
     self.data_loader.index_with(self.vocab)
     self.data_loader_lazy.index_with(self.vocab)
     self.model_params = Params(
         {
             "text_field_embedder": {
                 "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
             },
             "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
         }
     )
     self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9)
     self.validation_data_loader = MultiProcessDataLoader(
         self.reader, self.data_path, batch_size=2
     )
     self.validation_data_loader.index_with(self.vocab)
class TestMultiprocessIterator(IteratorTest):
    def setUp(self):
        super().setUp()

        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"sequence_tagging_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / "sequence_tagging_*.tsv")

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))

    def test_yield_one_epoch_iterates_over_the_data_once(self):
        for test_instances in (self.instances, self.lazy_instances):
            base_iterator = BasicIterator(batch_size=2, max_instances_in_memory=1024)
            iterator = MultiprocessIterator(base_iterator, num_workers=4)
            iterator.index_with(self.vocab)
            batches = list(iterator(test_instances, num_epochs=1))
            # We just want to get the single-token array for the text field in the instance.
            instances = [
                tuple(instance.detach().cpu().numpy())
                for batch in batches
                for instance in batch["text"]["tokens"]["tokens"]
            ]
            assert len(instances) == 5

    def test_multiprocess_iterate_partial_does_not_hang(self):
        for test_instances in (self.instances, self.lazy_instances):
            base_iterator = BasicIterator(batch_size=2, max_instances_in_memory=1024)
            iterator = MultiprocessIterator(base_iterator, num_workers=4)
            iterator.index_with(self.vocab)
            generator = iterator(test_instances, num_epochs=1)
            # We only iterate through 3 of the 5 instances causing the
            # processes generating the tensors to remain active.
            for _ in range(3):
                next(generator)
            # The real test here is that we exit normally and don't hang due to
            # the still active processes.

    def test_multiprocess_reader_with_multiprocess_iterator(self):
        # use SequenceTaggingDatasetReader as the base reader
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        base_iterator = BasicIterator(batch_size=32, max_instances_in_memory=1024)

        iterator = MultiprocessIterator(base_iterator, num_workers=2)
        iterator.index_with(self.vocab)

        instances = reader.read(self.glob)

        tensor_dicts = iterator(instances, num_epochs=1)
        sizes = [len(tensor_dict["tags"]) for tensor_dict in tensor_dicts]
        assert sum(sizes) == 400
    def setUp(self):
        super().setUp()

        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"sequence_tagging_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / "sequence_tagging_*.tsv")

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))
Пример #14
0
 def setup_method(self):
     super().setup_method()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"
     )
     self.vocab = Vocabulary.from_instances(self.instances)
     self.model_params = Params(
         {
             "text_field_embedder": {
                 "token_embedders": {
                     "tokens": {"type": "embedding", "embedding_dim": 5, "sparse": True}
                 }
             },
             "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
         }
     )
     self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
Пример #15
0
    def test_brown_corpus_format(self):
        reader = SequenceTaggingDatasetReader(word_tag_delimiter='/')
        dataset = reader.read('tests/fixtures/data/brown_corpus.txt')

        assert len(dataset.instances) == 4
        fields = dataset.instances[0].fields
        assert fields["tokens"].tokens == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[1].fields
        assert fields["tokens"].tokens == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[2].fields
        assert fields["tokens"].tokens == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[3].fields
        assert fields["tokens"].tokens == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #16
0
    def test_default_format(self):
        reader = SequenceTaggingDatasetReader()
        dataset = reader.read('tests/fixtures/data/sequence_tagging.tsv')

        assert len(dataset.instances) == 4
        fields = dataset.instances[0].fields
        assert fields["tokens"].tokens == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[1].fields
        assert fields["tokens"].tokens == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[2].fields
        assert fields["tokens"].tokens == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = dataset.instances[3].fields
        assert fields["tokens"].tokens == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #17
0
    def test_brown_corpus_format(self):
        reader = SequenceTaggingDatasetReader(word_tag_delimiter=u'/')
        instances = reader.read(AllenNlpTestCase.FIXTURES_ROOT / u'data' / u'brown_corpus.txt')
        instances = ensure_list(instances)

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"cats", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[1].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"dogs", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[2].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"snakes", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[3].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"birds", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
    def test_default_format(self, lazy):
        reader = SequenceTaggingDatasetReader(lazy=lazy)
        instances = reader.read('tests/fixtures/data/sequence_tagging.tsv')
        instances = ensure_list(instances)

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields["tokens"].tokens] == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[1].fields
        assert [t.text for t in fields["tokens"].tokens] == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[2].fields
        assert [t.text for t in fields["tokens"].tokens] == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[3].fields
        assert [t.text for t in fields["tokens"].tokens] == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #19
0
    def test_default_format(self, lazy):
        reader = SequenceTaggingDatasetReader(lazy=lazy)
        instances = reader.read(AllenNlpTestCase.FIXTURES_ROOT / u'data' / u'sequence_tagging.tsv')
        instances = ensure_list(instances)

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"cats", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[1].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"dogs", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[2].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"snakes", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
        fields = instances[3].fields
        assert [t.text for t in fields[u"tokens"].tokens] == [u"birds", u"are", u"animals", u"."]
        assert fields[u"tags"].labels == [u"N", u"V", u"N", u"N"]
Пример #20
0
    def test_read_from_file(self):

        reader = SequenceTaggingDatasetReader()
        dataset = reader.read(self.TRAIN_FILE)

        assert len(dataset.instances) == 4
        fields = dataset.instances[0].fields()
        assert fields["tokens"].tokens() == ["cats", "are", "animals", "."]
        assert fields["tags"].tags() == ["N", "V", "N", "N"]
        fields = dataset.instances[1].fields()
        assert fields["tokens"].tokens() == ["dogs", "are", "animals", "."]
        assert fields["tags"].tags() == ["N", "V", "N", "N"]
        fields = dataset.instances[2].fields()
        assert fields["tokens"].tokens() == ["snakes", "are", "animals", "."]
        assert fields["tags"].tags() == ["N", "V", "N", "N"]
        fields = dataset.instances[3].fields()
        assert fields["tokens"].tokens() == ["birds", "are", "animals", "."]
        assert fields["tags"].tags() == ["N", "V", "N", "N"]
    def test_brown_corpus_format(self):
        reader = SequenceTaggingDatasetReader(word_tag_delimiter='/')
        instances = reader.read('tests/fixtures/data/brown_corpus.txt')
        instances = ensure_list(instances)

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields["tokens"].tokens] == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[1].fields
        assert [t.text for t in fields["tokens"].tokens] == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[2].fields
        assert [t.text for t in fields["tokens"].tokens] == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[3].fields
        assert [t.text for t in fields["tokens"].tokens] == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #22
0
    def setUp(self):
        super(SimpleTaggerTest, self).setUp()
        self.write_sequence_tagging_data()

        dataset = SequenceTaggingDatasetReader().read(self.TRAIN_FILE)
        vocab = Vocabulary.from_dataset(dataset)
        self.vocab = vocab
        dataset.index_instances(vocab)
        self.dataset = dataset

        params = Params({
            "text_field_embedder": {
                "tokens": {
                    "type": "embedding",
                    "embedding_dim": 5
                }
            },
            "hidden_size": 7,
            "num_layers": 2
        })

        self.model = SimpleTagger.from_params(self.vocab, params)
Пример #23
0
    def test_brown_corpus_format(self):
        reader = SequenceTaggingDatasetReader(word_tag_delimiter="/")
        instances = list(
            reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" /
                        "brown_corpus.txt"))

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[1].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[2].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[3].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #24
0
    def test_default_format(self):
        reader = SequenceTaggingDatasetReader(max_instances=4)
        instances = list(
            reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" /
                        "sequence_tagging.tsv"))

        assert len(instances) == 4
        fields = instances[0].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["cats", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[1].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["dogs", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[2].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["snakes", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
        fields = instances[3].fields
        assert [t.text for t in fields["tokens"].tokens
                ] == ["birds", "are", "animals", "."]
        assert fields["tags"].labels == ["N", "V", "N", "N"]
Пример #25
0
class TestMultiprocessIterator(IteratorTest):
    def setUp(self):
        super().setUp()

        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'sequence_tagging_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / 'sequence_tagging_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(
            self.base_reader.read(str(base_file_path)))

    def test_yield_one_epoch_iterates_over_the_data_once(self):
        for test_instances in (self.instances, self.lazy_instances):
            base_iterator = BasicIterator(batch_size=2,
                                          max_instances_in_memory=1024)
            iterator = MultiprocessIterator(base_iterator, num_workers=4)
            iterator.index_with(self.vocab)
            batches = list(iterator(test_instances, num_epochs=1))
            # We just want to get the single-token array for the text field in the instance.
            instances = [
                tuple(instance.detach().cpu().numpy()) for batch in batches
                for instance in batch['text']["tokens"]
            ]
            assert len(instances) == 5

    def test_multiprocess_reader_with_multiprocess_iterator(self):
        # use SequenceTaggingDatasetReader as the base reader
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=2)
        base_iterator = BasicIterator(batch_size=32,
                                      max_instances_in_memory=1024)

        iterator = MultiprocessIterator(base_iterator, num_workers=2)
        iterator.index_with(self.vocab)

        instances = reader.read(self.glob)

        tensor_dicts = iterator(instances, num_epochs=1)
        sizes = [len(tensor_dict['tags']) for tensor_dict in tensor_dicts]
        assert sum(sizes) == 400
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"identical_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")
    def setUp(self):
        super().setUp()

        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'sequence_tagging_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / 'sequence_tagging_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))
class TestMultiprocessIterator(IteratorTest):
    def setUp(self):
        super().setUp()

        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'sequence_tagging_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / 'sequence_tagging_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))


    def test_yield_one_epoch_iterates_over_the_data_once(self):
        for test_instances in (self.instances, self.lazy_instances):
            base_iterator = BasicIterator(batch_size=2, max_instances_in_memory=1024)
            iterator = MultiprocessIterator(base_iterator, num_workers=4)
            iterator.index_with(self.vocab)
            batches = list(iterator(test_instances, num_epochs=1))
            # We just want to get the single-token array for the text field in the instance.
            instances = [tuple(instance.detach().cpu().numpy())
                         for batch in batches
                         for instance in batch['text']["tokens"]]
            assert len(instances) == 5

    def test_multiprocess_reader_with_multiprocess_iterator(self):
        # use SequenceTaggingDatasetReader as the base reader
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        base_iterator = BasicIterator(batch_size=32, max_instances_in_memory=1024)

        iterator = MultiprocessIterator(base_iterator, num_workers=2)
        iterator.index_with(self.vocab)

        instances = reader.read(self.glob)

        tensor_dicts = iterator(instances, num_epochs=1)
        sizes = [len(tensor_dict['tags']) for tensor_dict in tensor_dicts]
        assert sum(sizes) == 400
Пример #29
0
 def setUp(self):
     super(TestOptimizer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / u'data' / u'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.model_params = Params({
             u"text_field_embedder": {
                     u"tokens": {
                             u"type": u"embedding",
                             u"embedding_dim": 5
                             }
                     },
             u"encoder": {
                     u"type": u"lstm",
                     u"input_size": 5,
                     u"hidden_size": 7,
                     u"num_layers": 2
                     }
             })
     self.model = SimpleTagger.from_params(vocab=vocab, params=self.model_params)
 def setUp(self):
     super(TestOptimizer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         'tests/fixtures/data/sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(vocab, self.model_params)
Пример #31
0
class TestMultiprocessDatasetReader(AllenNlpTestCase):
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'


        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'identical_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.all_distinct_path = str(self.TEST_DIR / 'all_distinct.tsv')
        with open(self.all_distinct_path, 'w') as all_distinct:
            for i in range(100):
                file_path = self.TEST_DIR / f'distinct_{i}.tsv'
                line = f"This###DT\tis###VBZ\tsentence###NN\t{i}###CD\t.###.\n"
                with open(file_path, 'w') as f:
                    f.write(line)
                all_distinct.write(line)

        self.identical_files_glob = str(self.TEST_DIR / 'identical_*.tsv')
        self.distinct_files_glob = str(self.TEST_DIR / 'distinct_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))

    def test_multiprocess_read(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=4)

        all_instances = []

        for instance in reader.read(self.identical_files_glob):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100

    def test_multiprocess_read_in_subprocess_is_deterministic(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=1)
        q = Queue()
        def read():
            for instance in reader.read(self.distinct_files_glob):
                q.put(fingerprint(instance))

        # Ensure deterministic shuffling.
        np.random.seed(0)
        p = Process(target=read)
        p.start()
        p.join()

        # Convert queue to list.
        actual_fingerprints = []
        while not q.empty():
            actual_fingerprints.append(q.get(block=False))

        assert len(actual_fingerprints) == 100

        expected_fingerprints = []
        for instance in self.base_reader.read(self.all_distinct_path):
            expected_fingerprints.append(fingerprint(instance))

        np.random.seed(0)
        expected_fingerprints.sort()
        # This should be shuffled into exactly the same order as actual_fingerprints.
        np.random.shuffle(expected_fingerprints)

        assert actual_fingerprints == expected_fingerprints

    def test_multiple_epochs(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=2,
                                           epochs_per_read=3)

        all_instances = []

        for instance in reader.read(self.identical_files_glob):
            all_instances.append(instance)

        # 100 files * 4 sentences per file * 3 epochs
        assert len(all_instances) == 100 * 4 * 3

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 * 3 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 300

    def test_with_iterator(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        instances = reader.read(self.identical_files_glob)

        iterator = BasicIterator(batch_size=32)
        iterator.index_with(self.vocab)

        batches = [batch for batch in iterator(instances, num_epochs=1)]

        # 400 instances / batch_size 32 = 12 full batches + 1 batch of 16
        sizes = sorted([len(batch['tags']) for batch in batches])
        assert sizes == [16] + 12 * [32]
Пример #32
0
class TestMultiprocessDatasetReader(AllenNlpTestCase):
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f"identical_{i}.tsv"
            with open(file_path, "w") as f:
                f.write(raw_data)

        self.all_distinct_path = str(self.TEST_DIR / "all_distinct.tsv")
        with open(self.all_distinct_path, "w") as all_distinct:
            for i in range(100):
                file_path = self.TEST_DIR / f"distinct_{i}.tsv"
                line = f"This###DT\tis###VBZ\tsentence###NN\t{i}###CD\t.###.\n"
                with open(file_path, "w") as f:
                    f.write(line)
                all_distinct.write(line)

        self.identical_files_glob = str(self.TEST_DIR / "identical_*.tsv")
        self.distinct_files_glob = str(self.TEST_DIR / "distinct_*.tsv")

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))

    def test_multiprocess_read(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=4)

        all_instances = []

        for instance in reader.read(self.identical_files_glob):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100

    def test_multiprocess_read_partial_does_not_hang(self):
        # Use a small queue size such that the processes generating the data will block.
        reader = MultiprocessDatasetReader(
            base_reader=self.base_reader, num_workers=4, output_queue_size=10
        )

        all_instances = []

        # Half of 100 files * 4 sentences / file
        i = 0
        for instance in reader.read(self.identical_files_glob):
            # Stop early such that the processes generating the data remain
            # active (given the small queue size).
            if i == 200:
                break
            i += 1
            all_instances.append(instance)

        # This should be trivially true. The real test here is that we exit
        # normally and don't hang due to the still active processes.
        assert len(all_instances) == 200

    def test_multiprocess_read_with_qiterable(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=4)

        all_instances = []
        qiterable = reader.read(self.identical_files_glob)
        assert isinstance(qiterable, QIterable)

        # Essentially QIterable.__iter__. Broken out here as we intend it to be
        # a public interface.
        qiterable.start()
        while qiterable.num_active_workers.value > 0 or qiterable.num_inflight_items.value > 0:
            while True:
                try:
                    all_instances.append(qiterable.output_queue.get(block=False, timeout=1.0))
                    with qiterable.num_inflight_items.get_lock():
                        qiterable.num_inflight_items.value -= 1
                except Empty:
                    break
        qiterable.join()

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100

    def test_multiprocess_read_in_subprocess_is_deterministic(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=1)
        q = Queue()

        def read():
            for instance in reader.read(self.distinct_files_glob):
                q.put(fingerprint(instance))

        # Ensure deterministic shuffling.
        np.random.seed(0)
        p = Process(target=read)
        p.start()
        p.join()

        # Convert queue to list.
        actual_fingerprints = []
        while not q.empty():
            actual_fingerprints.append(q.get(block=False))

        assert len(actual_fingerprints) == 100

        expected_fingerprints = []
        for instance in self.base_reader.read(self.all_distinct_path):
            expected_fingerprints.append(fingerprint(instance))

        np.random.seed(0)
        expected_fingerprints.sort()
        # This should be shuffled into exactly the same order as actual_fingerprints.
        np.random.shuffle(expected_fingerprints)

        assert actual_fingerprints == expected_fingerprints

    def test_multiple_epochs(self):
        reader = MultiprocessDatasetReader(
            base_reader=self.base_reader, num_workers=2, epochs_per_read=3
        )

        all_instances = []

        for instance in reader.read(self.identical_files_glob):
            all_instances.append(instance)

        # 100 files * 4 sentences per file * 3 epochs
        assert len(all_instances) == 100 * 4 * 3

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 * 3 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 300

    def test_with_iterator(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        instances = reader.read(self.identical_files_glob)

        iterator = BasicIterator(batch_size=32)
        iterator.index_with(self.vocab)

        batches = [batch for batch in iterator(instances, num_epochs=1)]

        # 400 instances / batch_size 32 = 12 full batches + 1 batch of 16
        sizes = sorted([len(batch["tags"]) for batch in batches])
        assert sizes == [16] + 12 * [32]
Пример #33
0
class TestMultiprocessDatasetReader(AllenNlpTestCase):
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'

        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'sequence_tagging_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / 'sequence_tagging_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(
            self.base_reader.read(str(base_file_path)))

    def test_multiprocess_read(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=4)

        all_instances = []

        for instance in reader.read(self.glob):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N",
                       "N")] == 100

    def test_multiple_epochs(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=2,
                                           epochs_per_read=3)

        all_instances = []

        for instance in reader.read(self.glob):
            all_instances.append(instance)

        # 100 files * 4 sentences per file * 3 epochs
        assert len(all_instances) == 100 * 4 * 3

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 * 3 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N",
                       "N")] == 300
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N",
                       "N")] == 300
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N",
                       "N")] == 300
        assert counts[("birds", "are", "animals", ".", "N", "V", "N",
                       "N")] == 300

    def test_with_iterator(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=2)
        instances = reader.read(self.glob)

        iterator = BasicIterator(batch_size=32)
        iterator.index_with(self.vocab)

        batches = [batch for batch in iterator(instances, num_epochs=1)]

        # 400 instances / batch_size 32 = 12 full batches + 1 batch of 16
        sizes = sorted([len(batch['tags']) for batch in batches])
        assert sizes == [16] + 12 * [32]
class TestMultiprocessDatasetReader(AllenNlpTestCase):
    def setUp(self) -> None:
        super().setUp()

        # use SequenceTaggingDatasetReader as the base reader
        self.base_reader = SequenceTaggingDatasetReader(lazy=True)
        base_file_path = AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv'


        # Make 100 copies of the data
        raw_data = open(base_file_path).read()
        for i in range(100):
            file_path = self.TEST_DIR / f'sequence_tagging_{i}.tsv'
            with open(file_path, 'w') as f:
                f.write(raw_data)

        self.glob = str(self.TEST_DIR / 'sequence_tagging_*.tsv')

        # For some of the tests we need a vocab, we'll just use the base_reader for that.
        self.vocab = Vocabulary.from_instances(self.base_reader.read(str(base_file_path)))

    def test_multiprocess_read(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=4)

        all_instances = []

        for instance in reader.read(self.glob):
            all_instances.append(instance)

        # 100 files * 4 sentences / file
        assert len(all_instances) == 100 * 4

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 100
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 100

    def test_multiple_epochs(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader,
                                           num_workers=2,
                                           epochs_per_read=3)

        all_instances = []

        for instance in reader.read(self.glob):
            all_instances.append(instance)

        # 100 files * 4 sentences per file * 3 epochs
        assert len(all_instances) == 100 * 4 * 3

        counts = Counter(fingerprint(instance) for instance in all_instances)

        # should have the exact same data 100 * 3 times
        assert len(counts) == 4
        assert counts[("cats", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("dogs", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("snakes", "are", "animals", ".", "N", "V", "N", "N")] == 300
        assert counts[("birds", "are", "animals", ".", "N", "V", "N", "N")] == 300

    def test_with_iterator(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        instances = reader.read(self.glob)

        iterator = BasicIterator(batch_size=32)
        iterator.index_with(self.vocab)

        batches = [batch for batch in iterator(instances, num_epochs=1)]

        # 400 instances / batch_size 32 = 12 full batches + 1 batch of 16
        sizes = sorted([len(batch['tags']) for batch in batches])
        assert sizes == [16] + 12 * [32]