def _mnist_batch_train(model, batch): optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.01) model_vars = tf_computation_utils.create_variables('v', _mnist_model_type) assign_vars_op = tf_computation_utils.assign(model_vars, model) with tf.control_dependencies([assign_vars_op]): train_op = optimizer.minimize(_mnist_batch_loss(model_vars, batch)) with tf.control_dependencies([train_op]): return tf_computation_utils.identity(model_vars)
def test_create_variables_with_named_tuple_type(self): x = tf_computation_utils.create_variables('foo', [('x', tf.int32), ('y', tf.string), tf.bool]) self.assertIsInstance(x, anonymous_tuple.AnonymousTuple) self.assertLen(x, 3) self.assertEqual(dir(x), ['x', 'y']) self._assertMatchesVariable(x[0], 'foo/x:0', (), tf.int32) self._assertMatchesVariable(x[1], 'foo/y:0', (), tf.string) self._assertMatchesVariable(x[2], 'foo/2:0', (), tf.bool)
def test_create_variables_with_tensor_type(self): x = tf_computation_utils.create_variables('foo', tf.int32) self.assertIsInstance(x, tf.Variable) self.assertIs(x.dtype.base_dtype, tf.int32) self.assertEqual(x.shape, tf.TensorShape([])) self.assertEqual(str(x.name), 'foo:0')