def create_apply_graph(self, signature, input_tensors, name):
        """See `ModuleImpl.create_apply_graph`."""
        signature_def = self._meta_graph.signature_def.get(signature)

        # Build a input map to feed when importing the apply-graph by augmenting the
        # state_map with the input args. This allows an input to override a tensor
        # from the state-graph.
        feed_map = dict(self._state_map)
        feed_map.update(
            tensor_info.build_input_map(signature_def.inputs, input_tensors))

        # Make state tensors enter the current context. This way the Module can be
        # applied inside a control flow structure such as a while_loop.
        control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
        if control_flow:
            for key, value in sorted(feed_map.items()):
                feed_map[key] = control_flow.AddValue(value)

        # Don't mark the name as used at this point - import_scoped_meta_graph will
        # start using it.
        absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
        relative_scope_name = absolute_scope_name.split("/")[-1]

        import_collections = [
            # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
            # ops, however one could create a graph that uses an asset at any other
            # time. As so everytime we bring the tensor with that has the asset
            # filename we must annotate it as so, so later re-exports have that
            # semantic information and can handle it.
            tf.GraphKeys.ASSET_FILEPATHS,
            tf.GraphKeys.COND_CONTEXT,
            tf.GraphKeys.WHILE_CONTEXT,
        ]
        if self._trainable:
            import_collections.extend([tf.GraphKeys.UPDATE_OPS])

        meta_graph = meta_graph_pb2.MetaGraphDef()
        meta_graph.CopyFrom(self._meta_graph)

        meta_graph_lib.filter_collections(meta_graph, import_collections)
        meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                     absolute_scope_name)

        tf.train.import_meta_graph(meta_graph,
                                   input_map=feed_map,
                                   import_scope=relative_scope_name)
        fix_colocation_after_import(input_map=feed_map,
                                    absolute_import_scope=absolute_scope_name)

        def get_tensor(name):
            # When trying to output an input tensor there are no nodes created within
            # the apply scope. So one must look into the input map.
            try:
                return feed_map[name]
            except KeyError:
                return self._graph.get_tensor_by_name(
                    meta_graph_lib.prepend_name_scope(
                        name, import_scope=absolute_scope_name))

        return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Beispiel #2
0
  def create_apply_graph(self, signature, input_tensors, name):
    """See `ModuleImpl.create_apply_graph`."""
    signature_def = self._meta_graph.signature_def.get(signature)
    meta_graph = meta_graph_pb2.MetaGraphDef()
    meta_graph.CopyFrom(self._meta_graph)
    apply_graph = tf_v1.get_default_graph()
    infeed_map = tensor_info.build_input_map(signature_def.inputs,
                                             input_tensors)

    # Build a input map to feed when importing the apply-graph by augmenting the
    # state_map with the input args. This allows an input to override a tensor
    # from the state-graph.
    feed_map = dict(self._state_map)
    # If we are applying the module in a function with a TPUReplicateContext, we
    # must capture the state tensors in generating our feedmap and prune out
    # assign ops. Function graph semantics are different in that all ops are
    # executed regardless of dependency.
    # TODO(b/112575006): The following adds functionality of function call
    # within a TPU context. Work to generalize this for all function calls is
    # ongoing.
    if _is_tpu_graph_function():
      for k, v in self._state_map.items():
        feed_map[k] = apply_graph.capture(v)
      meta_graph_lib.prune_unused_nodes(meta_graph, signature_def)
      # After we prune the metagraph def, we might need to prune away
      # infeeds which no longer exist.
      meta_graph_lib.prune_feed_map(meta_graph, infeed_map)
    elif apply_graph.building_function:
      # Log a warning if a user is using a hub module in function graph.
      # This is only expected to work if the function graph is pruned and
      # not all nodes are executed.
      #
      # E.g. it could work with "tf.compat.v1.wrap_function", but it will not
      # work with defun, Dataset.map_fn, etc...
      logging.warning("Using `hub.Module` while building a function: %s. This "
                      "can lead to errors if the function is not pruned.",
                      apply_graph.name)

    # As state ops in the apply graph are unused, replace them with Placeholders
    # so that in a heirarchical instantiation, apply_graph state ops are
    # ignored.
    replace_apply_state(
        meta_graph,
        list_registered_stateful_ops_without_inputs(meta_graph.graph_def),
        feed_map)
    feed_map.update(infeed_map)

    # Make state tensors enter the current context. This way the Module can be
    # applied inside a control flow structure such as a while_loop.
    control_flow = apply_graph._get_control_flow_context()  # pylint: disable=protected-access
    if control_flow:
      for key, value in sorted(feed_map.items()):
        feed_map[key] = control_flow.AddValue(value)

    # Don't mark the name as used at this point - import_scoped_meta_graph will
    # start using it.
    absolute_scope_name = apply_graph.unique_name(name, mark_as_used=False)
    relative_scope_name = absolute_scope_name.split("/")[-1]

    import_collections = [
        # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
        # ops, however one could create a graph that uses an asset at any other
        # time. As so everytime we bring the tensor with that has the asset
        # filename we must annotate it as so, so later re-exports have that
        # semantic information and can handle it.
        tf_v1.GraphKeys.ASSET_FILEPATHS,
        tf_v1.GraphKeys.COND_CONTEXT,
        tf_v1.GraphKeys.WHILE_CONTEXT,
    ]
    if self._trainable:
      import_collections.extend([tf_v1.GraphKeys.UPDATE_OPS])

    meta_graph_lib.filter_collections(meta_graph, import_collections)
    meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                 absolute_scope_name)
    if len(meta_graph.collection_def) and _is_tpu_graph_function():
      raise NotImplementedError(
          "Applying modules with collections inside TPU functions is not "
          "supported. Collections found: %s" % str(meta_graph.collection_def))

    tf_v1.train.import_meta_graph(
        meta_graph,
        input_map=feed_map,
        import_scope=relative_scope_name)
    fix_colocation_after_import(input_map=feed_map,
                                absolute_import_scope=absolute_scope_name)

    def get_tensor(name):
      # When trying to output an input tensor there are no nodes created within
      # the apply scope. So one must look into the input map.
      try:
        return feed_map[name]
      except KeyError:
        return apply_graph.get_tensor_by_name(
            meta_graph_lib.prepend_name_scope(
                name, import_scope=absolute_scope_name))

    return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Beispiel #3
0
  def _create_state_graph(self, name):
    """Creates the graph nodes that hold the state of the Module.

    Args:
      name: name scope to create the state graph in.

    Returns:
      A tuple consisting of:
        variables_tensor_map: a map from tensor names in the original graph def
          to the created Variables objects.
        state_map: a map from tensors names in the original graph def to the
          instantiated tensors to be used as a state_map.
    """
    import_collections = [
        tf_v1.GraphKeys.GLOBAL_VARIABLES,
        tf_v1.GraphKeys.MODEL_VARIABLES,
        tf_v1.GraphKeys.TABLE_INITIALIZERS,
        tf_v1.GraphKeys.ASSET_FILEPATHS,  # Typically used to initialize tables.
        tf_v1.GraphKeys.COND_CONTEXT,
        tf_v1.GraphKeys.WHILE_CONTEXT,
    ]
    if self._trainable:
      # TODO(b/64049014): Import UPDATE_OPS which do not depend on inputs.
      import_collections.extend([tf_v1.GraphKeys.TRAINABLE_VARIABLES,
                                 tf_v1.GraphKeys.REGULARIZATION_LOSSES])

    absolute_scope_name = tf_v1.get_default_graph().unique_name(
        name, mark_as_used=False)
    relative_scope_name = absolute_scope_name.split("/")[-1]
    assert relative_scope_name == name  # verify name scope was indeed unused.

    meta_graph = meta_graph_pb2.MetaGraphDef()
    meta_graph.CopyFrom(self._meta_graph)

    meta_graph_lib.filter_collections(meta_graph, import_collections)
    meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                 absolute_scope_name)

    tf_v1.train.import_meta_graph(
        meta_graph,
        input_map={},
        import_scope=relative_scope_name)

    # Build a list from the variable name in the module definition to the actual
    # instantiated variables.
    variables_tensor_map = {}
    for var in tf_v1.global_variables():
      if var.op.name.startswith(absolute_scope_name + "/"):
        variables_tensor_map[var.name[len(absolute_scope_name)+1:]] = var

    # Build a map of tensors to feed from the state-graph into subsequent
    # apply-graphs.
    def _get_tensor(tensor_name):
      return tf_v1.get_default_graph().get_tensor_by_name(
          meta_graph_lib.prepend_name_scope(
              tensor_name, import_scope=absolute_scope_name))

    state_op_names = list_registered_stateful_ops_without_inputs(
        meta_graph.graph_def)
    state_map = get_state_map(meta_graph, state_op_names, set(), _get_tensor)

    return variables_tensor_map, state_map