def test_get_variable_values(self):
        with self.get_session() as sess:
            a = tf.get_variable('a', dtype=tf.int32, initializer=1)
            b = tf.get_variable('b', dtype=tf.int32, initializer=2)
            c = tf.get_variable('c', dtype=tf.int32, initializer=3)
            sess.run(tf.variables_initializer([a, b, c]))

            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            self.assertEqual(get_variable_values({
                'a': a,
                'b': b,
                'c': c
            }), {
                'a': 1,
                'b': 2,
                'c': 3
            })
            self.assertEqual(get_variable_values({
                'a': a,
                'c': c
            }), {
                'a': 1,
                'c': 3
            })
示例#2
0
    def test_load_save(self):
        with self.get_session():
            model = _MyModel()
            model.ensure_variables_initialized()

            with TemporaryDirectory() as tempdir:
                # test save and load model
                model.save_model(tempdir)
                self.assertTrue(os.path.isfile(os.path.join(tempdir,
                                                            'latest')))

                # test overwriting directory
                model.save_model(os.path.join(tempdir, '1'))
                with self.assertRaisesRegex(IOError, '.*already exists.'):
                    model.save_model(os.path.join(tempdir, '1'))
                model.save_model(os.path.join(tempdir, '1'), overwrite=True)
                self.assertTrue(
                    os.path.isfile(os.path.join(tempdir, '1', 'latest')))

                # test overwriting file
                open(os.path.join(tempdir, '2'), 'wb').close()
                with self.assertRaisesRegex(IOError, '.*already exists.'):
                    model.save_model(os.path.join(tempdir, '2'))
                model.save_model(os.path.join(tempdir, '2'), overwrite=True)
                self.assertTrue(
                    os.path.isfile(os.path.join(tempdir, '2', 'latest')))

                # test load from tempdir
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])
                set_variable_values([model.model_var, model.nested_var],
                                    [10, 30])
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [10, 30])
                model.load_model(tempdir)
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])

                # test load from tempdir + '/1'
                set_variable_values([model.model_var, model.nested_var],
                                    [10, 30])
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [10, 30])
                model.load_model(os.path.join(tempdir, '1'))
                self.assertEqual(
                    get_variable_values([model.model_var, model.nested_var]),
                    [1, 3])

                # test load from non-exist
                with self.assertRaisesRegex(
                        IOError, 'Checkpoint file does not exist.*'):
                    model.load_model(os.path.join(tempdir, '3'))
示例#3
0
    def test_bigger_is_better(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: memorize the best loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b], smaller_is_better=False) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(1.))
                self.assertAlmostEqual(es.best_metric, 1.)
                set_variable_values([a, b, c], [1000, 200, 30])
                self.assertFalse(es.update(.8))
                self.assertAlmostEqual(es.best_metric, 1.)
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])
示例#4
0
    def test_restore_on_error(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: do not restore on error
            with self.assertRaisesRegex(ValueError, 'value error'):
                with early_stopping([a, b], restore_on_error=False) as es:
                    self.assertTrue(es.update(1.))
                    set_variable_values([a, b], [10, 20])
                    raise ValueError('value error')
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [10, 20, 3])

            # test: restore on error
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with self.assertRaisesRegex(ValueError, 'value error'):
                with early_stopping([a, b], restore_on_error=True) as es:
                    self.assertTrue(es.update(1.))
                    set_variable_values([a, b], [10, 20])
                    raise ValueError('value error')
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
示例#5
0
    def test_basic(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            # test: param-vars must not be empty
            with self.assertRaisesRegex(
                    ValueError, '`param_vars` must not be empty.'):
                with early_stopping([]):
                    pass

            # test: early-stopping context without updating loss
            with early_stopping([a, b]):
                set_variable_values([a], [10])
            self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])

            # test: the first loss will always cause saving
            with early_stopping([a, b]) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(1.))
                set_variable_values([a, b], [100, 20])
            self.assertAlmostEqual(es.best_metric, 1.)
            self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])

            # test: memorize the best loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b]) as es:
                set_variable_values([a], [10])
                self.assertTrue(es.update(1.))
                self.assertAlmostEqual(es.best_metric, 1.)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
                set_variable_values([a, b, c], [1000, 200, 30])
                self.assertFalse(es.update(.8))
                self.assertAlmostEqual(es.best_metric, .5)
            self.assertAlmostEqual(es.best_metric, .5)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 30])

            # test: initial_loss
            set_variable_values([a, b, c], [1, 2, 3])
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])
            with early_stopping([a, b], initial_metric=.6) as es:
                set_variable_values([a], [10])
                self.assertFalse(es.update(1.))
                self.assertAlmostEqual(es.best_metric, .6)
                set_variable_values([a, b], [100, 20])
                self.assertTrue(es.update(.5))
                self.assertAlmostEqual(es.best_metric, .5)
            self.assertEqual(get_variable_values([a, b, c]), [100, 20, 3])
示例#6
0
    def test_save_dir(self):
        with self.get_session():
            a, b, c = _populate_variables()
            self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3])

            with TemporaryDirectory() as tempdir:
                # test cleanup save_dir
                save_dir = os.path.join(tempdir, '1')
                with early_stopping([a, b], save_dir=save_dir) as es:
                    self.assertTrue(es.update(1.))
                    self.assertTrue(
                        os.path.exists(os.path.join(save_dir, 'latest')))
                self.assertFalse(os.path.exists(save_dir))

                # test not cleanup save_dir
                save_dir = os.path.join(tempdir, '2')
                with early_stopping([a, b], save_dir=save_dir,
                                    cleanup=False) as es:
                    self.assertTrue(es.update(1.))
                    self.assertTrue(
                        os.path.exists(os.path.join(save_dir, 'latest')))
                self.assertTrue(
                    os.path.exists(os.path.join(save_dir, 'latest')))
示例#7
0
    def test_early_stopping(self):
        with self.get_session():
            a = tf.get_variable('a', shape=(), dtype=tf.int32)
            b = tf.get_variable('b', shape=(), dtype=tf.int32)

            # test early-stopping with no valid metric committed
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a], early_stopping=True):
                set_variable_values([a, b], [10, 20])
            self.assertEqual(get_variable_values([a, b]), [10, 20])

            # test early-stopping with smaller-better metric
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a], max_epoch=1, early_stopping=True) as loop:
                for _ in loop.iter_epochs():
                    for step, valid_loss in loop.iter_steps([0.7, 0.6, 0.8]):
                        set_variable_values([a, b], [10 + step, 20 + step])
                        loop.add_metrics(valid_loss=valid_loss)
            self.assertAlmostEqual(loop.best_valid_metric, 0.6)
            self.assertEqual(get_variable_values([a, b]), [12, 23])

            # test early-stopping with larger-better metric
            set_variable_values([a, b], [1, 2])
            self.assertEqual(get_variable_values([a, b]), [1, 2])
            with train_loop([a],
                            max_epoch=1,
                            valid_metric=('y', False),
                            early_stopping=True) as loop:
                for _ in loop.iter_epochs():
                    for step, y in loop.iter_steps([0.7, 0.6, 0.8]):
                        set_variable_values([a, b], [10 + step, 20 + step])
                        loop.add_metrics(y=y)
            self.assertAlmostEqual(loop.best_valid_metric, 0.8)
            self.assertEqual(get_variable_values([a, b]), [13, 23])