コード例 #1
0
ファイル: trainer_test.py プロジェクト: vincentqb/allennlp
    def test_should_stop_early_with_flat_lining_metric(self):
        flatline = [0.2] * 6
        tracker = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            patience=5,
            validation_metric="+test",
        )._metric_tracker
        tracker.add_metrics(flatline)
        assert tracker.should_stop_early

        tracker = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            patience=5,
            validation_metric="-test",
        )._metric_tracker
        tracker.add_metrics(flatline)
        assert tracker.should_stop_early
コード例 #2
0
    def test_should_stop_early_with_flat_lining_metric(self):
        # pylint: disable=protected-access
        flatline = [.2] * 6
        tracker = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          patience=5,
                          validation_metric="+test")._metric_tracker
        tracker.add_metrics(flatline)
        assert tracker.should_stop_early

        tracker = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          patience=5,
                          validation_metric="-test")._metric_tracker
        tracker.add_metrics(flatline)
        assert tracker.should_stop_early