示例#1
0
def run_tensorflow(computation_proto, arg=None):
    """Runs a TensorFlow computation with argument `arg`.

  Args:
    computation_proto: An instance of `pb.Computation`.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.

  Returns:
    The result of the computation.
  """
    with tf.Graph().as_default() as graph:
        type_signature = type_serialization.deserialize_type(
            computation_proto.type)
        if type_signature.parameter is not None:
            stamped_arg = _stamp_value_into_graph(arg,
                                                  type_signature.parameter,
                                                  graph)
        else:
            stamped_arg = None
        init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation(
            computation_proto, stamped_arg, graph)
    with tf.compat.v1.Session(graph=graph) as sess:
        if init_op:
            sess.run(init_op)
        result = tensorflow_utils.fetch_value_in_session(sess, result)
    return result
示例#2
0
def run_tensorflow(comp, arg):
  """Runs a compiled TensorFlow computation `comp` with argument `arg`.

  Args:
    comp: An instance of `building_blocks.CompiledComputation` with embedded
      TensorFlow code.
    arg: An instance of `ComputedValue` that represents the argument, or `None`
      if the compuation expects no argument.

  Returns:
    An instance of `ComputedValue` with the result.
  """
  py_typecheck.check_type(comp, building_blocks.CompiledComputation)
  if arg is not None:
    py_typecheck.check_type(arg, ComputedValue)
  with tf.Graph().as_default() as graph:
    stamped_arg = stamp_computed_value_into_graph(arg, graph)
    init_op, result = (
        tensorflow_deserialization.deserialize_and_call_tf_computation(
            comp.proto, stamped_arg, graph))
  with tf.compat.v1.Session(graph=graph) as sess:
    if init_op:
      sess.run(init_op)
    result_val = tensorflow_utils.fetch_value_in_session(sess, result)
  return capture_computed_value_from_graph(result_val,
                                           comp.type_signature.result)
示例#3
0
 def test_deserialize_and_call_tf_computation_with_add_one(self):
   identity_fn = building_block_factory.create_compiled_identity(tf.int32)
   init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation(
       identity_fn.proto, tf.constant(10), tf.compat.v1.get_default_graph())
   self.assertTrue(tf.is_tensor(result))
   with tf.compat.v1.Session() as sess:
     if init_op:
       sess.run(init_op)
     result_val = sess.run(result)
   self.assertEqual(result_val, 10)
 def invoke(self, comp, arg):
     # We are invoking a tff.tf_computation inside of another
     # tf_computation.
     py_typecheck.check_type(comp, computation_base.Computation)
     computation_proto = computation_impl.ComputationImpl.get_proto(comp)
     init_op, result = (
         tensorflow_deserialization.deserialize_and_call_tf_computation(
             computation_proto, arg, self._graph))
     if init_op:
         self._init_ops.append(init_op)
     return result
 def test_deserialize_and_call_tf_computation_with_add_one(self):
   ctx_stack = context_stack_impl.context_stack
   add_one = tensorflow_serialization.serialize_py_func_as_tf_computation(
       lambda x: tf.add(x, 1, name='the_add'), tf.int32, ctx_stack)
   init_op, result = (
       tensorflow_deserialization.deserialize_and_call_tf_computation(
           add_one, tf.constant(10, name='the_ten'), tf.get_default_graph()))
   self.assertTrue(tf.contrib.framework.is_tensor(result))
   with tf.Session() as sess:
     if init_op:
       sess.run(init_op)
     result_val = sess.run(result)
   self.assertEqual(result_val, 11)