コード例 #1
0
 def test_extract_tensor_names_from_binding_with_sequence(self):
   binding = pb.TensorFlow.Binding(
       sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
   result = tensorflow_utils.extract_tensor_names_from_binding(binding)
   self.assertEqual(str(sorted(result)), '[\'foo\']')
コード例 #2
0
def _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu, param_type,
                                    device):
    """Extracts the TensorFlow function from serialized computation.

  Args:
    comp: An instance of `pb.Computation`.
    must_pin_function_to_cpu: A boolean flag to indicate if the computation is
      forced to be on CPUs.
    param_type: A `tff.Type` instance or None.
    device: A `tf.config.LogicalDevice` or None.

  Returns:
    A TensorFlow ConcreteFunction.
  """
    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        # TODO(b/159180073): clean raise after fixing dataset reduce.
        _check_dataset_reduce_in_multi_gpu(graph_def)

        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()

    wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap,
                                                  signature=[])

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []
    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)
    import_graph = wrapped_noarg_fn.graph
    try:
        wrapped_fn = wrapped_noarg_fn.prune(
            feeds=tf.nest.map_structure(import_graph.as_graph_element,
                                        input_tensor_names),
            fetches=tf.nest.map_structure(import_graph.as_graph_element,
                                          output_tensor_names),
        )
    except KeyError as e:
        raise TypeError(
            'Caught exception trying to prune graph `{g}` with '
            'feeds {feeds} and fetches {fetches}. This indicates that these '
            'names may not refer to tensors in the graph. .\nException: {e}'.
            format(g=import_graph,
                   feeds=input_tensor_names,
                   fetches=output_tensor_names,
                   e=e))
    return wrapped_fn
コード例 #3
0
def deserialize_and_call_tf_computation(computation_proto, arg, graph):
  """Deserializes a TF computation and inserts it into `graph`.

  This method performs an action that can be considered roughly the opposite of
  what `tensorflow_serialization.serialize_py_fn_as_tf_computation` does. At
  the moment, it simply imports the graph in the current context. A future
  implementation may rely on different mechanisms. The caller should not be
  concerned with the specifics of the implementation. At this point, the method
  is expected to only be used within the body of another TF computation (within
  an instance of `tf_computation_context.TensorFlowComputationContext` at the
  top of the stack), and potentially also in certain types of interpreted
  execution contexts (TBD).

  Args:
    computation_proto: An instance of `pb.Computation` with the `computation`
      one of equal to `tensorflow` to be deserialized and called.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.
    graph: The graph to stamp into.

  Returns:
    A tuple (init_op, result) where:
       init_op:  String name of an op to initialize the graph.
       result: The results to be fetched from TensorFlow. Depending on
           the type of the result, this can be `tf.Tensor` or `tf.data.Dataset`
           instances, or a nested structure (such as an
           `anonymous_tuple.AnonymousTuple`).

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If `computation_proto` is not a TensorFlow computation proto.
  """
  py_typecheck.check_type(computation_proto, pb.Computation)
  computation_oneof = computation_proto.WhichOneof('computation')
  if computation_oneof != 'tensorflow':
    raise ValueError(
        'Expected a TensorFlow computation, got {}.'.format(computation_oneof))
  py_typecheck.check_type(graph, tf.Graph)
  with graph.as_default():
    type_spec = type_serialization.deserialize_type(computation_proto.type)
    if type_spec.parameter is None:
      if arg is None:
        input_map = None
      else:
        raise TypeError(
            'The computation declared no parameters; encountered an unexpected '
            'argument {}.'.format(arg))
    elif arg is None:
      raise TypeError(
          'The computation declared a parameter of type {}, but the argument '
          'was not supplied.'.format(type_spec.parameter))
    else:
      arg_type, arg_binding = tensorflow_utils.capture_result_from_graph(
          arg, graph)
      if not type_utils.is_assignable_from(type_spec.parameter, arg_type):
        raise TypeError(
            'The computation declared a parameter of type {}, but the argument '
            'is of a mismatching type {}.'.format(type_spec.parameter,
                                                  arg_type))
      else:
        input_map = {
            k: graph.get_tensor_by_name(v) for k, v in six.iteritems(
                tensorflow_utils.compute_map_from_bindings(
                    computation_proto.tensorflow.parameter, arg_binding))
        }
    return_elements = tensorflow_utils.extract_tensor_names_from_binding(
        computation_proto.tensorflow.result)
    orig_init_op_name = computation_proto.tensorflow.initialize_op
    if orig_init_op_name:
      return_elements.append(orig_init_op_name)
    # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information
    # about collections, and hence, when we import a graph with Variables,
    # those Variables are not added to global collections, and hence
    # functions like tf.compat.v1.global_variables_initializers() will not
    # contain their initialization ops.
    output_tensors = tf.import_graph_def(
        serialization_utils.unpack_graph_def(
            computation_proto.tensorflow.graph_def),
        input_map,
        return_elements,
        # N. B. It is very important not to return any names from the original
        # computation_proto.tensorflow.graph_def, those names might or might not
        # be valid in the current graph. Using a different scope makes the graph
        # somewhat more readable, since _N style de-duplication of graph
        # node names is less likely to be needed.
        name='subcomputation')

    output_map = {k: v for k, v in zip(return_elements, output_tensors)}
    new_init_op_name = output_map.pop(orig_init_op_name, None)
    return (new_init_op_name,
            tensorflow_utils.assemble_result_from_graph(
                type_spec.result, computation_proto.tensorflow.result,
                output_map))
コード例 #4
0
def import_tensorflow_computation(comp, name='fn'):
    """Creates a `computation_module.ComputationModule` from a TF computation.

  WARNING: This helper function is under construction, and most capabilities are
  not implemented at this stage:

  * The parameter and result of `comp` can only be a single tensor. Named
    tuples, sequences, or functional types are not currently supported.

  * Only tensorflow code can be imported.

  TODO(b/153499219): Add support for named tuples, sequences, and functions.

  Args:
    comp: An instance of a `pb.Computation` with TensorFlow code to import.
    name: An optional `str` name of the (single) function in the IREE module.

  Returns:
    An instance of `Module` with the imported function present.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    py_typecheck.check_type(comp, pb.Computation)
    type_spec = type_serialization.deserialize_type(comp.type)
    if not type_spec.is_function():
        type_spec = computation_types.FunctionType(None, type_spec)

    # TODO(b/153499219): Replace this with a recursive check of the signature
    # after relaxing the type restrictions and introducing nested structures.
    py_typecheck.check_type(type_spec.result, computation_types.TensorType)
    if type_spec.parameter is not None:
        py_typecheck.check_type(type_spec.parameter,
                                computation_types.TensorType)

    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)
    if type_spec.parameter is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    return_elements = input_tensor_names + output_tensor_names
    if init_op:
        graph_def = tensorflow_utils.add_control_deps_for_init_op(
            graph_def, init_op)
        return_elements.append(init_op)

    with tf.Graph().as_default() as graph:
        # TODO(b/153499219): See if we can reintroduce uniquify_shared_names().
        # Right now, it causes loader breakage, and unclear if still necessary.
        import_results = tf.graph_util.import_graph_def(
            graph_def, input_map={}, return_elements=return_elements, name='')

    if init_op:
        initializer = import_results[-1]
        import_results.pop()
    else:
        initializer = None

    inputs = import_results[0:len(input_tensor_names)]
    outputs = import_results[len(input_tensor_names):]

    with graph.as_default():
        # TODO(b/153499219): Find a way to reflect the nested parameter and result
        # structure here after relaxing the restrictions.
        if inputs:
            assert len(inputs) < 2
            input_dict = {
                'parameter':
                tf.compat.v1.saved_model.utils.build_tensor_info(inputs[0])
            }
        else:
            input_dict = {}
        assert len(outputs) == 1
        output_dict = {
            'result':
            tf.compat.v1.saved_model.utils.build_tensor_info(outputs[0])
        }
        sig_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs=input_dict, outputs=output_dict, method_name=name)
        with tempfile.TemporaryDirectory() as model_dir:
            builder = tf.compat.v1.saved_model.Builder(model_dir)
            with tf.compat.v1.Session(graph=graph) as sess:
                builder.add_meta_graph_and_variables(
                    sess, ['unused'],
                    signature_def_map={name: sig_def},
                    legacy_init_op=initializer,
                    strip_default_attrs=True)
                builder.save()
            iree_module = iree.compiler.tf.compile_saved_model(
                model_dir,
                import_type='SIGNATURE_DEF',
                import_only=True,
                saved_model_tags=set(['unused']),
                exported_names=[name])
            return computation_module.ComputationModule(
                iree_module, name, type_spec)
コード例 #5
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    if device is not None:
        old_fn_to_return = fn_to_return

        # pylint: disable=function-redefined
        def fn_to_return(x):
            with tf.device(device):
                return old_fn_to_return(x)

        # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
コード例 #6
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_spec.is_equivalent_to(comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    # TODO(b/155198591): Currently, TF will raise on any function returning a
    # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
    # this gating when we can.
    must_pin_function_to_cpu = type_analysis.contains_types(
        type_spec.result, computation_types.SequenceType)
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()

    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap,
                                                  signature=[])
    import_graph = wrapped_noarg_fn.graph
    try:
        wrapped_fn = wrapped_noarg_fn.prune(
            feeds=tf.nest.map_structure(import_graph.as_graph_element,
                                        input_tensor_names),
            fetches=tf.nest.map_structure(import_graph.as_graph_element,
                                          output_tensor_names),
        )
    except KeyError as e:
        raise TypeError(
            'Caught exception trying to prune graph `{g}` with '
            'feeds {feeds} and fetches {fetches}. This indicates that these '
            'names may not refer to tensors in the graph. .\nException: {e}'.
            format(g=import_graph,
                   feeds=input_tensor_names,
                   fetches=output_tensor_names,
                   e=e))

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_conversions.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    # pylint: disable=function-redefined
    if must_pin_function_to_cpu:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device('cpu'):
                return old_fn_to_return(x)
    elif device is not None:
        old_fn_to_return = fn_to_return

        def fn_to_return(x):
            with tf.device(device.name):
                return old_fn_to_return(x)

    # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)