def run_tensorflow(computation_proto, arg=None): """Runs a TensorFlow computation with argument `arg`. Args: computation_proto: An instance of `pb.Computation`. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. Returns: The result of the computation. """ with tf.Graph().as_default() as graph: type_signature = type_serialization.deserialize_type( computation_proto.type) if type_signature.parameter is not None: stamped_arg = _stamp_value_into_graph(arg, type_signature.parameter, graph) else: stamped_arg = None init_op, result = tensorflow_utils.deserialize_and_call_tf_computation( computation_proto, stamped_arg, graph) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result = tensorflow_utils.fetch_value_in_session(sess, result) return result
def run_tensorflow(comp, arg): """Runs a compiled TensorFlow computation `comp` with argument `arg`. Args: comp: An instance of `building_blocks.CompiledComputation` with embedded TensorFlow code. arg: An instance of `ComputedValue` that represents the argument, or `None` if the compuation expects no argument. Returns: An instance of `ComputedValue` with the result. """ py_typecheck.check_type(comp, building_blocks.CompiledComputation) if arg is not None: py_typecheck.check_type(arg, ComputedValue) with tf.Graph().as_default() as graph: stamped_arg = stamp_computed_value_into_graph(arg, graph) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( comp.proto, stamped_arg, graph)) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result_val = tensorflow_utils.fetch_value_in_session(sess, result) return capture_computed_value_from_graph(result_val, comp.type_signature.result)
def test_fetch_value_in_session_with_empty_structure(self): x = structure.Struct([ ('a', structure.Struct([ ('b', structure.Struct([])), ])), ]) with tf.compat.v1.Session() as sess: y = tensorflow_utils.fetch_value_in_session(sess, x) self.assertEqual(str(y), '<a=<b=<>>>')
def test_fetch_value_in_session_without_data_sets(self): x = structure.Struct([ ('a', structure.Struct([ ('b', tf.constant(10)), ])), ]) with tf.compat.v1.Session() as sess: y = tensorflow_utils.fetch_value_in_session(sess, x) self.assertEqual(str(y), '<a=<b=10>>')
def test_fetch_value_in_session_without_data_sets(self): x = anonymous_tuple.AnonymousTuple([ ('a', anonymous_tuple.AnonymousTuple([ ('b', tf.constant(10)), ])), ]) with tf.compat.v1.Session() as sess: y = tensorflow_utils.fetch_value_in_session(sess, x) self.assertEqual(str(y), '<a=<b=10>>')
def test_fetch_value_in_session_with_partially_empty_structure(self): x = anonymous_tuple.AnonymousTuple([ ('a', anonymous_tuple.AnonymousTuple([ ('b', anonymous_tuple.AnonymousTuple([])), ('c', tf.constant(10)), ])), ]) with tf.compat.v1.Session() as sess: y = tensorflow_utils.fetch_value_in_session(sess, x) self.assertEqual(str(y), '<a=<b=<>,c=10>>')
def test_fetch_value_in_session_with_string(self): x = tf.constant('abc') with tf.compat.v1.Session() as sess: y = tensorflow_utils.fetch_value_in_session(sess, x) self.assertEqual(str(y), 'abc')