コード例 #1
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
 def test_should_stop_early_with_increasing_metric(self):
     new_trainer = Trainer(self.model,
                           self.optimizer,
                           self.iterator,
                           self.instances,
                           validation_dataset=self.instances,
                           num_epochs=3,
                           serialization_dir=self.TEST_DIR,
                           patience=5,
                           validation_metric="+test")
     assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
     assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
コード例 #2
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
    def test_should_stop_early_with_early_stopping_disabled(self):
        # Increasing metric
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=100,
                          patience=None,
                          validation_metric="+test")
        decreasing_history = [float(i) for i in reversed(range(20))]
        assert not trainer._should_stop_early(decreasing_history)  # pylint: disable=protected-access

        # Decreasing metric
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=100,
                          patience=None,
                          validation_metric="-test")
        increasing_history = [float(i) for i in range(20)]
        assert not trainer._should_stop_early(increasing_history)  # pylint: disable=protected-access