def run_tensorflow(computation_proto, arg=None): """Runs a TensorFlow computation with argument `arg`. Args: computation_proto: An instance of `pb.Computation`. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. Returns: The result of the computation. """ with tf.Graph().as_default() as graph: type_signature = type_serialization.deserialize_type( computation_proto.type) if type_signature.parameter is not None: stamped_arg = _stamp_value_into_graph(arg, type_signature.parameter, graph) else: stamped_arg = None init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation( computation_proto, stamped_arg, graph) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result = tensorflow_utils.fetch_value_in_session(sess, result) return result
def run_tensorflow(comp, arg): """Runs a compiled TensorFlow computation `comp` with argument `arg`. Args: comp: An instance of `building_blocks.CompiledComputation` with embedded TensorFlow code. arg: An instance of `ComputedValue` that represents the argument, or `None` if the compuation expects no argument. Returns: An instance of `ComputedValue` with the result. """ py_typecheck.check_type(comp, building_blocks.CompiledComputation) if arg is not None: py_typecheck.check_type(arg, ComputedValue) with tf.Graph().as_default() as graph: stamped_arg = stamp_computed_value_into_graph(arg, graph) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( comp.proto, stamped_arg, graph)) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result_val = tensorflow_utils.fetch_value_in_session(sess, result) return capture_computed_value_from_graph(result_val, comp.type_signature.result)
def test_deserialize_and_call_tf_computation_with_add_one(self): identity_fn = building_block_factory.create_compiled_identity(tf.int32) init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation( identity_fn.proto, tf.constant(10), tf.compat.v1.get_default_graph()) self.assertTrue(tf.is_tensor(result)) with tf.compat.v1.Session() as sess: if init_op: sess.run(init_op) result_val = sess.run(result) self.assertEqual(result_val, 10)
def invoke(self, comp, arg): # We are invoking a tff.tf_computation inside of another # tf_computation. py_typecheck.check_type(comp, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto(comp) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( computation_proto, arg, self._graph)) if init_op: self._init_ops.append(init_op) return result
def test_deserialize_and_call_tf_computation_with_add_one(self): ctx_stack = context_stack_impl.context_stack add_one = tensorflow_serialization.serialize_py_func_as_tf_computation( lambda x: tf.add(x, 1, name='the_add'), tf.int32, ctx_stack) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( add_one, tf.constant(10, name='the_ten'), tf.get_default_graph())) self.assertTrue(tf.contrib.framework.is_tensor(result)) with tf.Session() as sess: if init_op: sess.run(init_op) result_val = sess.run(result) self.assertEqual(result_val, 11)