Exemple #1
0
def run_tensorflow(computation_proto, arg=None):
    """Runs a TensorFlow computation with argument `arg`.

  Args:
    computation_proto: An instance of `pb.Computation`.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.

  Returns:
    The result of the computation.
  """
    with tf.Graph().as_default() as graph:
        type_signature = type_serialization.deserialize_type(
            computation_proto.type)
        if type_signature.parameter is not None:
            stamped_arg = _stamp_value_into_graph(arg,
                                                  type_signature.parameter,
                                                  graph)
        else:
            stamped_arg = None
        init_op, result = tensorflow_utils.deserialize_and_call_tf_computation(
            computation_proto, stamped_arg, graph)
    with tf.compat.v1.Session(graph=graph) as sess:
        if init_op:
            sess.run(init_op)
        result = tensorflow_utils.fetch_value_in_session(sess, result)
    return result
def run_tensorflow(comp, arg):
  """Runs a compiled TensorFlow computation `comp` with argument `arg`.

  Args:
    comp: An instance of `building_blocks.CompiledComputation` with embedded
      TensorFlow code.
    arg: An instance of `ComputedValue` that represents the argument, or `None`
      if the compuation expects no argument.

  Returns:
    An instance of `ComputedValue` with the result.
  """
  py_typecheck.check_type(comp, building_blocks.CompiledComputation)
  if arg is not None:
    py_typecheck.check_type(arg, ComputedValue)
  with tf.Graph().as_default() as graph:
    stamped_arg = stamp_computed_value_into_graph(arg, graph)
    init_op, result = (
        tensorflow_deserialization.deserialize_and_call_tf_computation(
            comp.proto, stamped_arg, graph))
  with tf.compat.v1.Session(graph=graph) as sess:
    if init_op:
      sess.run(init_op)
    result_val = tensorflow_utils.fetch_value_in_session(sess, result)
  return capture_computed_value_from_graph(result_val,
                                           comp.type_signature.result)
Exemple #3
0
 def test_fetch_value_in_session_with_empty_structure(self):
   x = structure.Struct([
       ('a', structure.Struct([
           ('b', structure.Struct([])),
       ])),
   ])
   with tf.compat.v1.Session() as sess:
     y = tensorflow_utils.fetch_value_in_session(sess, x)
   self.assertEqual(str(y), '<a=<b=<>>>')
Exemple #4
0
 def test_fetch_value_in_session_without_data_sets(self):
   x = structure.Struct([
       ('a', structure.Struct([
           ('b', tf.constant(10)),
       ])),
   ])
   with tf.compat.v1.Session() as sess:
     y = tensorflow_utils.fetch_value_in_session(sess, x)
   self.assertEqual(str(y), '<a=<b=10>>')
 def test_fetch_value_in_session_without_data_sets(self):
     x = anonymous_tuple.AnonymousTuple([
         ('a', anonymous_tuple.AnonymousTuple([
             ('b', tf.constant(10)),
         ])),
     ])
     with tf.compat.v1.Session() as sess:
         y = tensorflow_utils.fetch_value_in_session(sess, x)
     self.assertEqual(str(y), '<a=<b=10>>')
 def test_fetch_value_in_session_with_partially_empty_structure(self):
   x = anonymous_tuple.AnonymousTuple([
       ('a',
        anonymous_tuple.AnonymousTuple([
            ('b', anonymous_tuple.AnonymousTuple([])),
            ('c', tf.constant(10)),
        ])),
   ])
   with tf.compat.v1.Session() as sess:
     y = tensorflow_utils.fetch_value_in_session(sess, x)
   self.assertEqual(str(y), '<a=<b=<>,c=10>>')
Exemple #7
0
 def test_fetch_value_in_session_with_string(self):
   x = tf.constant('abc')
   with tf.compat.v1.Session() as sess:
     y = tensorflow_utils.fetch_value_in_session(sess, x)
   self.assertEqual(str(y), 'abc')