class EpochTrackingBucketIteratorTest(IteratorTest):
    def setUp(self):
        # The super class creates a self.instances field and populates it with some instances with
        # TextFields.
        super(EpochTrackingBucketIteratorTest, self).setUp()
        self.iterator = EpochTrackingBucketIterator(
            sorting_keys=[["text", "num_tokens"]])
        self.iterator.index_with(self.vocab)
        # We'll add more to create a second dataset.
        self.more_instances = [
            self.create_instance(["this", "is", "a", "sentence"]),
            self.create_instance(
                ["this", "is", "in", "the", "second", "dataset"]),
            self.create_instance(["so", "is", "this", "one"])
        ]

    def test_iterator_tracks_epochs_per_dataset(self):
        generated_dataset1 = list(self.iterator(self.instances, num_epochs=2))
        generated_dataset2 = list(
            self.iterator(self.more_instances, num_epochs=2))

        # First dataset has five sentences. See ``IteratorTest.setUp``
        assert generated_dataset1[0]["epoch_num"] == [0, 0, 0, 0, 0]
        assert generated_dataset1[1]["epoch_num"] == [1, 1, 1, 1, 1]
        # Second dataset has three sentences.
        assert generated_dataset2[0]["epoch_num"] == [0, 0, 0]
        assert generated_dataset2[1]["epoch_num"] == [1, 1, 1]
 def setUp(self):
     # The super class creates a self.instances field and populates it with some instances with
     # TextFields.
     super(EpochTrackingBucketIteratorTest, self).setUp()
     self.iterator = EpochTrackingBucketIterator(
         sorting_keys=[["text", "num_tokens"]])
     self.iterator.index_with(self.vocab)
     # We'll add more to create a second dataset.
     self.more_instances = [
         self.create_instance(["this", "is", "a", "sentence"]),
         self.create_instance(
             ["this", "is", "in", "the", "second", "dataset"]),
         self.create_instance(["so", "is", "this", "one"])
     ]
예제 #3
0
 def test_forward_with_epoch_num_changes_cost_weight(self):
     # Redefining model. We do not want this to change the state of ``self.model``.
     params = Params.from_file(self.param_file)
     model = Model.from_params(vocab=self.vocab, params=params[u'model'])
     # Initial cost weight, before forward is called.
     assert model._checklist_cost_weight == 0.8
     iterator = EpochTrackingBucketIterator(sorting_keys=[[u'sentence', u'num_tokens']])
     cost_weights = []
     for epoch_data in iterator(self.dataset, num_epochs=4):
         model.forward(**epoch_data)
         cost_weights.append(model._checklist_cost_weight)
     # The config file has ``wait_num_epochs`` set to 0, so the model starts decreasing the cost
     # weight at epoch 0 itself.
     assert_almost_equal(cost_weights, [0.72, 0.648, 0.5832, 0.52488])