def test_to_var_dict_duplicate_names(self): v1 = tf.Variable(0, name='foo') v2 = tf.Variable(0, name='foo') assert v1.name == v2.name with self.assertRaisesRegexp(ValueError, 'multiple.*foo'): tensor_utils.to_var_dict([v1, v2])
def _get_weights(model): model_weights = collections.namedtuple('ModelWeights', 'trainable non_trainable') return model_weights( trainable=tensor_utils.to_var_dict(model.trainable_variables), non_trainable=tensor_utils.to_var_dict(model.non_trainable_variables))
def test_to_var_dict_preserves_order(self): a = tf.Variable(0, name='a') b = tf.Variable(0, name='b') c = tf.Variable(0, name='c') var_dict = tensor_utils.to_var_dict([c, a, b]) self.assertEqual(['c', 'a', 'b'], list(var_dict.keys()))