Ejemplo n.º 1
0
def _import_meta_graph_def(meta_graph_def):
  """Recreates a Graph saved in a a `MetaGraphDef` proto.

  This function adds all the nodes from the meta graph def proto to the current
  graph, recreates all the collections, and returns a saver from saver_def.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer.

  Returns:
    A saver constructed rom `saver_def` in `meta_graph_def`.
  """
  # Gathers the list of nodes we are interested in.
  importer.import_graph_def(meta_graph_def.graph_def, name="")

  # Restores all the other collections.
  for key, col_def in meta_graph_def.collection_def.items():
    kind = col_def.WhichOneof("kind")
    if kind is None:
      logging.error("Cannot identify data type for collection %s. Skipping."
                    % key)
      continue
    from_proto = ops.get_from_proto_function(key)
    if from_proto:
      assert kind == "bytes_list"
      proto_type = ops.get_collection_proto_type(key)
      for value in col_def.bytes_list.value:
        proto = proto_type()
        proto.ParseFromString(value)
        ops.add_to_collection(key, from_proto(proto))
    else:
      field = getattr(col_def, kind)
      if kind == "node_list":
        for value in field.value:
          col_op = ops.get_default_graph().as_graph_element(value)
          ops.add_to_collection(key, col_op)
      elif kind == "int64_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python2 distinguishes between int and long, while Python3 has
        # only int.
        for value in field.value:
          ops.add_to_collection(key, int(value))
      else:
        for value in field.value:
          ops.add_to_collection(key, value)

  if meta_graph_def.HasField("saver_def"):
    return Saver(saver_def=meta_graph_def.saver_def)
  else:
    return Saver()
Ejemplo n.º 2
0
def _handle_collection_def(multi_gpu_meta_graph_def, op_names_to_replicate,
                           num_replicas):
    allow_bytes_list_keys = [tf.GraphKeys.QUEUE_RUNNERS,
                             tf.GraphKeys.GLOBAL_VARIABLES,
                             tf.GraphKeys.TRAINABLE_VARIABLES,
                             tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
                             tf.GraphKeys.LOCAL_VARIABLES,
                             tf.GraphKeys.MODEL_VARIABLES,
                             tf.GraphKeys.GRADIENTS_INFO,
                             tf.GraphKeys.GLOBAL_STEP]
    keys_to_remove = []
    for key, col_def in multi_gpu_meta_graph_def.collection_def.items():
        kind = col_def.WhichOneof("kind")
        # Update node_list collections (e.g., GLOBAL_STEP, TRAIN_OP, UPDATE_OP,
        # LOSSES, ...)
        if kind == 'node_list':
            new_col_def = get_new_col_def_of_node_list(
                            col_def, op_names_to_replicate, num_replicas)
            multi_gpu_meta_graph_def.collection_def[key].Clear()
            multi_gpu_meta_graph_def.collection_def[key].CopyFrom(new_col_def)
        elif kind == 'bytes_list':
            if ops.get_from_proto_function(key):
                # Collections in allow_bytes_list_keys will be handled
                # explicitly below
                # (e.g., QUEUE_RUNNERS, LOCAL_VARIABLES, ...)
                if key in allow_bytes_list_keys:
                    continue
                # Remove unhandled collections (e.g., COND_CONTEXT)
                # TODO: Handle all protos in tf.GraphKeys
                else:
                    keys_to_remove.append(key)
            # Keep collections without proto function
            # (e.g., user defined string)
            else:
                continue
        else:
            raise RuntimeError("Should not reach here")
    for key in keys_to_remove:
        del multi_gpu_meta_graph_def.collection_def[key]

    # Update QUEUE_RUNNERS and LOCAL_VARIABLES collection
    update_queue_runners(multi_gpu_meta_graph_def, op_names_to_replicate,
                          num_replicas)
    update_local_variables(multi_gpu_meta_graph_def, op_names_to_replicate,
                            num_replicas)
    update_shard_info_for_in_graph(multi_gpu_meta_graph_def, num_replicas)
Ejemplo n.º 3
0
 def _restore_collections(dest_graph, src_meta_graph_def,
                          collection_keys):
     """Restores collections that we need to keep."""
     scope = ""
     for key in collection_keys:
         collection_def = src_meta_graph_def.collection_def[key]
         kind = collection_def.WhichOneof("kind")
         if kind is None:
             tf_logging.error(
                 "Cannot identify data type for collection %s. Skipping.",
                 key)
             continue
         from_proto = ops.get_from_proto_function(key)
         if from_proto and kind == "bytes_list":
             proto_type = ops.get_collection_proto_type(key)
             # It is assumed that there are no Variables Keys in collections
             for value in collection_def.bytes_list.value:
                 proto = proto_type()
                 proto.ParseFromString(value)
                 try:
                     new_value = from_proto(proto, import_scope=scope)
                 except:
                     continue
                 dest_graph.add_to_collection(key, new_value)
         else:
             field = getattr(collection_def, kind)
             if kind == "node_list":
                 for value in field.value:
                     name = ops.prepend_name_scope(value, scope)
                     # Since the graph has been optimized, the node may no longer
                     # exists
                     try:
                         col_op = dest_graph.as_graph_element(name)
                     except (TypeError, ValueError, KeyError) as e:
                         continue
                     dest_graph.add_to_collection(key, col_op)
             elif kind == "int64_list":
                 # NOTE(opensource): This force conversion is to work around the
                 # fact that Python2 distinguishes between int and long, while
                 # Python3 has only int.
                 for value in field.value:
                     dest_graph.add_to_collection(key, int(value))
             else:
                 for value in field.value:
                     dest_graph.add_to_collection(
                         key, ops.prepend_name_scope(value, scope))
Ejemplo n.º 4
0
 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
   """Restores collections that we need to keep."""
   scope = ""
   for key in collection_keys:
     collection_def = src_meta_graph_def.collection_def[key]
     kind = collection_def.WhichOneof("kind")
     if kind is None:
       tf_logging.error(
           "Cannot identify data type for collection %s. Skipping.", key)
       continue
     from_proto = ops.get_from_proto_function(key)
     if from_proto and kind == "bytes_list":
       proto_type = ops.get_collection_proto_type(key)
       # It is assumed that there are no Variables Keys in collections
       for value in collection_def.bytes_list.value:
         proto = proto_type()
         proto.ParseFromString(value)
         try:
           new_value = from_proto(proto, import_scope=scope)
         except:
           continue
         dest_graph.add_to_collection(key, new_value)
     else:
       field = getattr(collection_def, kind)
       if kind == "node_list":
         for value in field.value:
           name = ops.prepend_name_scope(value, scope)
           # Since the graph has been optimized, the node may no longer
           # exists
           try:
             col_op = dest_graph.as_graph_element(name)
           except (TypeError, ValueError, KeyError) as e:
             continue
           dest_graph.add_to_collection(key, col_op)
       elif kind == "int64_list":
         # NOTE(opensource): This force conversion is to work around the
         # fact that Python2 distinguishes between int and long, while
         # Python3 has only int.
         for value in field.value:
           dest_graph.add_to_collection(key, int(value))
       else:
         for value in field.value:
           dest_graph.add_to_collection(key,
                                        ops.prepend_name_scope(value, scope))
Ejemplo n.º 5
0
def import_scoped_meta_graph_with_return_elements(
        meta_graph_or_file,
        clear_devices=False,
        graph=None,
        import_scope=None,
        input_map=None,
        unbound_inputs_col_name="unbound_inputs",
        restore_collections_predicate=(lambda key: True),
        return_elements=None):
    """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
    if context.executing_eagerly():
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "eager execution is enabled.")
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." % ",".join(
                            compat.as_str(v) for v in field.value
                            if not input_map or v not in input_map))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        scope_to_prepend_to_names = graph.unique_name(import_scope or "",
                                                      mark_as_used=False)

        imported_return_elements = importer.import_graph_def(
            input_graph_def,
            name=(import_scope or scope_to_prepend_to_names),
            input_map=input_map,
            producer_op_list=producer_op_list,
            return_elements=return_elements)

        # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
        # without a VariableDef.trainable field set.
        tf_version = meta_graph_def.meta_info_def.tensorflow_version
        if not tf_version:
            variables_have_trainable = True
        else:
            variables_have_trainable = (packaging_version.parse(tf_version) >=
                                        packaging_version.parse("1.9"))

        # Sort collections so we see TRAINABLE_VARIABLES first and can default these
        # variables to trainable if the value is not set in their VariableDef.
        sorted_collections = []
        if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
            sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES,
                                       meta_graph_def.collection_def[
                                           ops.GraphKeys.TRAINABLE_VARIABLES]))
        for key, value in sorted(meta_graph_def.collection_def.items()):
            if key != ops.GraphKeys.TRAINABLE_VARIABLES:
                sorted_collections.append((key, value))

        # Restores all the other collections.
        variable_objects = {}
        for key, col_def in sorted_collections:
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue
            if not restore_collections_predicate(key):
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)

            # Temporary change to allow the TFMA evaluator to read metric variables
            # saved as a bytes list.
            # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
            if key == ops.GraphKeys.METRIC_VARIABLES:
                # Metric variables will use the same proto functions as GLOBAL_VARIABLES
                from_proto = ops.get_from_proto_function(
                    ops.GraphKeys.GLOBAL_VARIABLES)
            if from_proto and kind == "bytes_list":
                proto_type = ops.get_collection_proto_type(key)
                if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
                    for value in col_def.bytes_list.value:
                        variable = variable_objects.get(value, None)
                        if variable is None:
                            proto = proto_type()
                            proto.ParseFromString(value)
                            if not variables_have_trainable:
                                # If the VariableDef proto does not contain a "trainable"
                                # property because it was exported before that property was
                                # added, we default it to whether the variable is in the
                                # TRAINABLE_VARIABLES collection. We've sorted
                                # TRAINABLE_VARIABLES to be first, so trainable variables will
                                # be created from that collection.
                                proto.trainable = (
                                    key == ops.GraphKeys.TRAINABLE_VARIABLES)
                            variable = from_proto(
                                proto, import_scope=scope_to_prepend_to_names)
                            variable_objects[value] = variable
                        graph.add_to_collection(key, variable)
                else:
                    for value in col_def.bytes_list.value:
                        proto = proto_type()
                        proto.ParseFromString(value)
                        graph.add_to_collection(
                            key,
                            from_proto(proto,
                                       import_scope=scope_to_prepend_to_names))
            else:
                field = getattr(col_def, kind)
                if key in _COMPAT_COLLECTION_LIST:
                    logging.warning(
                        "The saved meta_graph is possibly from an older release:\n"
                        "'%s' collection should be of type 'byte_list', but instead "
                        "is of type '%s'.", key, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key,
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=scope_to_prepend_to_names)
        for v in variables:
            var_list[ops.strip_name_scope(v.name,
                                          scope_to_prepend_to_names)] = v

    return var_list, imported_return_elements
Ejemplo n.º 6
0
def create_meta_graph_def(meta_info_def=None,
                          graph_def=None,
                          saver_def=None,
                          collection_list=None,
                          graph=None,
                          export_scope=None,
                          exclude_nodes=None,
                          clear_extraneous_savers=False,
                          strip_default_attrs=False):
    # pylint: disable=line-too-long
    """Construct and returns a `MetaGraphDef` protocol buffer.

  Args:
    meta_info_def: `MetaInfoDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    saver_def: `SaverDef` protocol buffer.
    collection_list: List of string keys to collect.
    graph: The `Graph` to create `MetaGraphDef` out of.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from all
      collection, or None.
    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
        collection.  Note this method does not alter the graph, so any
        extraneous Save/Restore ops should have been removed already, as needed.
    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the NodeDefs. For a detailed guide, see
        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).

  Returns:
    MetaGraphDef protocol buffer.

  Raises:
    TypeError: If the arguments are not of the correct proto buffer type.
  """
    # pylint: enable=line-too-long
    # Type check.
    if graph and not isinstance(graph, ops.Graph):
        raise TypeError(
            f"graph must be of type Graph. Received type: {type(graph)}.")
    if meta_info_def and not isinstance(
            meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef):
        raise TypeError("meta_info_def must be of type MetaInfoDef. "
                        f"Received type: {type(meta_info_def)}.")
    if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be of type GraphDef. "
                        f"Received type: {type(graph_def)}.")
    if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
        raise TypeError(f"saver_def must be of type SaverDef. "
                        f"Received type: {type(saver_def)}.")

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Creates a MetaGraphDef proto.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    # Adds meta_info_def.
    if not meta_info_def:
        meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()

    # Set the tf version strings to the current tf build.
    meta_info_def.tensorflow_version = versions.__version__
    meta_info_def.tensorflow_git_version = versions.__git_version__
    meta_graph_def.meta_info_def.MergeFrom(meta_info_def)

    # Adds graph_def or the default.
    if not graph_def:
        meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
    else:
        meta_graph_def.graph_def.MergeFrom(graph_def)

    # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
    # pylint: disable=g-explicit-length-test
    if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
        meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
            stripped_op_list_for_graph(meta_graph_def.graph_def))
    # pylint: enable=g-explicit-length-test

    # Strip default valued attributes in graph_def.
    if strip_default_attrs:
        strip_graph_default_valued_attrs(meta_graph_def)

    # Adds saver_def.
    if saver_def:
        meta_graph_def.saver_def.MergeFrom(saver_def)

    # Adds collection_list.
    if collection_list is not None:
        clist = collection_list
    else:
        clist = graph.get_all_collection_keys()

    for ctype in clist:
        if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
            # Avoid importing Saver here
            from_proto = ops.get_from_proto_function(ctype)
            add_collection_def(meta_graph_def,
                               ctype,
                               graph=graph,
                               export_scope=export_scope,
                               exclude_nodes=exclude_nodes,
                               override_contents=[from_proto(saver_def)])
        else:
            add_collection_def(meta_graph_def,
                               ctype,
                               graph=graph,
                               export_scope=export_scope,
                               exclude_nodes=exclude_nodes)
    return meta_graph_def
Ejemplo n.º 7
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Ejemplo n.º 8
0
def create_meta_graph_def(meta_info_def=None,
                          graph_def=None,
                          saver_def=None,
                          collection_list=None,
                          graph=None,
                          export_scope=None,
                          exclude_nodes=None,
                          clear_extraneous_savers=False,
                          strip_default_attrs=False):
  # pylint: disable=line-too-long
  """Construct and returns a `MetaGraphDef` protocol buffer.

  Args:
    meta_info_def: `MetaInfoDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    saver_def: `SaverDef` protocol buffer.
    collection_list: List of string keys to collect.
    graph: The `Graph` to create `MetaGraphDef` out of.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from all
      collection, or None.
    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
        collection.  Note this method does not alter the graph, so any
        extraneous Save/Restore ops should have been removed already, as needed.
    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the NodeDefs. For a detailed guide, see
        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).

  Returns:
    MetaGraphDef protocol buffer.

  Raises:
    TypeError: If the arguments are not of the correct proto buffer type.
  """
  # pylint: enable=line-too-long
  # Type check.
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))
  if meta_info_def and not isinstance(meta_info_def,
                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
                    type(meta_info_def))
  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError("graph_def must be of type GraphDef, not %s",
                    type(graph_def))
  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
    raise TypeError("saver_def must be of type SaverDef, not %s",
                    type(saver_def))

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Creates a MetaGraphDef proto.
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  # Adds meta_info_def.
  if not meta_info_def:
    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()

  # Set the tf version strings to the current tf build.
  meta_info_def.tensorflow_version = versions.__version__
  meta_info_def.tensorflow_git_version = versions.__git_version__
  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)

  # Adds graph_def or the default.
  if not graph_def:
    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
  else:
    meta_graph_def.graph_def.MergeFrom(graph_def)

  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
  # pylint: disable=g-explicit-length-test
  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
        stripped_op_list_for_graph(meta_graph_def.graph_def))
  # pylint: enable=g-explicit-length-test

  # Strip default valued attributes in graph_def.
  if strip_default_attrs:
    _strip_graph_default_valued_attrs(meta_graph_def)

  # Adds saver_def.
  if saver_def:
    meta_graph_def.saver_def.MergeFrom(saver_def)

  # Adds collection_list.
  if collection_list is not None:
    clist = collection_list
  else:
    clist = graph.get_all_collection_keys()

  for ctype in clist:
    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
      # Avoid importing Saver here
      from_proto = ops.get_from_proto_function(ctype)
      add_collection_def(meta_graph_def, ctype,
                         graph=graph,
                         export_scope=export_scope,
                         exclude_nodes=exclude_nodes,
                         override_contents=[from_proto(saver_def)])
    else:
      add_collection_def(meta_graph_def, ctype,
                         graph=graph,
                         export_scope=export_scope,
                         exclude_nodes=exclude_nodes)
  return meta_graph_def
Ejemplo n.º 9
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.in_eager_mode():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Ejemplo n.º 10
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
      else:
        field = getattr(col_def, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, import_scope))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, import_scope))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.VARIABLES,
                                     scope=import_scope)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, import_scope)] = v

  return var_list