def to_proto(self, export_scope=None):
    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
      the specified name scope.
    """
    if (export_scope is None or
        self.queue.name.startswith(export_scope)):
      queue_runner_def = queue_runner_pb2.QueueRunnerDef()
      queue_runner_def.queue_name = ops.strip_name_scope(
          self.queue.name, export_scope)
      for enqueue_op in self.enqueue_ops:
        queue_runner_def.enqueue_op_name.append(
            ops.strip_name_scope(enqueue_op.name, export_scope))
      queue_runner_def.close_op_name = ops.strip_name_scope(
          self.close_op.name, export_scope)
      queue_runner_def.cancel_op_name = ops.strip_name_scope(
          self.cancel_op.name, export_scope)
      queue_runner_def.queue_closed_exception_types.extend([
          errors.error_code_from_exception_type(cls)
          for cls in self._queue_closed_exception_types])
      return queue_runner_def
    else:
      return None
  def to_proto(self, export_scope=None):
    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
      in the specified name scope.
    """
    if (export_scope is None or
        self.handle.name.startswith(export_scope)):
      var_def = variable_pb2.VariableDef()
      var_def.variable_name = ops.strip_name_scope(
          self.handle.name, export_scope)
      var_def.initializer_name = ops.strip_name_scope(
          self.initializer.name, export_scope)
      if self._cached_value is not None:
        var_def.snapshot_name = ops.strip_name_scope(
            self._cached_value.name, export_scope)
      var_def.is_resource = True
      if self._save_slice_info:
        var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
            export_scope=export_scope))
      return var_def
    else:
      return None
def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))

  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  collection_list = graph.get_collection(key)
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
Exemple #4
0
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
      if not export_scope or compat.as_str(v.s).startswith(export_scope):
        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def
  def to_proto(self, export_scope=None):
    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Raises:
      RuntimeError: If run in EAGER mode.

    Returns:
      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
      in the specified name scope.
    """
    if context.executing_eagerly():
      raise RuntimeError("to_proto not supported in EAGER mode.")
    if export_scope is None or self.handle.name.startswith(export_scope):
      var_def = variable_pb2.VariableDef()
      var_def.variable_name = ops.strip_name_scope(self.handle.name,
                                                   export_scope)
      if self._initial_value is not None:
        # This is inside an if-statement for backwards compatibility, since
        # self._initial_value might be None for variables constructed from old
        # protos.
        var_def.initial_value_name = ops.strip_name_scope(
            self._initial_value.name, export_scope)
      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
                                                      export_scope)
      if self._cached_value is not None:
        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
                                                     export_scope)
      else:
        # Store the graph_element here
        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
                                                     export_scope)
      var_def.is_resource = True
      var_def.trainable = self.trainable
      if self._save_slice_info:
        var_def.save_slice_info_def.MergeFrom(
            self._save_slice_info.to_proto(export_scope=export_scope))
      return var_def
    else:
      return None
def _node_def(from_node_def, export_scope, unbound_inputs):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    else:
      node_def.attr[k].CopyFrom(v)

  return node_def
Exemple #7
0
        def to_proto(self, export_scope=None):
            """Returns a SaveSliceInfoDef() proto.

      Args:
        export_scope: Optional `string`. Name scope to remove.

      Returns:
        A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
        in the specified name scope.
      """
            if export_scope is None or self.full_name.startswith(export_scope):
                save_slice_info_def = variable_pb2.SaveSliceInfoDef()
                save_slice_info_def.full_name = ops.strip_name_scope(self.full_name, export_scope)
                for i in self.full_shape:
                    save_slice_info_def.full_shape.append(i)
                for i in self.var_offset:
                    save_slice_info_def.var_offset.append(i)
                for i in self.var_shape:
                    save_slice_info_def.var_shape.append(i)
                return save_slice_info_def
            else:
                return None
Exemple #8
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             save_debug_info=False,
                             **kwargs):
  """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
      removed while exporting the GraphDef.
    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
      which in the same directory of filename and with `_debug` added before the
      file extension.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
    ValueError: When executing in Eager mode and either `graph_def` or `graph`
      is undefined.
  """
  if context.executing_eagerly() and not (graph_def is not None and
                                          graph is not None):
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "Eager Execution is enabled.")
  graph = graph or ops.get_default_graph()

  exclude_nodes = None
  unbound_inputs = []
  if export_scope or clear_extraneous_savers or clear_devices:
    if graph_def:
      new_graph_def = graph_pb2.GraphDef()
      new_graph_def.versions.CopyFrom(graph_def.versions)
      new_graph_def.library.CopyFrom(graph_def.library)

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)

      for node_def in graph_def.node:
        if _should_include_node(node_def.name, export_scope, exclude_nodes):
          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
                                   clear_devices=clear_devices)
          new_graph_def.node.extend([new_node_def])
      graph_def = new_graph_def
    else:
      # Only do this complicated work if we want to remove a name scope.
      graph_def = graph_pb2.GraphDef()
      # pylint: disable=protected-access
      graph_def.versions.CopyFrom(graph.graph_def_versions)
      bytesize = 0

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
                                                     saver_def)

      for key in sorted(graph._nodes_by_id):
        if _should_include_node(graph._nodes_by_id[key].name,
                                export_scope,
                                exclude_nodes):
          value = graph._nodes_by_id[key]
          # pylint: enable=protected-access
          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
                               clear_devices=clear_devices)
          graph_def.node.extend([node_def])
          if value.outputs:
            assert "_output_shapes" not in graph_def.node[-1].attr
            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
                output.get_shape().as_proto() for output in value.outputs])
          bytesize += value.node_def.ByteSize()
          if bytesize >= (1 << 31) or bytesize < 0:
            raise ValueError("GraphDef cannot be larger than 2GB.")

      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope, exclude_nodes):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      exclude_nodes=exclude_nodes,
      clear_extraneous_savers=clear_extraneous_savers,
      saver_def=saver_def,
      strip_default_attrs=strip_default_attrs,
      **kwargs)

  if filename:
    graph_io.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)
    if save_debug_info:
      name, _ = os.path.splitext(filename)
      debug_filename = "{name}{ext}".format(name=name, ext=".debug")

      # Gets the operation from the graph by the name. Exludes variable nodes,
      # so only the nodes in the frozen models are included.
      # TODO(liufengdb): fix this for functions.
      ops_to_export = []
      for node in scoped_meta_graph_def.graph_def.node:
        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))

      graph_debug_info = error_interpolation.create_graph_debug_info_def(
          ops_to_export)

      graph_io.write_graph(
          graph_debug_info,
          os.path.dirname(debug_filename),
          os.path.basename(debug_filename),
          as_text=as_text)

  return scoped_meta_graph_def, var_list
Exemple #9
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Exemple #10
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             **kwargs):
  """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
        removed while exporting the GraphDef.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "Eager Execution is enabled.")
  graph = graph or ops.get_default_graph()

  exclude_nodes = None
  unbound_inputs = []
  if export_scope or clear_extraneous_savers or clear_devices:
    if graph_def:
      new_graph_def = graph_pb2.GraphDef()
      new_graph_def.versions.CopyFrom(graph_def.versions)
      new_graph_def.library.CopyFrom(graph_def.library)

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)

      for node_def in graph_def.node:
        if _should_include_node(node_def.name, export_scope, exclude_nodes):
          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
                                   clear_devices=clear_devices)
          new_graph_def.node.extend([new_node_def])
      graph_def = new_graph_def
    else:
      # Only do this complicated work if we want to remove a name scope.
      graph_def = graph_pb2.GraphDef()
      # pylint: disable=protected-access
      graph_def.versions.CopyFrom(graph.graph_def_versions)
      bytesize = 0

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
                                                     saver_def)

      for key in sorted(graph._nodes_by_id):
        if _should_include_node(graph._nodes_by_id[key].name,
                                export_scope,
                                exclude_nodes):
          value = graph._nodes_by_id[key]
          # pylint: enable=protected-access
          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
                               clear_devices=clear_devices)
          graph_def.node.extend([node_def])
          if value.outputs:
            assert "_output_shapes" not in graph_def.node[-1].attr
            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
                output.get_shape().as_proto() for output in value.outputs])
          bytesize += value.node_def.ByteSize()
          if bytesize >= (1 << 31) or bytesize < 0:
            raise ValueError("GraphDef cannot be larger than 2GB.")

      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope, exclude_nodes):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      exclude_nodes=exclude_nodes,
      clear_extraneous_savers=clear_extraneous_savers,
      saver_def=saver_def,
      strip_default_attrs=strip_default_attrs,
      **kwargs)

  if filename:
    graph_io.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)

  return scoped_meta_graph_def, var_list
Exemple #11
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             **kwargs):
  """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to import into. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be striped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported. graph_def and export_scope cannot both be specified.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    **kwargs: Optional keyed arguments, including meta_info_def,
      saver_def, collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
  graph = graph or ops.get_default_graph()
  if graph_def and export_scope:
    raise ValueError("graph_def and export_scope cannot both "
                     "be specified.")

  if graph_def is None and export_scope:
    unbound_inputs = []
    # Only do this complicated work if we want to remove a name scope.
    graph_def = graph_pb2.GraphDef()
    # pylint: disable=protected-access
    graph_def.versions.CopyFrom(graph._graph_def_versions)
    bytesize = 0
    for key in sorted(graph._nodes_by_name):
      if _should_include_node(key, export_scope):
        value = graph._nodes_by_name[key]
    # pylint: enable=protected-access
        graph_def.node.extend([_node_def(value.node_def, export_scope,
                                         unbound_inputs)])
        if value.outputs:
          assert "_output_shapes" not in graph_def.node[-1].attr
          graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
              output.get_shape().as_proto() for output in value.outputs])
        bytesize += value.node_def.ByteSize()
        if bytesize >= (1 << 31) or bytesize < 0:
          raise ValueError("GraphDef cannot be larger than 2GB.")

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      **kwargs)

  if filename:
    training_util.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)

  return scoped_meta_graph_def, var_list
Exemple #12
0
  def __iadd__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self + other

  def __isub__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self - other

  def __imul__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self * other

  def __idiv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __itruediv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __irealdiv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __ipow__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self ** other

  class SaveSliceInfo(object):
    def __init__(self,
                 full_name=None,
                 full_shape=None,
                 var_offset=None,
                 var_shape=None,
                 save_slice_info_def=None,
                 import_scope=None):
      if save_slice_info_def:
        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
        self.full_name = ops.prepend_name_scope(
            save_slice_info_def.full_name, import_scope=import_scope)
        self.full_shape = [i for i in save_slice_info_def.full_shape]
        self.var_offset = [i for i in save_slice_info_def.var_offset]
        self.var_shape = [i for i in save_slice_info_def.var_shape]
      else:
        self.full_name = full_name
        self.full_shape = full_shape
        self.var_offset = var_offset
        self.var_shape = var_shape

    @property
    def spec(self):
      full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
      sl_spec = ":".join([
          "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
      ])
      return full_shape_str + sl_spec

    def to_proto(self, export_scope=None):
      if (export_scope is None or
          self.full_name.startswith(export_scope)):
        save_slice_info_def = variable_pb2.SaveSliceInfoDef()
        save_slice_info_def.full_name = ops.strip_name_scope(
            self.full_name, export_scope)
        for i in self.full_shape:
          save_slice_info_def.full_shape.append(i)
        for i in self.var_offset:
          save_slice_info_def.var_offset.append(i)
        for i in self.var_shape:
          save_slice_info_def.var_shape.append(i)
        return save_slice_info_def
      else:
        return None

  def _set_save_slice_info(self, save_slice_info):
    self._save_slice_info = save_slice_info

  def _get_save_slice_info(self):
    return self._save_slice_info
Exemple #13
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             **kwargs):
    """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to import into. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be striped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported. graph_def and export_scope cannot both be specified.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    **kwargs: Optional keyed arguments, including meta_info_def,
      saver_def, collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
    graph = graph or ops.get_default_graph()
    unbound_inputs = []
    if export_scope or clear_devices:
        if graph_def:
            new_graph_def = graph_pb2.GraphDef()
            new_graph_def.versions.CopyFrom(graph_def.versions)
            for node_def in graph_def.node:
                if _should_include_node(node_def.name, export_scope):
                    new_node_def = _node_def(node_def,
                                             export_scope,
                                             unbound_inputs,
                                             clear_devices=clear_devices)
                    new_graph_def.node.extend([new_node_def])
            graph_def = new_graph_def
        else:
            # Only do this complicated work if we want to remove a name scope.
            graph_def = graph_pb2.GraphDef()
            # pylint: disable=protected-access
            graph_def.versions.CopyFrom(graph.graph_def_versions)
            bytesize = 0
            for key in sorted(graph._nodes_by_id):
                if _should_include_node(graph._nodes_by_id[key].name,
                                        export_scope):
                    value = graph._nodes_by_id[key]
                    # pylint: enable=protected-access
                    node_def = _node_def(value.node_def,
                                         export_scope,
                                         unbound_inputs,
                                         clear_devices=clear_devices)
                    graph_def.node.extend([node_def])
                    if value.outputs:
                        assert "_output_shapes" not in graph_def.node[-1].attr
                        graph_def.node[-1].attr[
                            "_output_shapes"].list.shape.extend([
                                output.get_shape().as_proto()
                                for output in value.outputs
                            ])
                    bytesize += value.node_def.ByteSize()
                    if bytesize >= (1 << 31) or bytesize < 0:
                        raise ValueError("GraphDef cannot be larger than 2GB.")
        # It's possible that not all the inputs are in the export_scope.
        # If we would like such information included in the exported meta_graph,
        # add them to a special unbound_inputs collection.
        if unbound_inputs_col_name:
            # Clears the unbound_inputs collections.
            graph.clear_collection(unbound_inputs_col_name)
            for k in unbound_inputs:
                graph.add_to_collection(unbound_inputs_col_name, k)

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=export_scope)
    for v in variables:
        if _should_include_node(v, export_scope):
            var_list[ops.strip_name_scope(v.name, export_scope)] = v

    scoped_meta_graph_def = create_meta_graph_def(graph_def=graph_def,
                                                  graph=graph,
                                                  export_scope=export_scope,
                                                  **kwargs)

    if filename:
        graph_io.write_graph(scoped_meta_graph_def,
                             os.path.dirname(filename),
                             os.path.basename(filename),
                             as_text=as_text)

    return scoped_meta_graph_def, var_list
Exemple #14
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Output` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Output` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
      else:
        field = getattr(col_def, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, import_scope))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, import_scope))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=import_scope)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, import_scope)] = v

  return var_list
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
      else:
        field = getattr(col_def, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, import_scope))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, import_scope))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.VARIABLES,
                                     scope=import_scope)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, import_scope)] = v

  return var_list
Exemple #16
0
def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None, exclude_nodes=None,
                       override_contents=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from the
      collection, or None.
    override_contents: An iterable of values to place in the collection,
      ignoring the current values (if set).
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))

  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  if override_contents:
    collection_list = override_contents
  else:
    collection_list = graph.get_collection(key)

  # Remove nodes that should not be exported from the collection list.
  collection_list = [x for x in collection_list if
                     _should_include_node(x, export_scope, exclude_nodes)]
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Issue encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef. Note this is a warning "
                    "and probably safe to ignore.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
Exemple #17
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
    """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
    if context.executing_eagerly():
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "eager execution is enabled.")
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." % ",".join([
                            compat.as_str(v) for v in field.value
                            if not input_map or v not in input_map
                        ]))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        scope_to_prepend_to_names = graph.unique_name(import_scope or "",
                                                      mark_as_used=False)

        importer.import_graph_def(input_graph_def,
                                  name=(import_scope
                                        or scope_to_prepend_to_names),
                                  input_map=input_map,
                                  producer_op_list=producer_op_list)

        # Restores all the other collections.
        variable_objects = {}
        for key, col_def in sorted(meta_graph_def.collection_def.items()):
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue
            if not restore_collections_predicate(key):
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)
            if from_proto and kind == "bytes_list":
                proto_type = ops.get_collection_proto_type(key)
                if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
                    for value in col_def.bytes_list.value:
                        variable = variable_objects.get(value, None)
                        if variable is None:
                            proto = proto_type()
                            proto.ParseFromString(value)
                            variable = from_proto(
                                proto, import_scope=scope_to_prepend_to_names)
                            variable_objects[value] = variable
                        graph.add_to_collection(key, variable)
                else:
                    for value in col_def.bytes_list.value:
                        proto = proto_type()
                        proto.ParseFromString(value)
                        graph.add_to_collection(
                            key,
                            from_proto(proto,
                                       import_scope=scope_to_prepend_to_names))
            else:
                field = getattr(col_def, kind)
                if key in _COMPAT_COLLECTION_LIST:
                    logging.warning(
                        "The saved meta_graph is possibly from an older release:\n"
                        "'%s' collection should be of type 'byte_list', but instead "
                        "is of type '%s'.", key, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key,
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=scope_to_prepend_to_names)
        for v in variables:
            var_list[ops.strip_name_scope(v.name,
                                          scope_to_prepend_to_names)] = v

    return var_list
Exemple #18
0
def import_scoped_meta_graph_with_return_elements(
    meta_graph_or_file,
    clear_devices=False,
    graph=None,
    import_scope=None,
    input_map=None,
    unbound_inputs_col_name="unbound_inputs",
    restore_collections_predicate=(lambda key: True),
    return_elements=None):
  """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." % ",".join(
                               compat.as_str(v)
                               for v in field.value
                               if not input_map or v not in input_map))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    imported_return_elements = importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list,
        return_elements=return_elements)

    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
    # without a VariableDef.trainable field set.
    tf_version = meta_graph_def.meta_info_def.tensorflow_version
    if not tf_version:
      variables_have_trainable = True
    else:
      variables_have_trainable = (
          distutils_version.LooseVersion(tf_version)
          >= distutils_version.LooseVersion("1.9"))

    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
    # variables to trainable if the value is not set in their VariableDef.
    sorted_collections = []
    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
      sorted_collections.append(
          (ops.GraphKeys.TRAINABLE_VARIABLES,
           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
    for key, value in sorted(meta_graph_def.collection_def.items()):
      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
        sorted_collections.append((key, value))

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted_collections:
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)

      # Temporary change to allow the TFMA evaluator to read metric variables
      # saved as a bytes list.
      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
      if key == ops.GraphKeys.METRIC_VARIABLES:
        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              if not variables_have_trainable:
                # If the VariableDef proto does not contain a "trainable"
                # property because it was exported before that property was
                # added, we default it to whether the variable is in the
                # TRAINABLE_VARIABLES collection. We've sorted
                # TRAINABLE_VARIABLES to be first, so trainable variables will
                # be created from that collection.
                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list, imported_return_elements
Exemple #19
0
 def strip_name_scope(name_scope):
     return ops.strip_name_scope(name_scope, export_scope)