コード例 #1
0
    def test_initialize_variables(self):
        with self.get_session():
            model = _MyModel()
            model.build()
            out_var = tf.get_variable('out_var', shape=(), dtype=tf.int32)

            # test initializing variables
            self.assertEqual(
                set(get_uninitialized_variables()), {
                    model.model_var, model.nested_var, model.other_var,
                    model.get_global_step(), out_var
                })
            model.ensure_variables_initialized()
            self.assertEqual(set(get_uninitialized_variables()), {out_var})
コード例 #2
0
ファイル: test_session.py プロジェクト: shliujing/tfsnippet
    def test_ensure_variables_initialized_using_dict(self):
        a = tf.get_variable('a', dtype=tf.int32, initializer=1)
        b = tf.get_variable('b', dtype=tf.int32, initializer=2)

        # test using dict
        with self.test_session():
            ensure_variables_initialized({'a': a})
            self.assertEqual(get_uninitialized_variables([a, b]), [b])
コード例 #3
0
ファイル: test_session.py プロジェクト: shliujing/tfsnippet
    def test_ensure_variables_initialized(self):
        a = tf.get_variable('a', dtype=tf.int32, initializer=1)
        b = tf.get_variable('b', dtype=tf.int32, initializer=2)
        c = tf.get_variable('c',
                            dtype=tf.int32,
                            initializer=3,
                            collections=[tf.GraphKeys.MODEL_VARIABLES])
        d = tf.get_variable('d',
                            dtype=tf.int32,
                            initializer=4,
                            collections=[tf.GraphKeys.MODEL_VARIABLES])

        # test using list
        with self.test_session():
            self.assertEqual(get_uninitialized_variables([a, b, c, d]),
                             [a, b, c, d])
            ensure_variables_initialized()
            self.assertEqual(get_uninitialized_variables([a, b, c, d]), [c, d])
            ensure_variables_initialized([a, b, c, d])
            self.assertEqual(get_uninitialized_variables([a, b, c, d]), [])
コード例 #4
0
ファイル: test_session.py プロジェクト: shliujing/tfsnippet
 def test_get_uninitialized_variables(self):
     with self.test_session() as sess:
         a = tf.get_variable('a', dtype=tf.int32, initializer=1)
         b = tf.get_variable('b', dtype=tf.int32, initializer=2)
         c = tf.get_variable('c',
                             dtype=tf.int32,
                             initializer=3,
                             collections=[tf.GraphKeys.MODEL_VARIABLES])
         d = tf.get_variable('d',
                             dtype=tf.int32,
                             initializer=4,
                             collections=[tf.GraphKeys.MODEL_VARIABLES])
         self.assertEqual(get_uninitialized_variables(), [a, b])
         self.assertEqual(get_uninitialized_variables([a, b, c, d]),
                          [a, b, c, d])
         sess.run(tf.variables_initializer([a, c]))
         self.assertEqual(get_uninitialized_variables(), [b])
         self.assertEqual(get_uninitialized_variables([a, b, c, d]), [b, d])
         sess.run(tf.variables_initializer([b, d]))
         self.assertEqual(get_uninitialized_variables(), [])
         self.assertEqual(get_uninitialized_variables([a, b, c, d]), [])