コード例 #1
0
class TestEarlyStopping(unittest.TestCase):
    """Tests for early stopping."""

    #: The window size used by the early stopper
    patience: int = 2
    #: The mock losses the mock evaluator will return
    mock_losses: List[float] = [10.0, 9.0, 8.0, 8.0, 8.0, 8.0]
    #: The (zeroed) index  - 1 at which stopping will occur
    stop_constant: int = 4
    #: The minimum improvement
    delta: float = 0.0

    def setUp(self):
        """Prepare for testing the early stopper."""
        self.mock_evaluator = MockEvaluator(self.mock_losses)
        # Set automatic_memory_optimization to false for tests
        nations = Nations()
        self.model = MockModel(triples_factory=nations.training,
                               automatic_memory_optimization=False)
        self.stopper = EarlyStopper(
            model=self.model,
            evaluator=self.mock_evaluator,
            evaluation_triples_factory=nations.validation,
            patience=self.patience,
            delta=self.delta,
            larger_is_better=False,
        )

    def test_initialization(self):
        """Test warm-up phase."""
        for it in range(self.patience):
            should_stop = self.stopper.should_stop()
            assert self.stopper.number_evaluations == it + 1
            assert not should_stop

    def test_result_processing(self):
        """Test that the mock evaluation of the early stopper always gives the right loss."""
        for stop in range(1, 1 + len(self.mock_losses)):
            # Step early stopper
            should_stop = self.stopper.should_stop()

            if not should_stop:
                # check storing of results
                assert self.stopper.results == self.mock_losses[:stop]

                # check ring buffer
                if stop >= self.patience:
                    assert set(self.stopper.buffer) == set(
                        self.mock_losses[stop - self.patience:stop])

    def test_should_stop(self):
        """Test that the stopper knows when to stop."""
        for _ in range(self.stop_constant):
            self.assertFalse(self.stopper.should_stop())
        self.assertTrue(self.stopper.should_stop())
コード例 #2
0
ファイル: test_early_stopping.py プロジェクト: pykeen/pykeen
 def test_serialization(self):
     """Test for serialization."""
     summary = self.stopper.get_summary_dict()
     new_stopper = EarlyStopper(
         # not needed for test
         model=...,
         evaluator=...,
         training_triples_factory=...,
         evaluation_triples_factory=...,
     )
     new_stopper._write_from_summary_dict(**summary)
     for key in summary.keys():
         assert getattr(self.stopper, key) == getattr(new_stopper, key)
コード例 #3
0
 def setUp(self):
     """Prepare for testing the early stopper."""
     # Set automatic_memory_optimization to false for tests
     self.mock_evaluator = MockEvaluator(self.mock_losses, automatic_memory_optimization=False)
     nations = Nations()
     self.model = MockModel(triples_factory=nations.training)
     self.stopper = EarlyStopper(
         model=self.model,
         evaluator=self.mock_evaluator,
         evaluation_triples_factory=nations.validation,
         patience=self.patience,
         relative_delta=self.delta,
         larger_is_better=False,
     )
コード例 #4
0
ファイル: test_early_stopping.py プロジェクト: pykeen/pykeen
 def setUp(self):
     """Prepare for testing the early stopper."""
     # Set automatic_memory_optimization to false for tests
     self.mock_evaluator = MockEvaluator(
         key=("hits_at_10", SIDE_BOTH, RANK_REALISTIC),
         values=self.mock_losses,
         automatic_memory_optimization=False,
     )
     nations = Nations()
     self.model = FixedModel(triples_factory=nations.training)
     self.stopper = EarlyStopper(
         model=self.model,
         evaluator=self.mock_evaluator,
         training_triples_factory=nations.training,
         evaluation_triples_factory=nations.validation,
         patience=self.patience,
         relative_delta=self.delta,
         larger_is_better=False,
     )
コード例 #5
0
 def setUp(self):
     """Prepare for testing the early stopper."""
     # Set automatic_memory_optimization to false for tests
     self.mock_evaluator = MockEvaluator(
         key=None,
         values=self.mock_losses,
         automatic_memory_optimization=False,
     )
     self.triples_factory = Nations()
     self.model = FixedModel(triples_factory=self.triples_factory.training)
     self.stopper = EarlyStopper(
         metric=None,
         model=self.model,
         evaluator=self.mock_evaluator,
         training_triples_factory=self.triples_factory.training,
         evaluation_triples_factory=self.triples_factory.validation,
         patience=self.patience,
         relative_delta=self.delta,
         larger_is_better=False,
         frequency=1,
     )
コード例 #6
0
ファイル: test_early_stopping.py プロジェクト: pykeen/pykeen
class TestEarlyStopper(unittest.TestCase):
    """Tests for early stopping."""

    #: The window size used by the early stopper
    patience: int = 2
    #: The mock losses the mock evaluator will return
    mock_losses: List[float] = [10.0, 9.0, 8.0, 9.0, 8.0, 8.0]
    #: The (zeroed) index  - 1 at which stopping will occur
    stop_constant: int = 4
    #: The minimum improvement
    delta: float = 0.0
    #: The best results
    best_results: List[float] = [10.0, 9.0, 8.0, 8.0, 8.0]

    def setUp(self):
        """Prepare for testing the early stopper."""
        # Set automatic_memory_optimization to false for tests
        self.mock_evaluator = MockEvaluator(
            key=("hits_at_10", SIDE_BOTH, RANK_REALISTIC),
            values=self.mock_losses,
            automatic_memory_optimization=False,
        )
        nations = Nations()
        self.model = FixedModel(triples_factory=nations.training)
        self.stopper = EarlyStopper(
            model=self.model,
            evaluator=self.mock_evaluator,
            training_triples_factory=nations.training,
            evaluation_triples_factory=nations.validation,
            patience=self.patience,
            relative_delta=self.delta,
            larger_is_better=False,
        )

    def test_initialization(self):
        """Test warm-up phase."""
        for epoch in range(self.patience):
            should_stop = self.stopper.should_stop(epoch=epoch)
            assert not should_stop

    def test_result_processing(self):
        """Test that the mock evaluation of the early stopper always gives the right loss."""
        for epoch in range(len(self.mock_losses)):
            # Step early stopper
            should_stop = self.stopper.should_stop(epoch=epoch)

            if not should_stop:
                # check storing of results
                assert self.stopper.results == self.mock_losses[:epoch + 1]
                assert self.stopper.best_metric == self.best_results[epoch]

    def test_should_stop(self):
        """Test that the stopper knows when to stop."""
        for epoch in range(self.stop_constant):
            self.assertFalse(self.stopper.should_stop(epoch=epoch))
        self.assertTrue(self.stopper.should_stop(epoch=self.stop_constant))

    def test_result_logging(self):
        """Test whether result logger is called properly."""
        self.stopper.result_tracker = mock_tracker = Mock()
        self.stopper.should_stop(epoch=0)
        log_metrics = mock_tracker.log_metrics
        self.assertIsInstance(log_metrics, Mock)
        log_metrics.assert_called_once()
        _, call_args = log_metrics.call_args_list[0]
        self.assertIn("step", call_args)
        self.assertEqual(0, call_args["step"])
        self.assertIn("prefix", call_args)
        self.assertEqual("validation", call_args["prefix"])

    def test_serialization(self):
        """Test for serialization."""
        summary = self.stopper.get_summary_dict()
        new_stopper = EarlyStopper(
            # not needed for test
            model=...,
            evaluator=...,
            training_triples_factory=...,
            evaluation_triples_factory=...,
        )
        new_stopper._write_from_summary_dict(**summary)
        for key in summary.keys():
            assert getattr(self.stopper, key) == getattr(new_stopper, key)