Exemplo n.º 1
0
    def test_load_save(self):
        with self.get_session():
            model = _MyModel()
            model.ensure_variables_initialized()

            with TemporaryDirectory() as tempdir:
                # test save and load model
                model.save_model(tempdir)
                self.assertTrue(os.path.isfile(os.path.join(tempdir,
                                                            'latest')))

                # test overwriting directory
                model.save_model(os.path.join(tempdir, '1'))
                with self.assertRaisesRegex(IOError, '.*already exists.'):
                    model.save_model(os.path.join(tempdir, '1'))
                model.save_model(os.path.join(tempdir, '1'), overwrite=True)
                self.assertTrue(
                    os.path.isfile(os.path.join(tempdir, '1', 'latest')))

                # test overwriting file
                open(os.path.join(tempdir, '2'), 'wb').close()
                with self.assertRaisesRegex(IOError, '.*already exists.'):
                    model.save_model(os.path.join(tempdir, '2'))
                model.save_model(os.path.join(tempdir, '2'), overwrite=True)
                self.assertTrue(
                    os.path.isfile(os.path.join(tempdir, '2', 'latest')))

                # test load from tempdir
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])
                set_variable_values([model.model_var, model.nested_var],
                                    [10, 30])
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [10, 30])
                model.load_model(tempdir)
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])

                # test load from tempdir + '/1'
                set_variable_values([model.model_var, model.nested_var],
                                    [10, 30])
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [10, 30])
                model.load_model(os.path.join(tempdir, '1'))
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])

                # test load from non-exist
                with self.assertRaisesRegex(
                        IOError, 'Checkpoint file does not exist.*'):
                    model.load_model(os.path.join(tempdir, '3'))
Exemplo n.º 2
0
    def test_bigger_is_better(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: memorize the best loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b], smaller_is_better=False) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(1.))
                self.assertAlmostEqual(es.best_metric, 1.)
                set_variable_values([a, b, c], [1000, 200, 30])
                self.assertFalse(es.update(.8))
                self.assertAlmostEqual(es.best_metric, 1.)
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])
Exemplo n.º 3
0
    def test_restore_on_error(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: do not restore on error
            with self.assertRaisesRegex(ValueError, 'value error'):
                with early_stopping([a, b], restore_on_error=False) as es:
                    self.assertTrue(es.update(1.))
                    set_variable_values([a, b], [10, 20])
                    raise ValueError('value error')
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [10, 20, 3])

            # test: restore on error
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with self.assertRaisesRegex(ValueError, 'value error'):
                with early_stopping([a, b], restore_on_error=True) as es:
                    self.assertTrue(es.update(1.))
                    set_variable_values([a, b], [10, 20])
                    raise ValueError('value error')
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
Exemplo n.º 4
0
    def test_basic(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: param-vars must not be empty
            with self.assertRaisesRegex(
                    ValueError, '`param_vars` must not be empty.'):
                with early_stopping([]):
                    pass

            # test: early-stopping context without updating loss
            with early_stopping([a, b]):
                set_variable_values([a], [10])
            self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])

            # test: the first loss will always cause saving
            with early_stopping([a, b]) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(1.))
                set_variable_values([a, b], [100, 20])
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])

            # test: memorize the best loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b]) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(1.))
                self.assertAlmostEqual(es.best_metric, 1.)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
                set_variable_values([a, b, c], [1000, 200, 30])
                self.assertFalse(es.update(.8))
                self.assertAlmostEqual(es.best_metric, .5)
            self.assertAlmostEqual(es.best_metric, .5)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])

            # test: initial_loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b], initial_metric=.6) as es:
                set_variable_values([a], [10])
                self.assertFalse(es.update(1.))
                self.assertAlmostEqual(es.best_metric, .6)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 3])
Exemplo n.º 5
0
def _populate_variables():
    a = tf.get_variable('a', shape=(), dtype=tf.int32)
    b = tf.get_variable('b', shape=(), dtype=tf.int32)
    c = tf.get_variable('c', shape=(), dtype=tf.int32)
    set_variable_values([a, b, c], [1, 2, 3])
    return [a, b, c]
Exemplo n.º 6
0
    def test_early_stopping(self):
        with self.get_session():
            a = tf.get_variable('a', shape=(), dtype=tf.int32)
            b = tf.get_variable('b', shape=(), dtype=tf.int32)

            # test early-stopping with no valid metric committed
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a], early_stopping=True):
                set_variable_values([a, b], [10, 20])
            self.assertEqual(get_variable_values([a, b]), [10, 20])

            # test early-stopping with smaller-better metric
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a], max_epoch=1, early_stopping=True) as loop:
                for _ in loop.iter_epochs():
                    for step, valid_loss in loop.iter_steps([0.7, 0.6, 0.8]):
                        set_variable_values([a, b], [10 + step, 20 + step])
                        loop.add_metrics(valid_loss=valid_loss)
            self.assertAlmostEqual(loop.best_valid_metric, 0.6)
            self.assertEqual(get_variable_values([a, b]), [12, 23])

            # test early-stopping with larger-better metric
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a],
                            max_epoch=1,
                            valid_metric=('y', False),
                            early_stopping=True) as loop:
                for _ in loop.iter_epochs():
                    for step, y in loop.iter_steps([0.7, 0.6, 0.8]):
                        set_variable_values([a, b], [10 + step, 20 + step])
                        loop.add_metrics(y=y)
            self.assertAlmostEqual(loop.best_valid_metric, 0.8)
            self.assertEqual(get_variable_values([a, b]), [13, 23])
Exemplo n.º 7
0
    def test_set_variable_values(self):
        with self.get_session() as sess:
            a = tf.get_variable('a', dtype=tf.int32, initializer=1)
            b = tf.get_variable('b', dtype=tf.int32, initializer=2)
            c = tf.get_variable('c', dtype=tf.int32, initializer=3)

            set_variable_values([a, b], [10, 20])
            set_variable_values({'c': c}, {'c': 30})
            self.assertEqual(sess.run([a, b, c]), [10, 20, 30])

            sess.run(tf.variables_initializer([a, b, c]))
            self.assertEqual(sess.run([a, b, c]), [1, 2, 3])

            set_variable_values([a], [100])
            set_variable_values({'b': b, 'c': c}, {'b': 200, 'c': 300})
            self.assertEqual(sess.run([a, b, c]), [100, 200, 300])

            with self.assertRaises(TypeError):
                set_variable_values([a, b], {'a': 10, 'b': 20})
            with self.assertRaises(TypeError):
                set_variable_values({'a': a, 'b': b}, [10, 20])
            with self.assertRaises(TypeError):
                set_variable_values({'a': a, 'b': 20}, {'a': 10, 'b': 20})
            with self.assertRaises(IndexError):
                set_variable_values([a, b, c], [1, 2])
            with self.assertRaises(IndexError):
                set_variable_values([a, b], [1, 2, 3])
            with self.assertRaises(KeyError):
                set_variable_values({'a': a, 'b': b}, {'b': 20})

            set_variable_values({'a': a}, {'a': 10, 'b': 20})
            self.assertEqual(sess.run([a, b, c]), [10, 200, 300])