Beispiel #1
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()

    if graph._c_graph:  # pylint: disable=protected-access
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # Save unique prefix generated by name_scope
            if scope:
                assert scope.endswith('/')
                prefix = scope[:-1]
            else:
                prefix = ''

            # Generate any input map tensors inside name scope
            input_map = _ConvertInputMapValues(name, input_map)

        scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
        options = scoped_options.options
        _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                         return_elements)

        # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
        # call cannot occur between creating the TF_Operations in the
        # TF_GraphImportGraphDefWithResults call and mutating the them in
        # _ProcessNewOps.
        with graph._lock:  # pylint: disable=protected-access
            with c_api_util.tf_buffer(
                    graph_def.SerializeToString()) as serialized:
                try:
                    with errors.raise_exception_on_not_ok_status() as status:
                        results = c_api.TF_GraphImportGraphDefWithResults(
                            graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
                except errors.InvalidArgumentError as e:
                    # Convert to ValueError for backwards compatibility.
                    raise ValueError(str(e))

            _ProcessNewOps(graph)

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        # Treat input mappings that don't appear in the graph as an error, because
        # they are likely to be due to a typo.
        missing_unused_input_keys = (
            c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
                results))
        if missing_unused_input_keys:
            missing_unused_input_keys = [
                compat.as_str(s) for s in missing_unused_input_keys
            ]
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(missing_unused_input_keys))

        if return_elements is None:
            return None
        else:
            return _GatherReturnElements(return_elements, graph, results)

    else:
        g = graph

        # Use a canonical representation for all tensor names.
        input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
        used_input_keys = set()
        name_to_op = {}

        # Add any functions defined in `graph_def` to `g`
        if graph_def.library and graph_def.library.function:
            # Copy op_dict so we don't clobber the original
            op_dict = copy.copy(op_dict)
            # pylint: disable=protected-access
            # Note that we do not prepend `name` to the function name. The reasoning
            # is that function names are similar to op definition names, which
            # currently do not have a scoped name or namespace scheme.
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(g)
                op_dict[f.name] = f.definition.signature
            # pylint: enable=protected-access

        # LINT.IfChange
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # TODO(ashankar): Should this just copy over or should it do some
            # more nuanced merging? For example, the graph may already have some
            # marked "bad versions" and we don't want to lose those because of
            # what's in graph_def.versions? The C++ ImporGraphDef does something
            # more nuanced.
            g.graph_def_versions.CopyFrom(graph_def.versions)

            input_map = _ConvertInputMapValues(name, input_map)

            # NOTE(mrry): We do this in two passes, because there may be a cycle in
            # `graph_def`.

            # 1. Add operations without their inputs.
            for node in graph_def.node:
                # Check to see if this op's name matches a previously seen op
                if node.name in name_to_op:
                    raise ValueError('Duplicate name \'%s\' in GraphDef.' %
                                     node.name)
                if node.op not in op_dict:
                    raise ValueError('No op named %s in defined operations.' %
                                     node.op)
                op_def = op_dict[node.op]

                output_types = _OutputTypes(node, op_dict)
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False,
                                                    compute_device=False,
                                                    op_def=op_def)

            # Maps from a node to the ops it is colocated with, if colocation
            # is specified in the attributes.
            colocation_pairs = collections.defaultdict(list)

            # 2. Add inputs to the operations.
            for node in graph_def.node:
                op = name_to_op[node.name]
                input_types = _InputTypes(node, op_dict)
                apply_device_function = True

                # Rewrite the colocation attributes in the graph, since the
                # names of new ops may have changed.
                for key, value in op.node_def.attr.items():
                    if key == '_class':
                        class_values = value.list
                        new_class_values = []
                        for class_value in class_values.s:
                            if class_value.startswith(b'loc:@'):
                                op_to_bind_to = class_value[5:].decode()
                                # Find the op by its original name.
                                if op_to_bind_to not in name_to_op:
                                    raise ValueError(
                                        'Specified colocation to an op that '
                                        'does not exist during import: %s in %s'
                                        % (op_to_bind_to, node.name))
                                original_op = name_to_op[op_to_bind_to]
                                new_class_values.append(
                                    compat.as_bytes('loc:@' +
                                                    original_op.name))
                                if op_to_bind_to != node.name:
                                    # Keep track of this mapping for a later phase.
                                    colocation_pairs[op].append(original_op)
                                    # Don't apply this op's device function,
                                    # the colocation constraint will ensure
                                    # the proper device gets assigned at runtime.
                                    apply_device_function = False

                            else:
                                new_class_values.append(class_value)
                        value.list.CopyFrom(
                            attr_value_pb2.AttrValue.ListValue(
                                s=new_class_values))

                # NOTE(mrry): We cannot use zip here because control inputs do not
                # appear in the list of input_types.
                for i, input_name in enumerate(
                    [_CanonicalInputName(x) for x in node.input]):

                    if _IsControlInput(input_name):
                        # (a) Input is a control input that should be taken from an op
                        #     in "graph_def".
                        try:
                            source_op = name_to_op[input_name[1:]]
                        except KeyError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Control input %r not found in graph_def.'
                                    % (input_name, )))
                        # pylint: disable=protected-access
                        op._add_control_input(source_op)
                        # pylint: enable=protected-access

                    else:
                        try:
                            input_type = input_types[i]
                        except IndexError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'More inputs specified (%r) than the op expects.'
                                    % (input_name, )))

                        if input_name in input_map:
                            # (b) Input should be replaced by a tensor from the caller.
                            source_tensor = input_map[input_name]
                            used_input_keys.add(input_name)

                        else:
                            # (c) Input should be taken from an op in `graph_def`.
                            operation_name, output_index = _ParseTensorName(
                                input_name)
                            try:
                                source_op = name_to_op[operation_name]
                                source_tensor = list(
                                    source_op.values())[output_index]
                            except (KeyError, IndexError):
                                raise ValueError(
                                    _InvalidNodeMessage(
                                        node,
                                        'Input tensor %r not found in graph_def.'
                                        % (input_name, )))

                        try:
                            # pylint: disable=protected-access
                            op._add_input(source_tensor, dtype=input_type)
                            # pylint: enable=protected-access
                        except TypeError as te:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r %s' % (input_name, te)))

                # pylint: disable=protected-access
                if op._input_types != input_types:
                    raise ValueError(
                        _InvalidNodeMessage(
                            node,
                            'Input types mismatch (expected %r but got %r)' %
                            (', '.join(
                                dtypes.as_dtype(x).name
                                for x in input_types), ', '.join(
                                    x.name for x in op._input_types))))
                # pylint: enable=protected-access

                if not g._is_function(op.type):  # pylint: disable=protected-access
                    # Execute shape inference for this op.
                    # NOTE(mrry): If the graph contains a cycle, the full shape
                    # information may not be available for this op's inputs.
                    ops.set_shapes_for_outputs(op)
                # For nodes with _output_shapes set, set the output shapes.
                if '_output_shapes' in op.node_def.attr:
                    for i, output in enumerate(op.outputs):
                        dims = op.node_def.attr['_output_shapes'].list.shape[i]
                        output_shape = tensor_shape.TensorShape(
                            None if dims.unknown_rank else [
                                dim.size if dim.size >= 0 else None
                                for dim in dims.dim
                            ])

                        try:
                            output.set_shape(output_shape)
                        except ValueError as e:
                            # If the output shape is incompatible with what is inferred
                            # by the graph for a very specific whitelist of ops, then we
                            # ignore this output shape.  This can happen if there is a
                            # bug in the shape function for some operation, and the
                            # serialized graph def has the incorrect shape set when
                            # running on a newer binary with the fixed shape function.
                            # This is an escape hatch that allows us to correct shape
                            # functions that are not critical to correct execution but
                            # would cause graphs to fail if imported after correcting.
                            #
                            # This can be removed after 2017/03/08.
                            if op.type in [
                                    'RandomShuffleQueue', 'PaddingFIFOQueue',
                                    'FIFOQueue', 'PriorityQueue', 'QueueSize',
                                    'Stack', 'Barrier', 'BarrierReadySize',
                                    'BarrierIncompleteSize', 'HashTable',
                                    'MutableHashTable',
                                    'MutableHashTableOfTensors', 'Mutex',
                                    'CuckooTable', 'IndexTable',
                                    'WholeFileReader', 'TextLineReader',
                                    'FixedLengthRecordReader',
                                    'TFRecordReader', 'IdentityReader',
                                    'LMDBReader', 'RefSwitch', 'RefEnter',
                                    'RefNextIteration', 'RefMerge',
                                    'RefIdentity'
                            ]:
                                pass
                            elif op.type in [
                                    'ConditionalAccumulator',
                                    'SparseConditionalAccumulator', 'Table'
                            ]:
                                # This can be removed after 2017/04/24.
                                pass
                            else:
                                raise e

                    del op.node_def.attr['_output_shapes']

                # NOTE(mrry): We do this after configuring the inputs, because
                # the result of the device functions may depend on the inputs.
                if apply_device_function:
                    with _MaybeDevice(node.device):
                        g._apply_device_functions(op)  # pylint: disable=protected-access

            # The following loop populates the device field of ops that are
            # colocated with another op.  This is implied by the colocation
            # attribute, but we propagate the device field for completeness.
            for op, coloc_op_list in colocation_pairs.items():
                coloc_device = None
                # Find any device in the list of colocated ops that have a
                # device, if it exists.  We assume that if multiple ops
                # have devices, they refer to the same device.  Otherwise, a
                # runtime error will occur since the colocation property
                # cannot be guaranteed.
                #
                # One possible improvement is to try to check for compatibility
                # of all devices in this list at import time here, which would
                # require implementing a compatibility function for device specs
                # in python.
                for coloc_op in coloc_op_list:
                    if coloc_op.device:
                        coloc_device = pydev.DeviceSpec.from_string(
                            coloc_op.device)
                        break
                if coloc_device:
                    op._set_device(coloc_device)  # pylint: disable=protected-access

            # Treat input mappings that don't appear in the graph as an error,
            # because they are likely to be due to a typo.
            def _IsImportedNodeOutput(tensor_name):
                operation_name, output_index = _ParseTensorName(tensor_name)
                try:
                    return output_index < len(
                        name_to_op[operation_name].outputs)
                except KeyError:
                    return False

            absent_input_keys = [
                k for k in frozenset(input_map.keys()).difference(
                    used_input_keys) if not _IsImportedNodeOutput(k)
            ]
            if absent_input_keys:
                raise ValueError(
                    'Attempted to map inputs that were not found in graph_def: [%s]'
                    % ', '.join(absent_input_keys))

            if return_elements is None:
                return None
            else:
                ret = []
                for name in return_elements:
                    name = compat.as_str(name)
                    if ':' in name:
                        try:
                            operation_name, output_index = _ParseTensorName(
                                name)
                            ret.append(name_to_op[operation_name].
                                       outputs[output_index])
                        except (ValueError, KeyError, IndexError):
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                    else:
                        try:
                            ret.append(name_to_op[name])
                        except KeyError:
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                return ret
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  `tf.Tensor` and `tf.Operation` objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()
    with ops.name_scope(name, 'import', input_map.values()) as scope:
        # Save unique prefix generated by name_scope
        if scope:
            assert scope.endswith('/')
            prefix = scope[:-1]
        else:
            prefix = ''

        # Generate any input map tensors inside name scope
        input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
    # Session.run call cannot occur between creating the TF_Operations in the
    # TF_GraphImportGraphDefWithResults call and mutating the them in
    # _ProcessNewOps.
    with graph._mutation_lock():  # pylint: disable=protected-access
        with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
            try:
                results = c_api.TF_GraphImportGraphDefWithResults(
                    graph._c_graph, serialized, options)  # pylint: disable=protected-access
                results = c_api_util.ScopedTFImportGraphDefResults(results)
            except errors.InvalidArgumentError as e:
                # Convert to ValueError for backwards compatibility.
                raise ValueError(str(e))

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
        # _USE_C_SHAPES is removed.
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        _ProcessNewOps(graph)

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results.results))
    if missing_unused_input_keys:
        missing_unused_input_keys = [
            compat.as_str(s) for s in missing_unused_input_keys
        ]
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]' %
            ', '.join(missing_unused_input_keys))

    if return_elements is None:
        return None
    else:
        return _GatherReturnElements(return_elements, graph, results.results)