def test_get_variable_values(self): with self.get_session() as sess: a = tf.get_variable('a', dtype=tf.int32, initializer=1) b = tf.get_variable('b', dtype=tf.int32, initializer=2) c = tf.get_variable('c', dtype=tf.int32, initializer=3) sess.run(tf.variables_initializer([a, b, c])) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) self.assertEqual(get_variable_values([a, b]), [1, 2]) self.assertEqual(get_variable_values({ 'a': a, 'b': b, 'c': c }), { 'a': 1, 'b': 2, 'c': 3 }) self.assertEqual(get_variable_values({ 'a': a, 'c': c }), { 'a': 1, 'c': 3 })
def test_load_save(self): with self.get_session(): model = _MyModel() model.ensure_variables_initialized() with TemporaryDirectory() as tempdir: # test save and load model model.save_model(tempdir) self.assertTrue(os.path.isfile(os.path.join(tempdir, 'latest'))) # test overwriting directory model.save_model(os.path.join(tempdir, '1')) with self.assertRaisesRegex(IOError, '.*already exists.'): model.save_model(os.path.join(tempdir, '1')) model.save_model(os.path.join(tempdir, '1'), overwrite=True) self.assertTrue( os.path.isfile(os.path.join(tempdir, '1', 'latest'))) # test overwriting file open(os.path.join(tempdir, '2'), 'wb').close() with self.assertRaisesRegex(IOError, '.*already exists.'): model.save_model(os.path.join(tempdir, '2')) model.save_model(os.path.join(tempdir, '2'), overwrite=True) self.assertTrue( os.path.isfile(os.path.join(tempdir, '2', 'latest'))) # test load from tempdir self.assertEqual( get_variable_values([model.model_var, model.nested_var]), [1, 3]) set_variable_values([model.model_var, model.nested_var], [10, 30]) self.assertEqual( get_variable_values([model.model_var, model.nested_var]), [10, 30]) model.load_model(tempdir) self.assertEqual( get_variable_values([model.model_var, model.nested_var]), [1, 3]) # test load from tempdir + '/1' set_variable_values([model.model_var, model.nested_var], [10, 30]) self.assertEqual( get_variable_values([model.model_var, model.nested_var]), [10, 30]) model.load_model(os.path.join(tempdir, '1')) self.assertEqual( get_variable_values([model.model_var, model.nested_var]), [1, 3]) # test load from non-exist with self.assertRaisesRegex( IOError, 'Checkpoint file does not exist.*'): model.load_model(os.path.join(tempdir, '3'))
def test_bigger_is_better(self): with self.get_session(): a, b, c = _populate_variables() self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) # test: memorize the best loss set_variable_values([a, b, c], [1, 2, 3]) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with early_stopping([a, b], smaller_is_better=False) as es: set_variable_values([a], [10]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(1.)) self.assertAlmostEqual(es.best_metric, 1.) set_variable_values([a, b, c], [1000, 200, 30]) self.assertFalse(es.update(.8)) self.assertAlmostEqual(es.best_metric, 1.) self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])
def test_restore_on_error(self): with self.get_session(): a, b, c = _populate_variables() self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) # test: do not restore on error with self.assertRaisesRegex(ValueError, 'value error'): with early_stopping([a, b], restore_on_error=False) as es: self.assertTrue(es.update(1.)) set_variable_values([a, b], [10, 20]) raise ValueError('value error') self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [10, 20, 3]) # test: restore on error set_variable_values([a, b, c], [1, 2, 3]) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with self.assertRaisesRegex(ValueError, 'value error'): with early_stopping([a, b], restore_on_error=True) as es: self.assertTrue(es.update(1.)) set_variable_values([a, b], [10, 20]) raise ValueError('value error') self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
def test_basic(self): with self.get_session(): a, b, c = _populate_variables() self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) # test: param-vars must not be empty with self.assertRaisesRegex( ValueError, '`param_vars` must not be empty.'): with early_stopping([]): pass # test: early-stopping context without updating loss with early_stopping([a, b]): set_variable_values([a], [10]) self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3]) # test: the first loss will always cause saving with early_stopping([a, b]) as es: set_variable_values([a], [10]) self.assertTrue(es.update(1.)) set_variable_values([a, b], [100, 20]) self.assertAlmostEqual(es.best_metric, 1.) self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3]) # test: memorize the best loss set_variable_values([a, b, c], [1, 2, 3]) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with early_stopping([a, b]) as es: set_variable_values([a], [10]) self.assertTrue(es.update(1.)) self.assertAlmostEqual(es.best_metric, 1.) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) set_variable_values([a, b, c], [1000, 200, 30]) self.assertFalse(es.update(.8)) self.assertAlmostEqual(es.best_metric, .5) self.assertAlmostEqual(es.best_metric, .5) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30]) # test: initial_loss set_variable_values([a, b, c], [1, 2, 3]) self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with early_stopping([a, b], initial_metric=.6) as es: set_variable_values([a], [10]) self.assertFalse(es.update(1.)) self.assertAlmostEqual(es.best_metric, .6) set_variable_values([a, b], [100, 20]) self.assertTrue(es.update(.5)) self.assertAlmostEqual(es.best_metric, .5) self.assertEqual(get_variable_values([a, b, c]), [100, 20, 3])
def test_save_dir(self): with self.get_session(): a, b, c = _populate_variables() self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with TemporaryDirectory() as tempdir: # test cleanup save_dir save_dir = os.path.join(tempdir, '1') with early_stopping([a, b], save_dir=save_dir) as es: self.assertTrue(es.update(1.)) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest'))) self.assertFalse(os.path.exists(save_dir)) # test not cleanup save_dir save_dir = os.path.join(tempdir, '2') with early_stopping([a, b], save_dir=save_dir, cleanup=False) as es: self.assertTrue(es.update(1.)) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest'))) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest')))
def test_early_stopping(self): with self.get_session(): a = tf.get_variable('a', shape=(), dtype=tf.int32) b = tf.get_variable('b', shape=(), dtype=tf.int32) # test early-stopping with no valid metric committed set_variable_values([a, b], [1, 2]) self.assertEqual(get_variable_values([a, b]), [1, 2]) with train_loop([a], early_stopping=True): set_variable_values([a, b], [10, 20]) self.assertEqual(get_variable_values([a, b]), [10, 20]) # test early-stopping with smaller-better metric set_variable_values([a, b], [1, 2]) self.assertEqual(get_variable_values([a, b]), [1, 2]) with train_loop([a], max_epoch=1, early_stopping=True) as loop: for _ in loop.iter_epochs(): for step, valid_loss in loop.iter_steps([0.7, 0.6, 0.8]): set_variable_values([a, b], [10 + step, 20 + step]) loop.add_metrics(valid_loss=valid_loss) self.assertAlmostEqual(loop.best_valid_metric, 0.6) self.assertEqual(get_variable_values([a, b]), [12, 23]) # test early-stopping with larger-better metric set_variable_values([a, b], [1, 2]) self.assertEqual(get_variable_values([a, b]), [1, 2]) with train_loop([a], max_epoch=1, valid_metric=('y', False), early_stopping=True) as loop: for _ in loop.iter_epochs(): for step, y in loop.iter_steps([0.7, 0.6, 0.8]): set_variable_values([a, b], [10 + step, 20 + step]) loop.add_metrics(y=y) self.assertAlmostEqual(loop.best_valid_metric, 0.8) self.assertEqual(get_variable_values([a, b]), [13, 23])