示例#1
0
def _mnist_batch_train(model, batch):
  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.01)
  model_vars = tf_computation_utils.create_variables('v', _mnist_model_type)
  assign_vars_op = tf_computation_utils.assign(model_vars, model)
  with tf.control_dependencies([assign_vars_op]):
    train_op = optimizer.minimize(_mnist_batch_loss(model_vars, batch))
    with tf.control_dependencies([train_op]):
      return tf_computation_utils.identity(model_vars)
示例#2
0
 def test_create_variables_with_named_tuple_type(self):
     x = tf_computation_utils.create_variables('foo',
                                               [('x', tf.int32),
                                                ('y', tf.string), tf.bool])
     self.assertIsInstance(x, anonymous_tuple.AnonymousTuple)
     self.assertLen(x, 3)
     self.assertEqual(dir(x), ['x', 'y'])
     self._assertMatchesVariable(x[0], 'foo/x:0', (), tf.int32)
     self._assertMatchesVariable(x[1], 'foo/y:0', (), tf.string)
     self._assertMatchesVariable(x[2], 'foo/2:0', (), tf.bool)
示例#3
0
 def test_create_variables_with_tensor_type(self):
     x = tf_computation_utils.create_variables('foo', tf.int32)
     self.assertIsInstance(x, tf.Variable)
     self.assertIs(x.dtype.base_dtype, tf.int32)
     self.assertEqual(x.shape, tf.TensorShape([]))
     self.assertEqual(str(x.name), 'foo:0')