def create_apply_graph(self, signature, input_tensors, name): """See `ModuleImpl.create_apply_graph`.""" signature_def = self._meta_graph.signature_def.get(signature) # Build a input map to feed when importing the apply-graph by augmenting the # state_map with the input args. This allows an input to override a tensor # from the state-graph. feed_map = dict(self._state_map) feed_map.update( tensor_info.build_input_map(signature_def.inputs, input_tensors)) # Make state tensors enter the current context. This way the Module can be # applied inside a control flow structure such as a while_loop. control_flow = self._graph._get_control_flow_context() # pylint: disable=protected-access if control_flow: for key, value in sorted(feed_map.items()): feed_map[key] = control_flow.AddValue(value) # Don't mark the name as used at this point - import_scoped_meta_graph will # start using it. absolute_scope_name = self._graph.unique_name(name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] import_collections = [ # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS # ops, however one could create a graph that uses an asset at any other # time. As so everytime we bring the tensor with that has the asset # filename we must annotate it as so, so later re-exports have that # semantic information and can handle it. tf.GraphKeys.ASSET_FILEPATHS, tf.GraphKeys.COND_CONTEXT, tf.GraphKeys.WHILE_CONTEXT, ] if self._trainable: import_collections.extend([tf.GraphKeys.UPDATE_OPS]) meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.CopyFrom(self._meta_graph) meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) tf.train.import_meta_graph(meta_graph, input_map=feed_map, import_scope=relative_scope_name) fix_colocation_after_import(input_map=feed_map, absolute_import_scope=absolute_scope_name) def get_tensor(name): # When trying to output an input tensor there are no nodes created within # the apply scope. So one must look into the input map. try: return feed_map[name] except KeyError: return self._graph.get_tensor_by_name( meta_graph_lib.prepend_name_scope( name, import_scope=absolute_scope_name)) return tensor_info.build_output_map(signature_def.outputs, get_tensor)
def create_apply_graph(self, signature, inputs, name): """See `ModuleImpl.create_apply_graph`.""" signature_def = self._meta_graph.signature_def.get(signature) input_tensors = tensor_info.convert_to_input_tensors( signature_def.inputs, inputs) # Build a input map to feed when importing the apply-graph by augmenting the # state_map with the input args. This allows an input to override a tensor # from the state-graph. feed_map = dict(self._state_map) feed_map.update( tensor_info.build_input_map(signature_def.inputs, input_tensors)) # Make state tensors enter the current context. This way the Module can be # applied inside a control flow structure such as a while_loop. control_flow = self._graph._get_control_flow_context() # pylint: disable=protected-access if control_flow: for key, value in sorted(feed_map.items()): feed_map[key] = control_flow.AddValue(value) # Don't mark the name as used at this point - import_scoped_meta_graph will # start using it. absolute_scope_name = self._graph.unique_name(name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] import_collections = [ # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS # ops, however one could create a graph that uses an asset at any other # time. As so everytime we bring the tensor with that has the asset # filename we must annotate it as so, so later re-exports have that # semantic information and can handle it. tf.GraphKeys.ASSET_FILEPATHS, tf.GraphKeys.COND_CONTEXT, tf.GraphKeys.WHILE_CONTEXT, ] if self._trainable: import_collections.extend([tf.GraphKeys.UPDATE_OPS]) tf.train.import_meta_graph( adapted_meta_graph_for_import(self._meta_graph, absolute_scope_name), input_map=feed_map, import_scope=relative_scope_name, restore_collections_predicate=(lambda key: key in import_collections)) fix_colocation_after_import(input_map=feed_map, absolute_import_scope=absolute_scope_name) def get_tensor(name): # When trying to output an input tensor there are no nodes created within # the apply scope. So one must look into the input map. try: return feed_map[name] except KeyError: return self._graph.get_tensor_by_name( prepend_name_scope(name, import_scope=absolute_scope_name)) return tensor_info.build_output_map(signature_def.outputs, get_tensor)
def testBuildOutputMap(self): x = tf.placeholder(tf.int32, [2]) y = tf.sparse_placeholder(tf.string, [None]) sig = _make_signature({}, {"x": x, "y": y}) def _get_tensor(name): return tf.get_default_graph().get_tensor_by_name(name) output_map = tensor_info.build_output_map(sig.outputs, _get_tensor) self.assertEquals(len(output_map), 2) self.assertEquals(output_map["x"], x) self.assertEquals(output_map["y"].indices, y.indices) self.assertEquals(output_map["y"].values, y.values) self.assertEquals(output_map["y"].dense_shape, y.dense_shape)
def testBuildOutputMap(self): with tf.compat.v1.Graph().as_default(): x = tf.compat.v1.placeholder(tf.int32, [2]) y = tf.compat.v1.sparse_placeholder(tf.string, [None]) r = tf.compat.v1.ragged.placeholder(tf.float32, 1, ()) sig = _make_signature({}, {"x": x, "y": y, "r": r}) def _get_tensor(name): return tf.compat.v1.get_default_graph().get_tensor_by_name( name) output_map = tensor_info.build_output_map(sig.outputs, _get_tensor) self.assertEqual(len(output_map), 3) self.assertIs(output_map["x"], x) self.assertIs(output_map["y"].indices, y.indices) self.assertIs(output_map["y"].values, y.values) self.assertIs(output_map["y"].dense_shape, y.dense_shape) self.assertIs(output_map["r"].values, r.values) self.assertIs(output_map["r"].row_splits, r.row_splits)
def create_apply_graph(self, signature, input_tensors, name): """See `ModuleImpl.create_apply_graph`.""" signature_def = self._meta_graph.signature_def.get(signature) meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.CopyFrom(self._meta_graph) apply_graph = tf_v1.get_default_graph() infeed_map = tensor_info.build_input_map(signature_def.inputs, input_tensors) # Build a input map to feed when importing the apply-graph by augmenting the # state_map with the input args. This allows an input to override a tensor # from the state-graph. feed_map = dict(self._state_map) # If we are applying the module in a function with a TPUReplicateContext, we # must capture the state tensors in generating our feedmap and prune out # assign ops. Function graph semantics are different in that all ops are # executed regardless of dependency. # TODO(b/112575006): The following adds functionality of function call # within a TPU context. Work to generalize this for all function calls is # ongoing. if _is_tpu_graph_function(): for k, v in self._state_map.items(): feed_map[k] = apply_graph.capture(v) meta_graph_lib.prune_unused_nodes(meta_graph, signature_def) # After we prune the metagraph def, we might need to prune away # infeeds which no longer exist. meta_graph_lib.prune_feed_map(meta_graph, infeed_map) elif apply_graph.building_function: # Log a warning if a user is using a hub module in function graph. # This is only expected to work if the function graph is pruned and # not all nodes are executed. # # E.g. it could work with "tf.compat.v1.wrap_function", but it will not # work with defun, Dataset.map_fn, etc... logging.warning("Using `hub.Module` while building a function: %s. This " "can lead to errors if the function is not pruned.", apply_graph.name) # As state ops in the apply graph are unused, replace them with Placeholders # so that in a heirarchical instantiation, apply_graph state ops are # ignored. replace_apply_state( meta_graph, list_registered_stateful_ops_without_inputs(meta_graph.graph_def), feed_map) feed_map.update(infeed_map) # Make state tensors enter the current context. This way the Module can be # applied inside a control flow structure such as a while_loop. control_flow = apply_graph._get_control_flow_context() # pylint: disable=protected-access if control_flow: for key, value in sorted(feed_map.items()): feed_map[key] = control_flow.AddValue(value) # Don't mark the name as used at this point - import_scoped_meta_graph will # start using it. absolute_scope_name = apply_graph.unique_name(name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] import_collections = [ # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS # ops, however one could create a graph that uses an asset at any other # time. As so everytime we bring the tensor with that has the asset # filename we must annotate it as so, so later re-exports have that # semantic information and can handle it. tf_v1.GraphKeys.ASSET_FILEPATHS, tf_v1.GraphKeys.COND_CONTEXT, tf_v1.GraphKeys.WHILE_CONTEXT, ] if self._trainable: import_collections.extend([tf_v1.GraphKeys.UPDATE_OPS]) meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) if len(meta_graph.collection_def) and _is_tpu_graph_function(): raise NotImplementedError( "Applying modules with collections inside TPU functions is not " "supported. Collections found: %s" % str(meta_graph.collection_def)) tf_v1.train.import_meta_graph( meta_graph, input_map=feed_map, import_scope=relative_scope_name) fix_colocation_after_import(input_map=feed_map, absolute_import_scope=absolute_scope_name) def get_tensor(name): # When trying to output an input tensor there are no nodes created within # the apply scope. So one must look into the input map. try: return feed_map[name] except KeyError: return apply_graph.get_tensor_by_name( meta_graph_lib.prepend_name_scope( name, import_scope=absolute_scope_name)) return tensor_info.build_output_map(signature_def.outputs, get_tensor)