Пример #1
0
 def test_assemble_result_from_graph_with_named_tuple(self):
   test_named_tuple = collections.namedtuple('_', 'X Y')
   type_spec = test_named_tuple(tf.int32, tf.int32)
   binding = pb.TensorFlow.Binding(
       struct=pb.TensorFlow.StructBinding(element=[
           pb.TensorFlow.Binding(
               tensor=pb.TensorFlow.TensorBinding(tensor_name='P')),
           pb.TensorFlow.Binding(
               tensor=pb.TensorFlow.TensorBinding(tensor_name='Q'))
       ]))
   tensor_a = tf.constant(1, name='A')
   tensor_b = tf.constant(2, name='B')
   output_map = {'P': tensor_a, 'Q': tensor_b}
   result = tensorflow_utils.assemble_result_from_graph(
       type_spec, binding, output_map)
   self.assertIsInstance(result, test_named_tuple)
   self.assertEqual(result.X, tensor_a)
   self.assertEqual(result.Y, tensor_b)
Пример #2
0
 def test_assemble_result_from_graph_with_sequence_of_odicts(self):
   type_spec = computation_types.SequenceType(
       collections.OrderedDict([('X', tf.int32), ('Y', tf.int32)]))
   binding = pb.TensorFlow.Binding(
       sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
   data_set = tf.data.Dataset.from_tensors({
       'X': tf.constant(1),
       'Y': tf.constant(2)
   })
   output_map = {'foo': tf.data.experimental.to_variant(data_set)}
   result = tensorflow_utils.assemble_result_from_graph(
       type_spec, binding, output_map)
   self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
   self.assertEqual(
       str(tf.compat.v1.data.get_output_types(result)),
       'OrderedDict([(\'X\', tf.int32), (\'Y\', tf.int32)])')
   self.assertEqual(
       str(tf.compat.v1.data.get_output_shapes(result)),
       'OrderedDict([(\'X\', TensorShape([])), (\'Y\', TensorShape([]))])')
Пример #3
0
 def test_assemble_result_from_graph_with_sequence_of_namedtuples(self):
   named_tuple_type = collections.namedtuple('TestNamedTuple', 'X Y')
   type_spec = computation_types.SequenceType(
       named_tuple_type(tf.int32, tf.int32))
   binding = pb.TensorFlow.Binding(
       sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
   data_set = tf.data.Dataset.from_tensors({
       'X': tf.constant(1),
       'Y': tf.constant(2)
   })
   output_map = {'foo': tf.data.experimental.to_variant(data_set)}
   result = tensorflow_utils.assemble_result_from_graph(
       type_spec, binding, output_map)
   self.assertIsInstance(result, tensorflow_utils.DATASET_REPRESENTATION_TYPES)
   self.assertEqual(
       str(tf.compat.v1.data.get_output_types(result)),
       'TestNamedTuple(X=tf.int32, Y=tf.int32)')
   self.assertEqual(
       str(tf.compat.v1.data.get_output_shapes(result)),
       'TestNamedTuple(X=TensorShape([]), Y=TensorShape([]))')
Пример #4
0
 def test_assemble_result_from_graph_with_sequence_of_odicts(self):
   type_spec = computation_types.SequenceType(
       collections.OrderedDict([('X', tf.int32), ('Y', tf.int32)]))
   binding = pb.TensorFlow.Binding(
       sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
   data_set = tf.data.Dataset.from_tensors({
       'X': tf.constant(1),
       'Y': tf.constant(2)
   })
   output_map = {'foo': tf.data.experimental.to_variant(data_set)}
   result = tensorflow_utils.assemble_result_from_graph(
       type_spec, binding, output_map)
   self.assertIsInstance(result,
                         type_conversions.TF_DATASET_REPRESENTATION_TYPES)
   self.assertEqual(
       result.element_spec,
       collections.OrderedDict([
           ('X', tf.TensorSpec(shape=(), dtype=tf.int32)),
           ('Y', tf.TensorSpec(shape=(), dtype=tf.int32)),
       ]),
   )
def deserialize_and_call_tf_computation(computation_proto, arg, graph):
  """Deserializes a TF computation and inserts it into `graph`.

  This method performs an action that can be considered roughly the opposite of
  what `tensorflow_serialization.serialize_py_fn_as_tf_computation` does. At
  the moment, it simply imports the graph in the current context. A future
  implementation may rely on different mechanisms. The caller should not be
  concerned with the specifics of the implementation. At this point, the method
  is expected to only be used within the body of another TF computation (within
  an instance of `tf_computation_context.TensorFlowComputationContext` at the
  top of the stack), and potentially also in certain types of interpreted
  execution contexts (TBD).

  Args:
    computation_proto: An instance of `pb.Computation` with the `computation`
      one of equal to `tensorflow` to be deserialized and called.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.
    graph: The graph to stamp into.

  Returns:
    A tuple (init_op, result) where:
       init_op:  String name of an op to initialize the graph.
       result: The results to be fetched from TensorFlow. Depending on
           the type of the result, this can be `tf.Tensor` or `tf.data.Dataset`
           instances, or a nested structure (such as an
           `anonymous_tuple.AnonymousTuple`).

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If `computation_proto` is not a TensorFlow computation proto.
  """
  py_typecheck.check_type(computation_proto, pb.Computation)
  computation_oneof = computation_proto.WhichOneof('computation')
  if computation_oneof != 'tensorflow':
    raise ValueError(
        'Expected a TensorFlow computation, got {}.'.format(computation_oneof))
  py_typecheck.check_type(graph, tf.Graph)
  with graph.as_default():
    type_spec = type_serialization.deserialize_type(computation_proto.type)
    if type_spec.parameter is None:
      if arg is None:
        input_map = None
      else:
        raise TypeError(
            'The computation declared no parameters; encountered an unexpected '
            'argument {}.'.format(arg))
    elif arg is None:
      raise TypeError(
          'The computation declared a parameter of type {}, but the argument '
          'was not supplied.'.format(type_spec.parameter))
    else:
      arg_type, arg_binding = tensorflow_utils.capture_result_from_graph(
          arg, graph)
      if not type_utils.is_assignable_from(type_spec.parameter, arg_type):
        raise TypeError(
            'The computation declared a parameter of type {}, but the argument '
            'is of a mismatching type {}.'.format(type_spec.parameter,
                                                  arg_type))
      else:
        input_map = {
            k: graph.get_tensor_by_name(v) for k, v in six.iteritems(
                tensorflow_utils.compute_map_from_bindings(
                    computation_proto.tensorflow.parameter, arg_binding))
        }
    return_elements = tensorflow_utils.extract_tensor_names_from_binding(
        computation_proto.tensorflow.result)
    orig_init_op_name = computation_proto.tensorflow.initialize_op
    if orig_init_op_name:
      return_elements.append(orig_init_op_name)
    # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information
    # about collections, and hence, when we import a graph with Variables,
    # those Variables are not added to global collections, and hence
    # functions like tf.compat.v1.global_variables_initializers() will not
    # contain their initialization ops.
    output_tensors = tf.import_graph_def(
        serialization_utils.unpack_graph_def(
            computation_proto.tensorflow.graph_def),
        input_map,
        return_elements,
        # N. B. It is very important not to return any names from the original
        # computation_proto.tensorflow.graph_def, those names might or might not
        # be valid in the current graph. Using a different scope makes the graph
        # somewhat more readable, since _N style de-duplication of graph
        # node names is less likely to be needed.
        name='subcomputation')

    output_map = {k: v for k, v in zip(return_elements, output_tensors)}
    new_init_op_name = output_map.pop(orig_init_op_name, None)
    return (new_init_op_name,
            tensorflow_utils.assemble_result_from_graph(
                type_spec.result, computation_proto.tensorflow.result,
                output_map))