예제 #1
0
def assert_meta_graph_protos_equal(tester, a, b):
  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
  # Carefully check the collection_defs
  tester.assertEqual(set(a.collection_def), set(b.collection_def))
  collection_keys = a.collection_def.keys()
  for k in collection_keys:
    a_value = a.collection_def[k]
    b_value = b.collection_def[k]
    proto_type = ops.get_collection_proto_type(k)
    if proto_type:
      a_proto = proto_type()
      b_proto = proto_type()
      # Number of entries in the collections is the same
      tester.assertEqual(len(a_value.bytes_list.value),
                         len(b_value.bytes_list.value))
      for (a_value_item, b_value_item) in zip(
          a_value.bytes_list.value,
          b_value.bytes_list.value):
        a_proto.ParseFromString(a_value_item)
        b_proto.ParseFromString(b_value_item)
        tester.assertProtoEquals(a_proto, b_proto)
    else:
      tester.assertEquals(a_value, b_value)
  # Compared the fields directly, remove their raw values from the
  # proto comparison below.
  a.ClearField("collection_def")
  b.ClearField("collection_def")
  tester.assertProtoEquals(a, b)
예제 #2
0
def assert_meta_graph_protos_equal(tester, a, b):
    """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
    # Carefully check the collection_defs
    tester.assertEqual(set(a.collection_def), set(b.collection_def))
    collection_keys = a.collection_def.keys()
    for k in collection_keys:
        a_value = a.collection_def[k]
        b_value = b.collection_def[k]
        proto_type = ops.get_collection_proto_type(k)
        if proto_type:
            a_proto = proto_type()
            b_proto = proto_type()
            # Number of entries in the collections is the same
            tester.assertEqual(len(a_value.bytes_list.value),
                               len(b_value.bytes_list.value))
            for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
                                                    b_value.bytes_list.value):
                a_proto.ParseFromString(a_value_item)
                b_proto.ParseFromString(b_value_item)
                tester.assertProtoEquals(a_proto, b_proto)
        else:
            tester.assertEquals(a_value, b_value)
    # Compared the fields directly, remove their raw values from the
    # proto comparison below.
    a.ClearField("collection_def")
    b.ClearField("collection_def")
    tester.assertProtoEquals(a, b)
예제 #3
0
def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))

  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  collection_list = graph.get_collection(key)
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
예제 #4
0
def add_collection_def(meta_graph_def, key):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning(
            "Only collections with string type keys will be "
            "serialized. This key has %s", type(key))
        return
    collection_list = ops.get_collection(key)
    if not collection_list:
        return
    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x)
                if not isinstance(proto, proto_type):
                    raise TypeError("proto %s is not type %s" %
                                    (proto, proto_type))
                getattr(col_def, kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                getattr(col_def,
                        kind).value.extend([x.name for x in collection_list])
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend(
                    [compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def,
                        kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s", key, str(e))
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
예제 #5
0
def _import_meta_graph_def(meta_graph_def):
  """Recreates a Graph saved in a a `MetaGraphDef` proto.

  This function adds all the nodes from the meta graph def proto to the current
  graph, recreates all the collections, and returns a saver from saver_def.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer.

  Returns:
    A saver constructed rom `saver_def` in `meta_graph_def`.
  """
  # Gathers the list of nodes we are interested in.
  importer.import_graph_def(meta_graph_def.graph_def, name="")

  # Restores all the other collections.
  for key, col_def in meta_graph_def.collection_def.items():
    kind = col_def.WhichOneof("kind")
    if kind is None:
      logging.error("Cannot identify data type for collection %s. Skipping."
                    % key)
      continue
    from_proto = ops.get_from_proto_function(key)
    if from_proto:
      assert kind == "bytes_list"
      proto_type = ops.get_collection_proto_type(key)
      for value in col_def.bytes_list.value:
        proto = proto_type()
        proto.ParseFromString(value)
        ops.add_to_collection(key, from_proto(proto))
    else:
      field = getattr(col_def, kind)
      if kind == "node_list":
        for value in field.value:
          col_op = ops.get_default_graph().as_graph_element(value)
          ops.add_to_collection(key, col_op)
      elif kind == "int64_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python2 distinguishes between int and long, while Python3 has
        # only int.
        for value in field.value:
          ops.add_to_collection(key, int(value))
      else:
        for value in field.value:
          ops.add_to_collection(key, value)

  if meta_graph_def.HasField("saver_def"):
    return Saver(saver_def=meta_graph_def.saver_def)
  else:
    return Saver()
예제 #6
0
def _get_all_protos_from_collection(meta_graph_def, collection_key):
    """Obtain node names from a collection."""
    if collection_key not in meta_graph_def.collection_def:
        return []
    collection = meta_graph_def.collection_def[collection_key]
    if not collection.bytes_list.value:
        raise ValueError(
            'Collection {} is present but type is not bytes_list.'.format(
                collection_key))
    proto_type = _ops.get_collection_proto_type(collection_key)
    result = []
    for value in collection.bytes_list.value:
        proto = proto_type()
        proto.ParseFromString(value)
        result.append(proto)
    return result
예제 #7
0
def _get_all_protos_from_collection(meta_graph_def, collection_key):
  """Obtain node names from a collection."""
  if collection_key not in meta_graph_def.collection_def:
    return []
  collection = meta_graph_def.collection_def[collection_key]
  if not collection.bytes_list.value:
    raise ValueError(
        'Collection {} is present but type is not bytes_list.'.format(
            collection_key))
  proto_type = _ops.get_collection_proto_type(collection_key)
  result = []
  for value in collection.bytes_list.value:
    proto = proto_type()
    proto.ParseFromString(value)
    result.append(proto)
  return result
예제 #8
0
def add_collection_def(meta_graph_def, key):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key))
        return
    collection_list = ops.get_collection(key)
    if not collection_list:
        return
    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x)
                if not isinstance(proto, proto_type):
                    raise TypeError("proto %s is not type %s" % (proto, proto_type))
                getattr(col_def, kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                getattr(col_def, kind).value.extend([x.name for x in collection_list])
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend([compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def, kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s",
            key,
            str(e),
        )
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
예제 #9
0
 def _restore_collections(dest_graph, src_meta_graph_def,
                          collection_keys):
     """Restores collections that we need to keep."""
     scope = ""
     for key in collection_keys:
         collection_def = src_meta_graph_def.collection_def[key]
         kind = collection_def.WhichOneof("kind")
         if kind is None:
             tf_logging.error(
                 "Cannot identify data type for collection %s. Skipping.",
                 key)
             continue
         from_proto = ops.get_from_proto_function(key)
         if from_proto and kind == "bytes_list":
             proto_type = ops.get_collection_proto_type(key)
             # It is assumed that there are no Variables Keys in collections
             for value in collection_def.bytes_list.value:
                 proto = proto_type()
                 proto.ParseFromString(value)
                 try:
                     new_value = from_proto(proto, import_scope=scope)
                 except:
                     continue
                 dest_graph.add_to_collection(key, new_value)
         else:
             field = getattr(collection_def, kind)
             if kind == "node_list":
                 for value in field.value:
                     name = ops.prepend_name_scope(value, scope)
                     # Since the graph has been optimized, the node may no longer
                     # exists
                     try:
                         col_op = dest_graph.as_graph_element(name)
                     except (TypeError, ValueError, KeyError) as e:
                         continue
                     dest_graph.add_to_collection(key, col_op)
             elif kind == "int64_list":
                 # NOTE(opensource): This force conversion is to work around the
                 # fact that Python2 distinguishes between int and long, while
                 # Python3 has only int.
                 for value in field.value:
                     dest_graph.add_to_collection(key, int(value))
             else:
                 for value in field.value:
                     dest_graph.add_to_collection(
                         key, ops.prepend_name_scope(value, scope))
예제 #10
0
 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
   """Restores collections that we need to keep."""
   scope = ""
   for key in collection_keys:
     collection_def = src_meta_graph_def.collection_def[key]
     kind = collection_def.WhichOneof("kind")
     if kind is None:
       tf_logging.error(
           "Cannot identify data type for collection %s. Skipping.", key)
       continue
     from_proto = ops.get_from_proto_function(key)
     if from_proto and kind == "bytes_list":
       proto_type = ops.get_collection_proto_type(key)
       # It is assumed that there are no Variables Keys in collections
       for value in collection_def.bytes_list.value:
         proto = proto_type()
         proto.ParseFromString(value)
         try:
           new_value = from_proto(proto, import_scope=scope)
         except:
           continue
         dest_graph.add_to_collection(key, new_value)
     else:
       field = getattr(collection_def, kind)
       if kind == "node_list":
         for value in field.value:
           name = ops.prepend_name_scope(value, scope)
           # Since the graph has been optimized, the node may no longer
           # exists
           try:
             col_op = dest_graph.as_graph_element(name)
           except (TypeError, ValueError, KeyError) as e:
             continue
           dest_graph.add_to_collection(key, col_op)
       elif kind == "int64_list":
         # NOTE(opensource): This force conversion is to work around the
         # fact that Python2 distinguishes between int and long, while
         # Python3 has only int.
         for value in field.value:
           dest_graph.add_to_collection(key, int(value))
       else:
         for value in field.value:
           dest_graph.add_to_collection(key,
                                        ops.prepend_name_scope(value, scope))
예제 #11
0
    def _tf_meta_graph_def_to_lgf_meta_graph_info(self, meta_graph_def):
        # Store the serialized partial meta graph def
        partial_meta_graph_def = tf.MetaGraphDef()
        partial_meta_graph_def.saver_def.CopyFrom(meta_graph_def.saver_def)
        for k, v in meta_graph_def.collection_def.items():
            partial_meta_graph_def.collection_def[k].CopyFrom(v)
        for k, v in meta_graph_def.signature_def.items():
            partial_meta_graph_def.signature_def[k].CopyFrom(v)

        partial_meta_graph_def.graph_def.CopyFrom(meta_graph_def.graph_def)
        del (partial_meta_graph_def.graph_def.node[:])

        meta_graph_info = lgf_pb2.MetaGraphInfo()
        meta_graph_info.original_graph_info[
            ImportTFSavedModelBase.
            PARTIAL_META_GRAPH_DEF].v = partial_meta_graph_def.SerializeToString(
            )

        # Get all strings from the partial meta graph def
        proto_strings = self.get_strings_from_proto(partial_meta_graph_def)

        # Need to manually deserialize the collection defs
        for k, collection_def in partial_meta_graph_def.collection_def.items():
            if collection_def.HasField("bytes_list"):
                proto_type = tf_ops.get_collection_proto_type(k)
                for serialized_proto in collection_def.bytes_list.value:
                    proto = proto_type()
                    proto.ParseFromString(serialized_proto)
                    proto_strings.extend(self.get_strings_from_proto(proto))

        # Get all the node names from proto_strings
        node_names = {n.name for n in self._graph_def.node}
        for string in proto_strings:
            if string != "":
                name = self.get_node_name_and_output_index(string)[0]
                if name in node_names:
                    self._required_nodes.add(name)

        # Add required nodes to meta_graph_info
        meta_graph_info.required_nodes.extend(self._required_nodes)

        return meta_graph_info
예제 #12
0
파일: saver.py 프로젝트: hdzz/tensorflow
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, (str, bytes, unicode)):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception, e:  # pylint: disable=broad-except
    logging.warning("Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
예제 #13
0
def import_scoped_meta_graph_with_return_elements(
        meta_graph_or_file,
        clear_devices=False,
        graph=None,
        import_scope=None,
        input_map=None,
        unbound_inputs_col_name="unbound_inputs",
        restore_collections_predicate=(lambda key: True),
        return_elements=None):
    """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
    if context.executing_eagerly():
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "eager execution is enabled.")
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." % ",".join(
                            compat.as_str(v) for v in field.value
                            if not input_map or v not in input_map))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        scope_to_prepend_to_names = graph.unique_name(import_scope or "",
                                                      mark_as_used=False)

        imported_return_elements = importer.import_graph_def(
            input_graph_def,
            name=(import_scope or scope_to_prepend_to_names),
            input_map=input_map,
            producer_op_list=producer_op_list,
            return_elements=return_elements)

        # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
        # without a VariableDef.trainable field set.
        tf_version = meta_graph_def.meta_info_def.tensorflow_version
        if not tf_version:
            variables_have_trainable = True
        else:
            variables_have_trainable = (packaging_version.parse(tf_version) >=
                                        packaging_version.parse("1.9"))

        # Sort collections so we see TRAINABLE_VARIABLES first and can default these
        # variables to trainable if the value is not set in their VariableDef.
        sorted_collections = []
        if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
            sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES,
                                       meta_graph_def.collection_def[
                                           ops.GraphKeys.TRAINABLE_VARIABLES]))
        for key, value in sorted(meta_graph_def.collection_def.items()):
            if key != ops.GraphKeys.TRAINABLE_VARIABLES:
                sorted_collections.append((key, value))

        # Restores all the other collections.
        variable_objects = {}
        for key, col_def in sorted_collections:
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue
            if not restore_collections_predicate(key):
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)

            # Temporary change to allow the TFMA evaluator to read metric variables
            # saved as a bytes list.
            # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
            if key == ops.GraphKeys.METRIC_VARIABLES:
                # Metric variables will use the same proto functions as GLOBAL_VARIABLES
                from_proto = ops.get_from_proto_function(
                    ops.GraphKeys.GLOBAL_VARIABLES)
            if from_proto and kind == "bytes_list":
                proto_type = ops.get_collection_proto_type(key)
                if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
                    for value in col_def.bytes_list.value:
                        variable = variable_objects.get(value, None)
                        if variable is None:
                            proto = proto_type()
                            proto.ParseFromString(value)
                            if not variables_have_trainable:
                                # If the VariableDef proto does not contain a "trainable"
                                # property because it was exported before that property was
                                # added, we default it to whether the variable is in the
                                # TRAINABLE_VARIABLES collection. We've sorted
                                # TRAINABLE_VARIABLES to be first, so trainable variables will
                                # be created from that collection.
                                proto.trainable = (
                                    key == ops.GraphKeys.TRAINABLE_VARIABLES)
                            variable = from_proto(
                                proto, import_scope=scope_to_prepend_to_names)
                            variable_objects[value] = variable
                        graph.add_to_collection(key, variable)
                else:
                    for value in col_def.bytes_list.value:
                        proto = proto_type()
                        proto.ParseFromString(value)
                        graph.add_to_collection(
                            key,
                            from_proto(proto,
                                       import_scope=scope_to_prepend_to_names))
            else:
                field = getattr(col_def, kind)
                if key in _COMPAT_COLLECTION_LIST:
                    logging.warning(
                        "The saved meta_graph is possibly from an older release:\n"
                        "'%s' collection should be of type 'byte_list', but instead "
                        "is of type '%s'.", key, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key,
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=scope_to_prepend_to_names)
        for v in variables:
            var_list[ops.strip_name_scope(v.name,
                                          scope_to_prepend_to_names)] = v

    return var_list, imported_return_elements
예제 #14
0
def add_collection_def(meta_graph_def,
                       key,
                       graph=None,
                       export_scope=None,
                       exclude_nodes=None,
                       override_contents=None):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from the
      collection, or None.
    override_contents: An iterable of values to place in the collection,
      ignoring the current values (if set).
  """
    if graph and not isinstance(graph, ops.Graph):
        raise TypeError(
            f"graph must be of type Graph. Received type: {type(graph)}.")

    if not isinstance(key, str) and not isinstance(key, bytes):
        logging.warning(
            "Only collections with string type keys will be "
            "serialized. This key has %s", type(key))
        return

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    if override_contents:
        collection_list = override_contents
    else:
        collection_list = graph.get_collection(key)

    # Remove nodes that should not be exported from the collection list.
    collection_list = [
        x for x in collection_list
        if _should_include_node(x, export_scope, exclude_nodes)
    ]
    if not collection_list:
        return

    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x, export_scope=export_scope)
                if proto:
                    assert isinstance(proto, proto_type)
                    getattr(col_def,
                            kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                for x in collection_list:
                    if not export_scope or x.name.startswith(export_scope):
                        getattr(col_def, kind).value.append(
                            ops.strip_name_scope(x.name, export_scope))
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend(
                    [compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def,
                        kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Issue encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef. Note this is a warning "
            "and probably safe to ignore.\n%s", key, str(e))
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
예제 #15
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
예제 #16
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.in_eager_mode():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
예제 #17
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
      else:
        field = getattr(col_def, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, import_scope))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, import_scope))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.VARIABLES,
                                     scope=import_scope)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, import_scope)] = v

  return var_list