def test_epoch_checker():
    mm = ModelManager("test_model")
    ac = AcceptanceChecker(mm, max_epochs=5)
    ac._save_model = MagicMock()

    value = 0.1
    while not ac.done:
        ac.add_checkpoint(None, value)
        value += 0.1

    # The epoch based on should get to the end and NEVER call with a revision
    assert ac.current_epoch == 5
    ac._save_model.assert_called_with(None)
def test_threshold_epoch_checker():
    mm = ModelManager("test_model")
    ac = AcceptanceChecker(mm,
                           max_epochs=10,
                           threshold=100,
                           comparator=acc_fun)
    ac._save_model = MagicMock()

    value = 0.1
    while not ac.done:
        ac.add_checkpoint(None, value)
        value += 0.1

    # This should reach the TENTH epoch because of EPOCH
    assert ac.current_epoch == 10
    ac._save_model.assert_called_with(None)
def test_comparator():
    # Used for loss and other comparator things
    mm = ModelManager("test_model")
    ac = AcceptanceChecker(mm,
                           max_epochs=10,
                           threshold=0.7,
                           comparator=loss_fn)
    ac._save_model = MagicMock()

    value = 1
    while not ac.done:
        ac.add_checkpoint(None, value)
        value -= 0.1

    # This should reach the sixth epoch because of THRESHOLD
    assert ac.current_epoch == 4
    ac._save_model.assert_called_with(None)
def test_plateau_epoch():
    mm = ModelManager("test_model")
    ac = AcceptanceChecker(mm,
                           max_epochs=5,
                           plateau_count=5,
                           comparator=acc_fun)
    ac._save_model = MagicMock()

    #          1    2    3    4    5    6    7
    values = [0.1, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4]
    value_idx = 0
    while not ac.done:
        ac.add_checkpoint(None, values[value_idx])
        value_idx += 1

    # This should be the 5th because we have a plateau of 3. Then we should see a rename and some deletes.
    assert ac.current_epoch == 5
    ac._save_model.assert_called_with(None)
示例#5
0
 def setup_acceptance_checker(self) -> None:
     """
     Creates an acceptance checker based on the parameters in the training config
     """
     stopping_options = self.training_config.get('stopping', {})
     tol = stopping_options['plateau_abs_tol']
     self.acceptance_checker = AcceptanceChecker(self.model_manager,
                                                 max_epochs=self.max_epochs,
                                                 threshold=stopping_options.get('threshold'),
                                                 plateau_count=stopping_options.get('plateau_count'),
                                                 comparator=lambda x, y: acceptance_loss_comparator(x, y, tol))