def test_initialize_variables(self): with self.get_session(): model = _MyModel() model.build() out_var = tf.get_variable('out_var', shape=(), dtype=tf.int32) # test initializing variables self.assertEqual( set(get_uninitialized_variables()), { model.model_var, model.nested_var, model.other_var, model.get_global_step(), out_var }) model.ensure_variables_initialized() self.assertEqual(set(get_uninitialized_variables()), {out_var})
def test_ensure_variables_initialized_using_dict(self): a = tf.get_variable('a', dtype=tf.int32, initializer=1) b = tf.get_variable('b', dtype=tf.int32, initializer=2) # test using dict with self.test_session(): ensure_variables_initialized({'a': a}) self.assertEqual(get_uninitialized_variables([a, b]), [b])
def test_ensure_variables_initialized(self): a = tf.get_variable('a', dtype=tf.int32, initializer=1) b = tf.get_variable('b', dtype=tf.int32, initializer=2) c = tf.get_variable('c', dtype=tf.int32, initializer=3, collections=[tf.GraphKeys.MODEL_VARIABLES]) d = tf.get_variable('d', dtype=tf.int32, initializer=4, collections=[tf.GraphKeys.MODEL_VARIABLES]) # test using list with self.test_session(): self.assertEqual(get_uninitialized_variables([a, b, c, d]), [a, b, c, d]) ensure_variables_initialized() self.assertEqual(get_uninitialized_variables([a, b, c, d]), [c, d]) ensure_variables_initialized([a, b, c, d]) self.assertEqual(get_uninitialized_variables([a, b, c, d]), [])
def test_get_uninitialized_variables(self): with self.test_session() as sess: a = tf.get_variable('a', dtype=tf.int32, initializer=1) b = tf.get_variable('b', dtype=tf.int32, initializer=2) c = tf.get_variable('c', dtype=tf.int32, initializer=3, collections=[tf.GraphKeys.MODEL_VARIABLES]) d = tf.get_variable('d', dtype=tf.int32, initializer=4, collections=[tf.GraphKeys.MODEL_VARIABLES]) self.assertEqual(get_uninitialized_variables(), [a, b]) self.assertEqual(get_uninitialized_variables([a, b, c, d]), [a, b, c, d]) sess.run(tf.variables_initializer([a, c])) self.assertEqual(get_uninitialized_variables(), [b]) self.assertEqual(get_uninitialized_variables([a, b, c, d]), [b, d]) sess.run(tf.variables_initializer([b, d])) self.assertEqual(get_uninitialized_variables(), []) self.assertEqual(get_uninitialized_variables([a, b, c, d]), [])