Пример #1
0
    def create_apply_graph(self, signature, input_tensors, name):
        """See `ModuleImpl.create_apply_graph`."""
        signature_def = self._meta_graph.signature_def.get(signature)

        # Build a input map to feed when importing the apply-graph by augmenting the
        # state_map with the input args. This allows an input to override a tensor
        # from the state-graph.
        feed_map = dict(self._state_map)
        feed_map.update(
            tensor_info.build_input_map(signature_def.inputs, input_tensors))

        # Make state tensors enter the current context. This way the Module can be
        # applied inside a control flow structure such as a while_loop.
        control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
        if control_flow:
            for key, value in sorted(feed_map.items()):
                feed_map[key] = control_flow.AddValue(value)

        # Don't mark the name as used at this point - import_scoped_meta_graph will
        # start using it.
        absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
        relative_scope_name = absolute_scope_name.split("/")[-1]

        import_collections = [
            # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
            # ops, however one could create a graph that uses an asset at any other
            # time. As so everytime we bring the tensor with that has the asset
            # filename we must annotate it as so, so later re-exports have that
            # semantic information and can handle it.
            tf.GraphKeys.ASSET_FILEPATHS,
            tf.GraphKeys.COND_CONTEXT,
            tf.GraphKeys.WHILE_CONTEXT,
        ]
        if self._trainable:
            import_collections.extend([tf.GraphKeys.UPDATE_OPS])

        meta_graph = meta_graph_pb2.MetaGraphDef()
        meta_graph.CopyFrom(self._meta_graph)

        meta_graph_lib.filter_collections(meta_graph, import_collections)
        meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                     absolute_scope_name)

        tf.train.import_meta_graph(meta_graph,
                                   input_map=feed_map,
                                   import_scope=relative_scope_name)
        fix_colocation_after_import(input_map=feed_map,
                                    absolute_import_scope=absolute_scope_name)

        def get_tensor(name):
            # When trying to output an input tensor there are no nodes created within
            # the apply scope. So one must look into the input map.
            try:
                return feed_map[name]
            except KeyError:
                return self._graph.get_tensor_by_name(
                    meta_graph_lib.prepend_name_scope(
                        name, import_scope=absolute_scope_name))

        return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Пример #2
0
  def create_apply_graph(self, signature, inputs, name):
    """See `ModuleImpl.create_apply_graph`."""
    signature_def = self._meta_graph.signature_def.get(signature)

    input_tensors = tensor_info.convert_to_input_tensors(
        signature_def.inputs, inputs)

    # Build a input map to feed when importing the apply-graph by augmenting the
    # state_map with the input args. This allows an input to override a tensor
    # from the state-graph.
    feed_map = dict(self._state_map)
    feed_map.update(
        tensor_info.build_input_map(signature_def.inputs, input_tensors))

    # Make state tensors enter the current context. This way the Module can be
    # applied inside a control flow structure such as a while_loop.
    control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
    if control_flow:
      for key, value in sorted(feed_map.items()):
        feed_map[key] = control_flow.AddValue(value)

    # Don't mark the name as used at this point - import_scoped_meta_graph will
    # start using it.
    absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
    relative_scope_name = absolute_scope_name.split("/")[-1]

    import_collections = [
        # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
        # ops, however one could create a graph that uses an asset at any other
        # time. As so everytime we bring the tensor with that has the asset
        # filename we must annotate it as so, so later re-exports have that
        # semantic information and can handle it.
        tf.GraphKeys.ASSET_FILEPATHS,
        tf.GraphKeys.COND_CONTEXT,
        tf.GraphKeys.WHILE_CONTEXT,
    ]
    if self._trainable:
      import_collections.extend([tf.GraphKeys.UPDATE_OPS])

    tf.train.import_meta_graph(
        adapted_meta_graph_for_import(self._meta_graph, absolute_scope_name),
        input_map=feed_map,
        import_scope=relative_scope_name,
        restore_collections_predicate=(lambda key: key in import_collections))
    fix_colocation_after_import(input_map=feed_map,
                                absolute_import_scope=absolute_scope_name)

    def get_tensor(name):
      # When trying to output an input tensor there are no nodes created within
      # the apply scope. So one must look into the input map.
      try:
        return feed_map[name]
      except KeyError:
        return self._graph.get_tensor_by_name(
            prepend_name_scope(name, import_scope=absolute_scope_name))

    return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Пример #3
0
  def testBuildInputMap(self):
    x = tf.placeholder(tf.int32, [2])
    y = tf.sparse_placeholder(tf.string, [None])
    sig = _make_signature({"x": x, "y": y}, {})

    input_map = tensor_info.build_input_map(sig.inputs, {"x": x, "y": y})
    self.assertEquals(len(input_map), 4)
    self.assertEquals(input_map[x.name], x)
    self.assertEquals(input_map[y.indices.name], y.indices)
    self.assertEquals(input_map[y.values.name], y.values)
    self.assertEquals(input_map[y.dense_shape.name], y.dense_shape)
Пример #4
0
def make_prune(input_tensors):
    m = import_module()
    # def model_fn():
    #     z = tf.placeholder(dtype=tf.float32)
    #     x = m(z, signature='generate', as_dict=True)['default']
    #     hub.add_signature(signature, z, x)
    # tags_and_args = [({'train'}, {}), (set(), {})]
    # drop_collections = [v for v in tf.global_variables() if 'Generator' not in v.name]
    # p = hub.create_module_spec(model_fn, tags_and_args, drop_collections)
    graph = m._impl._meta_graph
    signature_def = m._spec._get_signature_def('generate', {'train'})
    infeed_map = tensor_info.build_input_map(signature_def.inputs,
                                             input_tensors)
    prune_unused_nodes(graph, signature_def)
    prune_feed_map(graph, infeed_map)
    return m
Пример #5
0
    def testBuildInputMap(self):
        with tf.compat.v1.Graph().as_default():
            x = tf.compat.v1.placeholder(tf.int32, [2])
            y = tf.compat.v1.sparse_placeholder(tf.string, [None])
            r = tf.compat.v1.ragged.placeholder(tf.float32, 1, ())
            sig = _make_signature({"x": x, "y": y, "r": r}, {})

            input_map = tensor_info.build_input_map(sig.inputs, {
                "x": x,
                "y": y,
                "r": r
            })
            self.assertEqual(len(input_map), 6)
            self.assertIs(input_map[x.name], x)
            self.assertIs(input_map[y.indices.name], y.indices)
            self.assertIs(input_map[y.values.name], y.values)
            self.assertIs(input_map[y.dense_shape.name], y.dense_shape)
            self.assertIs(input_map[r.values.name], r.values)
            self.assertIs(input_map[r.row_splits.name], r.row_splits)
Пример #6
0
  def create_apply_graph(self, signature, input_tensors, name):
    """See `ModuleImpl.create_apply_graph`."""
    signature_def = self._meta_graph.signature_def.get(signature)
    meta_graph = meta_graph_pb2.MetaGraphDef()
    meta_graph.CopyFrom(self._meta_graph)
    apply_graph = tf_v1.get_default_graph()
    infeed_map = tensor_info.build_input_map(signature_def.inputs,
                                             input_tensors)

    # Build a input map to feed when importing the apply-graph by augmenting the
    # state_map with the input args. This allows an input to override a tensor
    # from the state-graph.
    feed_map = dict(self._state_map)
    # If we are applying the module in a function with a TPUReplicateContext, we
    # must capture the state tensors in generating our feedmap and prune out
    # assign ops. Function graph semantics are different in that all ops are
    # executed regardless of dependency.
    # TODO(b/112575006): The following adds functionality of function call
    # within a TPU context. Work to generalize this for all function calls is
    # ongoing.
    if _is_tpu_graph_function():
      for k, v in self._state_map.items():
        feed_map[k] = apply_graph.capture(v)
      meta_graph_lib.prune_unused_nodes(meta_graph, signature_def)
      # After we prune the metagraph def, we might need to prune away
      # infeeds which no longer exist.
      meta_graph_lib.prune_feed_map(meta_graph, infeed_map)
    elif apply_graph.building_function:
      # Log a warning if a user is using a hub module in function graph.
      # This is only expected to work if the function graph is pruned and
      # not all nodes are executed.
      #
      # E.g. it could work with "tf.compat.v1.wrap_function", but it will not
      # work with defun, Dataset.map_fn, etc...
      logging.warning("Using `hub.Module` while building a function: %s. This "
                      "can lead to errors if the function is not pruned.",
                      apply_graph.name)

    # As state ops in the apply graph are unused, replace them with Placeholders
    # so that in a heirarchical instantiation, apply_graph state ops are
    # ignored.
    replace_apply_state(
        meta_graph,
        list_registered_stateful_ops_without_inputs(meta_graph.graph_def),
        feed_map)
    feed_map.update(infeed_map)

    # Make state tensors enter the current context. This way the Module can be
    # applied inside a control flow structure such as a while_loop.
    control_flow = apply_graph._get_control_flow_context()  # pylint: disable=protected-access
    if control_flow:
      for key, value in sorted(feed_map.items()):
        feed_map[key] = control_flow.AddValue(value)

    # Don't mark the name as used at this point - import_scoped_meta_graph will
    # start using it.
    absolute_scope_name = apply_graph.unique_name(name, mark_as_used=False)
    relative_scope_name = absolute_scope_name.split("/")[-1]

    import_collections = [
        # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
        # ops, however one could create a graph that uses an asset at any other
        # time. As so everytime we bring the tensor with that has the asset
        # filename we must annotate it as so, so later re-exports have that
        # semantic information and can handle it.
        tf_v1.GraphKeys.ASSET_FILEPATHS,
        tf_v1.GraphKeys.COND_CONTEXT,
        tf_v1.GraphKeys.WHILE_CONTEXT,
    ]
    if self._trainable:
      import_collections.extend([tf_v1.GraphKeys.UPDATE_OPS])

    meta_graph_lib.filter_collections(meta_graph, import_collections)
    meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                 absolute_scope_name)
    if len(meta_graph.collection_def) and _is_tpu_graph_function():
      raise NotImplementedError(
          "Applying modules with collections inside TPU functions is not "
          "supported. Collections found: %s" % str(meta_graph.collection_def))

    tf_v1.train.import_meta_graph(
        meta_graph,
        input_map=feed_map,
        import_scope=relative_scope_name)
    fix_colocation_after_import(input_map=feed_map,
                                absolute_import_scope=absolute_scope_name)

    def get_tensor(name):
      # When trying to output an input tensor there are no nodes created within
      # the apply scope. So one must look into the input map.
      try:
        return feed_map[name]
      except KeyError:
        return apply_graph.get_tensor_by_name(
            meta_graph_lib.prepend_name_scope(
                name, import_scope=absolute_scope_name))

    return tensor_info.build_output_map(signature_def.outputs, get_tensor)