class TestEarlyStopping(unittest.TestCase): """Tests for early stopping.""" #: The window size used by the early stopper patience: int = 2 #: The mock losses the mock evaluator will return mock_losses: List[float] = [10.0, 9.0, 8.0, 8.0, 8.0, 8.0] #: The (zeroed) index - 1 at which stopping will occur stop_constant: int = 4 #: The minimum improvement delta: float = 0.0 def setUp(self): """Prepare for testing the early stopper.""" self.mock_evaluator = MockEvaluator(self.mock_losses) # Set automatic_memory_optimization to false for tests nations = Nations() self.model = MockModel(triples_factory=nations.training, automatic_memory_optimization=False) self.stopper = EarlyStopper( model=self.model, evaluator=self.mock_evaluator, evaluation_triples_factory=nations.validation, patience=self.patience, delta=self.delta, larger_is_better=False, ) def test_initialization(self): """Test warm-up phase.""" for it in range(self.patience): should_stop = self.stopper.should_stop() assert self.stopper.number_evaluations == it + 1 assert not should_stop def test_result_processing(self): """Test that the mock evaluation of the early stopper always gives the right loss.""" for stop in range(1, 1 + len(self.mock_losses)): # Step early stopper should_stop = self.stopper.should_stop() if not should_stop: # check storing of results assert self.stopper.results == self.mock_losses[:stop] # check ring buffer if stop >= self.patience: assert set(self.stopper.buffer) == set( self.mock_losses[stop - self.patience:stop]) def test_should_stop(self): """Test that the stopper knows when to stop.""" for _ in range(self.stop_constant): self.assertFalse(self.stopper.should_stop()) self.assertTrue(self.stopper.should_stop())
def test_serialization(self): """Test for serialization.""" summary = self.stopper.get_summary_dict() new_stopper = EarlyStopper( # not needed for test model=..., evaluator=..., training_triples_factory=..., evaluation_triples_factory=..., ) new_stopper._write_from_summary_dict(**summary) for key in summary.keys(): assert getattr(self.stopper, key) == getattr(new_stopper, key)
def setUp(self): """Prepare for testing the early stopper.""" # Set automatic_memory_optimization to false for tests self.mock_evaluator = MockEvaluator(self.mock_losses, automatic_memory_optimization=False) nations = Nations() self.model = MockModel(triples_factory=nations.training) self.stopper = EarlyStopper( model=self.model, evaluator=self.mock_evaluator, evaluation_triples_factory=nations.validation, patience=self.patience, relative_delta=self.delta, larger_is_better=False, )
def setUp(self): """Prepare for testing the early stopper.""" # Set automatic_memory_optimization to false for tests self.mock_evaluator = MockEvaluator( key=("hits_at_10", SIDE_BOTH, RANK_REALISTIC), values=self.mock_losses, automatic_memory_optimization=False, ) nations = Nations() self.model = FixedModel(triples_factory=nations.training) self.stopper = EarlyStopper( model=self.model, evaluator=self.mock_evaluator, training_triples_factory=nations.training, evaluation_triples_factory=nations.validation, patience=self.patience, relative_delta=self.delta, larger_is_better=False, )
def setUp(self): """Prepare for testing the early stopper.""" # Set automatic_memory_optimization to false for tests self.mock_evaluator = MockEvaluator( key=None, values=self.mock_losses, automatic_memory_optimization=False, ) self.triples_factory = Nations() self.model = FixedModel(triples_factory=self.triples_factory.training) self.stopper = EarlyStopper( metric=None, model=self.model, evaluator=self.mock_evaluator, training_triples_factory=self.triples_factory.training, evaluation_triples_factory=self.triples_factory.validation, patience=self.patience, relative_delta=self.delta, larger_is_better=False, frequency=1, )
class TestEarlyStopper(unittest.TestCase): """Tests for early stopping.""" #: The window size used by the early stopper patience: int = 2 #: The mock losses the mock evaluator will return mock_losses: List[float] = [10.0, 9.0, 8.0, 9.0, 8.0, 8.0] #: The (zeroed) index - 1 at which stopping will occur stop_constant: int = 4 #: The minimum improvement delta: float = 0.0 #: The best results best_results: List[float] = [10.0, 9.0, 8.0, 8.0, 8.0] def setUp(self): """Prepare for testing the early stopper.""" # Set automatic_memory_optimization to false for tests self.mock_evaluator = MockEvaluator( key=("hits_at_10", SIDE_BOTH, RANK_REALISTIC), values=self.mock_losses, automatic_memory_optimization=False, ) nations = Nations() self.model = FixedModel(triples_factory=nations.training) self.stopper = EarlyStopper( model=self.model, evaluator=self.mock_evaluator, training_triples_factory=nations.training, evaluation_triples_factory=nations.validation, patience=self.patience, relative_delta=self.delta, larger_is_better=False, ) def test_initialization(self): """Test warm-up phase.""" for epoch in range(self.patience): should_stop = self.stopper.should_stop(epoch=epoch) assert not should_stop def test_result_processing(self): """Test that the mock evaluation of the early stopper always gives the right loss.""" for epoch in range(len(self.mock_losses)): # Step early stopper should_stop = self.stopper.should_stop(epoch=epoch) if not should_stop: # check storing of results assert self.stopper.results == self.mock_losses[:epoch + 1] assert self.stopper.best_metric == self.best_results[epoch] def test_should_stop(self): """Test that the stopper knows when to stop.""" for epoch in range(self.stop_constant): self.assertFalse(self.stopper.should_stop(epoch=epoch)) self.assertTrue(self.stopper.should_stop(epoch=self.stop_constant)) def test_result_logging(self): """Test whether result logger is called properly.""" self.stopper.result_tracker = mock_tracker = Mock() self.stopper.should_stop(epoch=0) log_metrics = mock_tracker.log_metrics self.assertIsInstance(log_metrics, Mock) log_metrics.assert_called_once() _, call_args = log_metrics.call_args_list[0] self.assertIn("step", call_args) self.assertEqual(0, call_args["step"]) self.assertIn("prefix", call_args) self.assertEqual("validation", call_args["prefix"]) def test_serialization(self): """Test for serialization.""" summary = self.stopper.get_summary_dict() new_stopper = EarlyStopper( # not needed for test model=..., evaluator=..., training_triples_factory=..., evaluation_triples_factory=..., ) new_stopper._write_from_summary_dict(**summary) for key in summary.keys(): assert getattr(self.stopper, key) == getattr(new_stopper, key)