def test_epoch_checker(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=5) ac._save_model = MagicMock() value = 0.1 while not ac.done: ac.add_checkpoint(None, value) value += 0.1 # The epoch based on should get to the end and NEVER call with a revision assert ac.current_epoch == 5 ac._save_model.assert_called_with(None)
def test_threshold_epoch_checker(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=10, threshold=100, comparator=acc_fun) ac._save_model = MagicMock() value = 0.1 while not ac.done: ac.add_checkpoint(None, value) value += 0.1 # This should reach the TENTH epoch because of EPOCH assert ac.current_epoch == 10 ac._save_model.assert_called_with(None)
def test_comparator(): # Used for loss and other comparator things mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=10, threshold=0.7, comparator=loss_fn) ac._save_model = MagicMock() value = 1 while not ac.done: ac.add_checkpoint(None, value) value -= 0.1 # This should reach the sixth epoch because of THRESHOLD assert ac.current_epoch == 4 ac._save_model.assert_called_with(None)
def test_plateau_epoch(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=5, plateau_count=5, comparator=acc_fun) ac._save_model = MagicMock() # 1 2 3 4 5 6 7 values = [0.1, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4] value_idx = 0 while not ac.done: ac.add_checkpoint(None, values[value_idx]) value_idx += 1 # This should be the 5th because we have a plateau of 3. Then we should see a rename and some deletes. assert ac.current_epoch == 5 ac._save_model.assert_called_with(None)
def setup_acceptance_checker(self) -> None: """ Creates an acceptance checker based on the parameters in the training config """ stopping_options = self.training_config.get('stopping', {}) tol = stopping_options['plateau_abs_tol'] self.acceptance_checker = AcceptanceChecker(self.model_manager, max_epochs=self.max_epochs, threshold=stopping_options.get('threshold'), plateau_count=stopping_options.get('plateau_count'), comparator=lambda x, y: acceptance_loss_comparator(x, y, tol))