def testGetVariablesToRestore(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('B'): b = variables.variable('a', [5]) self.assertEquals([a, b], variables.get_variables_to_restore())
def testNoneGetVariablesToRestore(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5], restore=False) with tf.variable_scope('B'): b = variables.variable('a', [5], restore=False) self.assertEqual([], variables.get_variables_to_restore()) self.assertEqual([a, b], variables.get_variables())
def testNoneGetVariablesToRestore(self): with self.test_session(): with tf.variable_scope("A"): a = variables.variable("a", [5], restore=False) with tf.variable_scope("B"): b = variables.variable("a", [5], restore=False) self.assertEquals([], variables.get_variables_to_restore()) self.assertEquals([a, b], variables.get_variables())
def testGetMixedVariablesToRestore(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) b = variables.variable('b', [5], restore=False) with tf.variable_scope('B'): c = variables.variable('c', [5]) d = variables.variable('d', [5], restore=False) self.assertEquals([a, b, c, d], variables.get_variables()) self.assertEquals([a, c], variables.get_variables_to_restore())
def testGetMixedVariablesToRestore(self): with self.test_session(): with tf.variable_scope("A"): a = variables.variable("a", [5]) b = variables.variable("b", [5], restore=False) with tf.variable_scope("B"): c = variables.variable("c", [5]) d = variables.variable("d", [5], restore=False) self.assertEquals([a, b, c, d], variables.get_variables()) self.assertEquals([a, c], variables.get_variables_to_restore())
def testVariableRestoreWithArgScopeNested(self): with self.test_session(): with scopes.arg_scope([variables.variable], restore=True): a = variables.variable("a", []) with scopes.arg_scope([variables.variable], trainable=False, collections=["A", "B"]): b = variables.variable("b", []) c = variables.variable("c", []) self.assertListEqual([a, b, c], variables.get_variables_to_restore()) self.assertListEqual([a, c], tf.trainable_variables()) self.assertListEqual([b], tf.get_collection("A")) self.assertListEqual([b], tf.get_collection("B"))
def testVariableRestoreWithArgScopeNested(self): with self.test_session(): with scopes.arg_scope([variables.variable], restore=True): a = variables.variable('a', []) with scopes.arg_scope([variables.variable], trainable=False, collections=['A', 'B']): b = variables.variable('b', []) c = variables.variable('c', []) self.assertListEqual([a, b, c], variables.get_variables_to_restore()) self.assertListEqual([a, c], tf.trainable_variables()) self.assertListEqual([b], tf.get_collection('A')) self.assertListEqual([b], tf.get_collection('B'))