def test_early_stopping_context_without_updating_loss(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b]) as es: set_variable_values([a], [10]) self.assertFalse(es.ever_updated) self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])
def test_do_not_restore_on_error(self): with self.test_session(): a, b, c = _populate_variables() with pytest.raises(ValueError, match='value error'): with EarlyStopping([a, b], restore_on_error=False) as es: self.assertTrue(es.update(1.)) set_variable_values([a, b], [10, 20]) raise ValueError('value error') self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [10, 20, 3])
def test_the_first_loss_will_always_cause_saving(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b]) as es: set_variable_values([a], [10]) self.assertTrue(es.update(1.)) set_variable_values([a, b], [100, 20]) self.assertTrue(es.ever_updated) self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])
def test_restore_on_keyboard_interrupt(self): with self.test_session(): a, b, c = _populate_variables() with pytest.raises(KeyboardInterrupt): with EarlyStopping([a, b]) as es: self.assertTrue(es.update(1.)) set_variable_values([a, b], [10, 20]) raise KeyboardInterrupt() self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
def test_initial_loss(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b], initial_metric=.6) as es: set_variable_values([a], [10]) self.assertFalse(es.update(1.)) self.assertAlmostEqual(es.best_metric, .6) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 3])
def test_cleanup_checkpoint_dir(self): with self.test_session(): a, b, c = _populate_variables() with TemporaryDirectory() as tempdir: checkpoint_dir = os.path.join(tempdir, '1') with EarlyStopping([a, b], checkpoint_dir=checkpoint_dir) as es: self.assertTrue(es.update(1.)) self.assertTrue( os.path.exists(os.path.join(checkpoint_dir, 'latest'))) self.assertFalse(os.path.exists(checkpoint_dir))
def test_bigger_is_better(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b], smaller_is_better=False) as es: set_variable_values([a], [10]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(1.)) self.assertAlmostEqual(es.best_metric, 1.) set_variable_values([a, b, c], [1000, 200, 30]) self.assertFalse(es.update(.8)) self.assertAlmostEqual(es.best_metric, 1.) self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])
def test_memorize_the_best_loss(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b]) as es: set_variable_values([a], [10]) self.assertTrue(es.update(1.)) self.assertAlmostEqual(es.best_metric, 1.) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) set_variable_values([a, b, c], [1000, 200, 30]) self.assertFalse(es.update(.8)) self.assertAlmostEqual(es.best_metric, .5) self.assertTrue(es.ever_updated) self.assertAlmostEqual(es.best_metric, .5) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])
def test_initial_loss_is_tensor(self): with self.test_session(): a, b, c = _populate_variables() with EarlyStopping([a, b], initial_metric=tf.constant(.5)) as es: np.testing.assert_equal(es.best_metric, .5)
def test_param_vars_must_not_be_empty(self): with self.test_session(): with pytest.raises(ValueError, match='`param_vars` must not be empty'): with EarlyStopping([]): pass