def test_should_stop_early_with_increasing_metric(self): new_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, patience=5, validation_metric="+test") assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4]) # pylint: disable=protected-access assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1]) # pylint: disable=protected-access
def test_should_stop_early_with_early_stopping_disabled(self): # Increasing metric trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=100, patience=None, validation_metric="+test") decreasing_history = [float(i) for i in reversed(range(20))] assert not trainer._should_stop_early(decreasing_history) # pylint: disable=protected-access # Decreasing metric trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=100, patience=None, validation_metric="-test") increasing_history = [float(i) for i in range(20)] assert not trainer._should_stop_early(increasing_history) # pylint: disable=protected-access