Exemplo n.º 1
0
 def test_extract_tensor_names_from_binding_with_tuple_of_tensors(self):
     with tf.Graph().as_default() as graph:
         _, binding = graph_utils.capture_result_from_graph(
             collections.OrderedDict([('foo', tf.constant(10, name='A')),
                                      ('bar', tf.constant(20, name='B'))]),
             graph)
     result = graph_utils.extract_tensor_names_from_binding(binding)
     self.assertEqual(str(sorted(result)), '[\'A:0\', \'B:0\']')
Exemplo n.º 2
0
def prune_tensorflow_proto(proto):
    """Extracts subgraph from `proto` preserving parameter, result and initialize.

  Args:
    proto: Instance of `pb.Computation` of the `tensorflow` variety whose
      `graphdef` attribute we wish to prune of extraneous ops.

  Returns:
    A transformed instance of `pb.Computation` of the `tensorflow` variety,
    whose `graphdef` attribute contains only ops which can reach the
    parameter or result bindings, or initialize op.
  """
    py_typecheck.check_type(proto, pb.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError(
            '`prune_tensorflow_proto` only accepts `Computation` '
            'protos of the \'tensorflow\' variety; you have passed '
            'one of variety {}.'.format(computation_oneof))
    if proto.tensorflow.parameter.WhichOneof('binding'):
        parameter_tensor_names = graph_utils.extract_tensor_names_from_binding(
            proto.tensorflow.parameter)
        parameter_names = [
            ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names
        ]
    else:
        parameter_names = []
    return_tensor_names = graph_utils.extract_tensor_names_from_binding(
        proto.tensorflow.result)
    return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names]
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    init_op_name = proto.tensorflow.initialize_op
    names_to_preserve = parameter_names + return_names
    if init_op_name:
        names_to_preserve.append(init_op_name)
    subgraph_def = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def, names_to_preserve)
    tf_block = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(subgraph_def),
        initialize_op=proto.tensorflow.initialize_op,
        parameter=proto.tensorflow.parameter,
        result=proto.tensorflow.result)
    pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block)
    return pruned_proto
Exemplo n.º 3
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since
    # it deals exclusively with eager mode. Incubate here, and potentially move
    # there, once stable.

    if device is not None:
        raise NotImplementedError(
            'Unable to embed TF code on a specific device.')

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    str(type_spec), str(comp_type)))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = graph_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = graph_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                str(len(input_tensor_names)), str(len(args))))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        init_names = [init_op] if init_op else []
        returned_elements = tf.import_graph_def(
            graph_merge.uniquify_shared_names(graph_def),
            input_map=dict(zip(input_tensor_names, args)),
            return_elements=output_tensor_names + init_names)
        if init_names:
            with tf.control_dependencies([returned_elements[-1]]):
                return [tf.identity(x) for x in returned_elements[0:-1]]
        else:
            return returned_elements

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    str(len(param_fns)), str(len(arg_parts))))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)
    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
def deserialize_and_call_tf_computation(computation_proto, arg, graph):
    """Deserializes a TF computation and inserts it into `graph`.

  This method performs an action that can be considered roughly the opposite of
  what `tensorflow_serialization.serialize_py_fn_as_tf_computation` does. At
  the moment, it simply imports the graph in the current context. A future
  implementation may rely on different mechanisms. The caller should not be
  concerned with the specifics of the implementation. At this point, the method
  is expected to only be used within the body of another TF computation (within
  an instance of `tf_computation_context.TensorFlowComputationContext` at the
  top of the stack), and potentially also in certain types of interpreted
  execution contexts (TBD).

  Args:
    computation_proto: An instance of `pb.Computation` with the `computation`
      one of equal to `tensorflow` to be deserialized and called.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.
    graph: The graph to stamp into.

  Returns:
    A tuple (init_op, result) where:
       init_op:  String name of an op to initialize the graph.
       result: The results to be fetched from TensorFlow. Depending on
           the type of the result, this can be `tf.Tensor` or `tf.data.Dataset`
           instances, or a nested structure (such as an
           `anonymous_tuple.AnonymousTuple`).

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If `computation_proto` is not a TensorFlow computation proto.
  """
    py_typecheck.check_type(computation_proto, pb.Computation)
    computation_oneof = computation_proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise ValueError('Expected a TensorFlow computation, got {}.'.format(
            computation_oneof))
    py_typecheck.check_type(graph, tf.Graph)
    with graph.as_default():
        type_spec = type_serialization.deserialize_type(computation_proto.type)
        if type_spec.parameter is None:
            if arg is None:
                input_map = None
            else:
                raise TypeError(
                    'The computation declared no parameters; encountered an unexpected '
                    'argument {}.'.format(str(arg)))
        elif arg is None:
            raise TypeError(
                'The computation declared a parameter of type {}, but the argument '
                'was not supplied.'.format(str(type_spec.parameter)))
        else:
            arg_type, arg_binding = graph_utils.capture_result_from_graph(
                arg, graph)
            if not type_utils.is_assignable_from(type_spec.parameter,
                                                 arg_type):
                raise TypeError(
                    'The computation declared a parameter of type {}, but the argument '
                    'is of a mismatching type {}.'.format(
                        str(type_spec.parameter), str(arg_type)))
            else:
                input_map = {
                    k: graph.get_tensor_by_name(v)
                    for k, v in six.iteritems(
                        graph_utils.compute_map_from_bindings(
                            computation_proto.tensorflow.parameter,
                            arg_binding))
                }
        return_elements = graph_utils.extract_tensor_names_from_binding(
            computation_proto.tensorflow.result)
        orig_init_op_name = computation_proto.tensorflow.initialize_op
        if orig_init_op_name:
            return_elements.append(orig_init_op_name)
        # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information
        # about collections, and hence, when we import a graph with Variables,
        # those Variables are not added to global collections, and hence
        # functions like tf.compat.v1.global_variables_initializers() will not
        # contain their initialization ops.
        output_tensors = tf.import_graph_def(
            serialization_utils.unpack_graph_def(
                computation_proto.tensorflow.graph_def),
            input_map,
            return_elements,
            # N. B. It is very important not to return any names from the original
            # computation_proto.tensorflow.graph_def, those names might or might not
            # be valid in the current graph. Using a different scope makes the graph
            # somewhat more readable, since _N style de-duplication of graph
            # node names is less likely to be needed.
            name='subcomputation')

        output_map = {k: v for k, v in zip(return_elements, output_tensors)}
        new_init_op_name = output_map.pop(orig_init_op_name, None)
        return (new_init_op_name,
                graph_utils.assemble_result_from_graph(
                    type_spec.result, computation_proto.tensorflow.result,
                    output_map))
Exemplo n.º 5
0
 def test_extract_tensor_names_from_binding_with_sequence(self):
     binding = pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding(
         iterator_string_handle_name='foo'))
     result = graph_utils.extract_tensor_names_from_binding(binding)
     self.assertEqual(str(sorted(result)), '[\'foo\']')