Exemplo n.º 1
0
 def __init__(self, object_graph_proto, saved_model_proto, export_dir):
     meta_graph = saved_model_proto.meta_graphs[0]
     self._asset_file_def = meta_graph.asset_file_def
     self._proto = object_graph_proto
     self._export_dir = export_dir
     self._defined_functions = {}
     for defined_function in function_lib.from_library(
             meta_graph.graph_def.library):
         # TODO(allenl): Do we need to do name mapping here? Not quite sure what
         # happens when loaded names collide with existing names.
         defined_function.add_to_graph(None)
         self._defined_functions[defined_function.name] = defined_function
     self._load_all()
Exemplo n.º 2
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  `tf.Tensor` and `tf.Operation` objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`,
    and None if `returns_elements` is None.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  op_dict = op_def_registry.get_registered_ops()

  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # Save unique prefix generated by name_scope
    if scope:
      assert scope.endswith('/')
      prefix = scope[:-1]
    else:
      prefix = ''

    # Generate any input map tensors inside name scope
    input_map = _ConvertInputMapValues(name, input_map)

  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
  options = scoped_options.options
  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                   return_elements)

  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
  # Session.run call cannot occur between creating the TF_Operations in the
  # TF_GraphImportGraphDefWithResults call and mutating the them in
  # _ProcessNewOps.
  with graph._mutation_lock():  # pylint: disable=protected-access
    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        results = c_api.TF_GraphImportGraphDefWithResults(
            graph._c_graph, serialized, options)  # pylint: disable=protected-access
        results = c_api_util.ScopedTFImportGraphDefResults(results)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph

    _ProcessNewOps(graph)

  if graph_def.library and graph_def.library.function:
    functions = function.from_library(graph_def.library)
    for f in functions:
      f.add_to_graph(graph)

  # Treat input mappings that don't appear in the graph as an error, because
  # they are likely to be due to a typo.
  missing_unused_input_keys = (
      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
          results.results))
  if missing_unused_input_keys:
    missing_unused_input_keys = [
        compat.as_str(s) for s in missing_unused_input_keys
    ]
    raise ValueError(
        'Attempted to map inputs that were not found in graph_def: [%s]' %
        ', '.join(missing_unused_input_keys))

  if return_elements is None:
    return None
  else:
    return _GatherReturnElements(return_elements, graph, results.results)
Exemplo n.º 3
0
def _import_graph_def_internal(  # pylint: disable=invalid-name
        graph_def,
        input_map=None,
        return_elements=None,
        validate_colocation_constraints=True,
        name=None,
        op_dict=None,
        producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  `tf.Tensor` and `tf.Operation` objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into the
      default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in `graph_def`
      that will be returned as `Operation` objects; and/or tensor names in
      `graph_def` that will be returned as `Tensor` objects.
    validate_colocation_constraints: Whether to validate colocation constraints.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`,
    and None if `returns_elements` is None.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()
    with ops.name_scope(name, 'import', input_map.values()) as scope:
        # Save unique prefix generated by name_scope
        if scope:
            assert scope.endswith('/')
            prefix = scope[:-1]
        else:
            prefix = ''

        # Generate any input map tensors inside name scope
        input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements,
                                     validate_colocation_constraints)

    # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
    # Session.run call cannot occur between creating the TF_Operations in the
    # TF_GraphImportGraphDefWithResults call and mutating the them in
    # _ProcessNewOps.
    with graph._mutation_lock():  # pylint: disable=protected-access
        with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
            try:
                results = c_api.TF_GraphImportGraphDefWithResults(
                    graph._c_graph, serialized, options)  # pylint: disable=protected-access
                results = c_api_util.ScopedTFImportGraphDefResults(results)
            except errors.InvalidArgumentError as e:
                # Convert to ValueError for backwards compatibility.
                raise ValueError(str(e))

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph

        _ProcessNewOps(graph)

    if graph_def.library and graph_def.library.function:
        functions = function.from_library(graph_def.library)
        for f in functions:
            f.add_to_graph(graph)

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results.results))
    if missing_unused_input_keys:
        missing_unused_input_keys = [
            compat.as_str(s) for s in missing_unused_input_keys
        ]
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]' %
            ', '.join(missing_unused_input_keys))

    if return_elements is None:
        return None
    else:
        return _GatherReturnElements(return_elements, graph, results.results)