def test_epoch_checker(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=5) ac._save_model = MagicMock() value = 0.1 while not ac.done: ac.add_checkpoint(None, value) value += 0.1 # The epoch based on should get to the end and NEVER call with a revision assert ac.current_epoch == 5 ac._save_model.assert_called_with(None)
def test_threshold_epoch_checker(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=10, threshold=100, comparator=acc_fun) ac._save_model = MagicMock() value = 0.1 while not ac.done: ac.add_checkpoint(None, value) value += 0.1 # This should reach the TENTH epoch because of EPOCH assert ac.current_epoch == 10 ac._save_model.assert_called_with(None)
def test_comparator(): # Used for loss and other comparator things mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=10, threshold=0.7, comparator=loss_fn) ac._save_model = MagicMock() value = 1 while not ac.done: ac.add_checkpoint(None, value) value -= 0.1 # This should reach the sixth epoch because of THRESHOLD assert ac.current_epoch == 4 ac._save_model.assert_called_with(None)
def test_plateau_epoch(): mm = ModelManager("test_model") ac = AcceptanceChecker(mm, max_epochs=5, plateau_count=5, comparator=acc_fun) ac._save_model = MagicMock() # 1 2 3 4 5 6 7 values = [0.1, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4] value_idx = 0 while not ac.done: ac.add_checkpoint(None, values[value_idx]) value_idx += 1 # This should be the 5th because we have a plateau of 3. Then we should see a rename and some deletes. assert ac.current_epoch == 5 ac._save_model.assert_called_with(None)