class EpochTrackingBucketIteratorTest(IteratorTest): def setUp(self): # The super class creates a self.instances field and populates it with some instances with # TextFields. super(EpochTrackingBucketIteratorTest, self).setUp() self.iterator = EpochTrackingBucketIterator( sorting_keys=[["text", "num_tokens"]]) self.iterator.index_with(self.vocab) # We'll add more to create a second dataset. self.more_instances = [ self.create_instance(["this", "is", "a", "sentence"]), self.create_instance( ["this", "is", "in", "the", "second", "dataset"]), self.create_instance(["so", "is", "this", "one"]) ] def test_iterator_tracks_epochs_per_dataset(self): generated_dataset1 = list(self.iterator(self.instances, num_epochs=2)) generated_dataset2 = list( self.iterator(self.more_instances, num_epochs=2)) # First dataset has five sentences. See ``IteratorTest.setUp`` assert generated_dataset1[0]["epoch_num"] == [0, 0, 0, 0, 0] assert generated_dataset1[1]["epoch_num"] == [1, 1, 1, 1, 1] # Second dataset has three sentences. assert generated_dataset2[0]["epoch_num"] == [0, 0, 0] assert generated_dataset2[1]["epoch_num"] == [1, 1, 1]
def setUp(self): # The super class creates a self.instances field and populates it with some instances with # TextFields. super(EpochTrackingBucketIteratorTest, self).setUp() self.iterator = EpochTrackingBucketIterator( sorting_keys=[["text", "num_tokens"]]) self.iterator.index_with(self.vocab) # We'll add more to create a second dataset. self.more_instances = [ self.create_instance(["this", "is", "a", "sentence"]), self.create_instance( ["this", "is", "in", "the", "second", "dataset"]), self.create_instance(["so", "is", "this", "one"]) ]
def test_forward_with_epoch_num_changes_cost_weight(self): # Redefining model. We do not want this to change the state of ``self.model``. params = Params.from_file(self.param_file) model = Model.from_params(vocab=self.vocab, params=params[u'model']) # Initial cost weight, before forward is called. assert model._checklist_cost_weight == 0.8 iterator = EpochTrackingBucketIterator(sorting_keys=[[u'sentence', u'num_tokens']]) cost_weights = [] for epoch_data in iterator(self.dataset, num_epochs=4): model.forward(**epoch_data) cost_weights.append(model._checklist_cost_weight) # The config file has ``wait_num_epochs`` set to 0, so the model starts decreasing the cost # weight at epoch 0 itself. assert_almost_equal(cost_weights, [0.72, 0.648, 0.5832, 0.52488])