Exemplo n.º 1
0
def _mnist_batch_train(model, batch):
  optimizer = tf.train.GradientDescentOptimizer(0.01)
  model_vars = tf_computation_utils.get_variables('v', _mnist_model_type)
  assign_vars_op = tf_computation_utils.assign(model_vars, model)
  with tf.control_dependencies([assign_vars_op]):
    train_op = optimizer.minimize(_mnist_batch_loss(model_vars, batch))
    with tf.control_dependencies([train_op]):
      return tf_computation_utils.identity(model_vars)
Exemplo n.º 2
0
 def test_get_variables_with_named_tuple_type(self):
     x = tf_computation_utils.get_variables('foo',
                                            [('x', tf.int32),
                                             ('y', tf.string), tf.bool])
     self.assertIsInstance(x, anonymous_tuple.AnonymousTuple)
     self.assertEqual(
         str(x), '<x=<tf.Variable \'foo/x:0\' shape=() dtype=int32>,'
         'y=<tf.Variable \'foo/y:0\' shape=() dtype=string>,'
         '<tf.Variable \'foo/2:0\' shape=() dtype=bool>>')
Exemplo n.º 3
0
 def test_get_variables_with_named_tuple_type(self):
   x = tf_computation_utils.get_variables('foo', [('x', tf.int32),
                                                  ('y', tf.string), tf.bool])
   self.assertIsInstance(x, anonymous_tuple.AnonymousTuple)
   self.assertLen(x, 3)
   self.assertEqual(dir(x), ['x', 'y'])
   self._assertMatchesVariable(x[0], 'foo/x:0', (), tf.int32)
   self._assertMatchesVariable(x[1], 'foo/y:0', (), tf.string)
   self._assertMatchesVariable(x[2], 'foo/2:0', (), tf.bool)
Exemplo n.º 4
0
 def test_get_variables_with_tensor_type(self):
     x = tf_computation_utils.get_variables('foo', tf.int32)
     self.assertIsInstance(x, tf.Variable)
     self.assertIs(x.dtype.base_dtype, tf.int32)
     self.assertEqual(x.shape, tf.TensorShape([]))
     self.assertEqual(str(x.name), 'foo:0')