Example #1
0
def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))

  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  collection_list = graph.get_collection(key)
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
Example #2
0
def add_collection_def(meta_graph_def, key):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning(
            "Only collections with string type keys will be "
            "serialized. This key has %s", type(key))
        return
    collection_list = ops.get_collection(key)
    if not collection_list:
        return
    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x)
                if not isinstance(proto, proto_type):
                    raise TypeError("proto %s is not type %s" %
                                    (proto, proto_type))
                getattr(col_def, kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                getattr(col_def,
                        kind).value.extend([x.name for x in collection_list])
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend(
                    [compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def,
                        kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s", key, str(e))
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
Example #3
0
def add_collection_def(meta_graph_def, key):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key))
        return
    collection_list = ops.get_collection(key)
    if not collection_list:
        return
    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x)
                if not isinstance(proto, proto_type):
                    raise TypeError("proto %s is not type %s" % (proto, proto_type))
                getattr(col_def, kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                getattr(col_def, kind).value.extend([x.name for x in collection_list])
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend([compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def, kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s",
            key,
            str(e),
        )
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
Example #4
0
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, (str, bytes, unicode)):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception, e:  # pylint: disable=broad-except
    logging.warning("Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
Example #5
0
def add_collection_def(meta_graph_def,
                       key,
                       graph=None,
                       export_scope=None,
                       exclude_nodes=None,
                       override_contents=None):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from the
      collection, or None.
    override_contents: An iterable of values to place in the collection,
      ignoring the current values (if set).
  """
    if graph and not isinstance(graph, ops.Graph):
        raise TypeError(
            f"graph must be of type Graph. Received type: {type(graph)}.")

    if not isinstance(key, str) and not isinstance(key, bytes):
        logging.warning(
            "Only collections with string type keys will be "
            "serialized. This key has %s", type(key))
        return

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    if override_contents:
        collection_list = override_contents
    else:
        collection_list = graph.get_collection(key)

    # Remove nodes that should not be exported from the collection list.
    collection_list = [
        x for x in collection_list
        if _should_include_node(x, export_scope, exclude_nodes)
    ]
    if not collection_list:
        return

    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x, export_scope=export_scope)
                if proto:
                    assert isinstance(proto, proto_type)
                    getattr(col_def,
                            kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                for x in collection_list:
                    if not export_scope or x.name.startswith(export_scope):
                        getattr(col_def, kind).value.append(
                            ops.strip_name_scope(x.name, export_scope))
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend(
                    [compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def,
                        kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Issue encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef. Note this is a warning "
            "and probably safe to ignore.\n%s", key, str(e))
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return