Пример #1
0
 def test_identity_with_unordered_dict(self):
     with tf.Graph().as_default() as graph:
         c1 = {'foo': tf.constant(10, dtype=tf.int32, shape=[])}
         c2 = tf_computation_utils.identity(c1)
     self.assertIsNot(c2, c1)
     with tf.compat.v1.Session(graph=graph) as sess:
         result = sess.run(c2['foo'])
     self.assertEqual(result, 10)
Пример #2
0
 def test_identity_with_no_nesting(self):
     with tf.Graph().as_default() as graph:
         c1 = tf.constant(10, dtype=tf.int32, shape=[])
         c2 = tf_computation_utils.identity(c1)
     self.assertIsNot(c2, c1)
     with tf.compat.v1.Session(graph=graph) as sess:
         result = sess.run(c2)
     self.assertEqual(result, 10)
Пример #3
0
def _mnist_batch_train(model, batch):
  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.01)
  model_vars = tf_computation_utils.create_variables('v', _mnist_model_type)
  assign_vars_op = tf_computation_utils.assign(model_vars, model)
  with tf.control_dependencies([assign_vars_op]):
    train_op = optimizer.minimize(_mnist_batch_loss(model_vars, batch))
    with tf.control_dependencies([train_op]):
      return tf_computation_utils.identity(model_vars)
Пример #4
0
 def test_identity_with_structure(self):
   with tf.Graph().as_default() as graph:
     c1 = structure.Struct([('foo', tf.constant(10, dtype=tf.int32,
                                                shape=[]))])
     c2 = tf_computation_utils.identity(c1)
   self.assertIsNot(c2, c1)
   with tf.compat.v1.Session(graph=graph) as sess:
     result = sess.run(c2.foo)
   self.assertEqual(result, 10)
Пример #5
0
 def test_identity_with_anonymous_tuple(self):
     with tf.Graph().as_default() as graph:
         c1 = anonymous_tuple.AnonymousTuple([('foo',
                                               tf.constant(10,
                                                           dtype=tf.int32,
                                                           shape=[]))])
         c2 = tf_computation_utils.identity(c1)
     self.assertIsNot(c2, c1)
     with tf.Session(graph=graph) as sess:
         result = sess.run(c2.foo)
     self.assertEqual(result, 10)