def export_module_spec(spec, export_path): """Export module with random initialization.""" with tf_v1.Graph().as_default(): m = hub.Module(spec) with tf_v1.Session() as session: session.run(tf_v1.initializers.global_variables()) m.export(export_path, session)
def testMatchingTensorInfoProtoMaps(self): with tf_v1.Graph().as_default(): sig1 = _make_signature({ "x": tf_v1.placeholder(tf.int32, [2]), }, { "x": tf_v1.placeholder(tf.int32, [2]), }) sig2 = _make_signature( { "x": tf_v1.placeholder(tf.int32, [2]), }, { "x": tf_v1.sparse_placeholder(tf.int64, [2]), }) self.assertTrue( tensor_info.tensor_info_proto_maps_match( sig1.inputs, sig2.inputs)) self.assertFalse( tensor_info.tensor_info_proto_maps_match( sig1.outputs, sig2.outputs)) sig3 = _make_signature({ "x": tf_v1.placeholder(tf.int32, [None]), }, { "x": tf_v1.placeholder(tf.int32, [2]), }) self.assertFalse( tensor_info.tensor_info_proto_maps_match( sig1.inputs, sig3.inputs)) self.assertTrue( tensor_info.tensor_info_proto_maps_match( sig1.outputs, sig3.outputs))
def prune_unused_nodes(meta_graph, signature_def): """Function to prune unused ops given a signature def. This function does a graph traversal through from all outputs as defined in the signature_def to collect all used nodes. Then, any nodes which are unused can be discarded. This is useful for graph which are executing eagerly or on TPUs. Args: meta_graph: The input/output MetaGraphDef for which we wish to prune. signature_def: A SignatureDef which specifies the outputs from which we wish to start graph traversal. """ # Instantiate a temporary empty graph so that we have access to Graph API # and import the meta_graph. graph = tf_v1.Graph() with graph.as_default(): tf_v1.train.import_meta_graph(meta_graph, input_map={}, import_scope="") # Traverse from all outputs and mark all nodes. used_node_names = set() for _, tensor_def in signature_def.outputs.items(): output_tensor = graph.get_tensor_by_name(tensor_def.name) mark_backward(output_tensor, used_node_names) # Filter out all nodes in the meta_graph that are not used. node_filter_in_list = [] for node in meta_graph.graph_def.node: # Make a special exception for VarHandleOp. Removing VarhandleOps # will make the graph not importable as they often leave nodes hanging. # These will be disconnected through the feedmap when importing the # metagraph. if node.name in used_node_names or node.op == "VarHandleOp": node_filter_in_list.append(node) del meta_graph.graph_def.node[:] meta_graph.graph_def.node.extend(node_filter_in_list) del graph
def testParsingTensorInfoProtoMaps(self): with tf_v1.Graph().as_default(): sig = _make_signature( { "x": tf_v1.placeholder(tf.string, [2]), }, { "y": tf_v1.placeholder(tf.int32, [2]), "z": tf_v1.sparse_placeholder(tf.float32, [2, 10]), }) inputs = tensor_info.parse_tensor_info_map(sig.inputs) self.assertEqual(set(inputs.keys()), set(["x"])) self.assertEqual(inputs["x"].get_shape(), [2]) self.assertEqual(inputs["x"].dtype, tf.string) self.assertFalse(inputs["x"].is_sparse) outputs = tensor_info.parse_tensor_info_map(sig.outputs) self.assertEqual(set(outputs.keys()), set(["y", "z"])) self.assertEqual(outputs["y"].get_shape(), [2]) self.assertEqual(outputs["y"].dtype, tf.int32) self.assertFalse(outputs["y"].is_sparse) self.assertEqual(outputs["z"].get_shape(), [2, 10]) self.assertEqual(outputs["z"].dtype, tf.float32) self.assertTrue(outputs["z"].is_sparse)
def testBuildOutputMap(self): with tf_v1.Graph().as_default(): x = tf_v1.placeholder(tf.int32, [2]) y = tf_v1.sparse_placeholder(tf.string, [None]) sig = _make_signature({}, {"x": x, "y": y}) def _get_tensor(name): return tf_v1.get_default_graph().get_tensor_by_name(name) output_map = tensor_info.build_output_map(sig.outputs, _get_tensor) self.assertEqual(len(output_map), 2) self.assertEqual(output_map["x"], x) self.assertEqual(output_map["y"].indices, y.indices) self.assertEqual(output_map["y"].values, y.values) self.assertEqual(output_map["y"].dense_shape, y.dense_shape)
def testBuildInputMap(self): with tf_v1.Graph().as_default(): x = tf_v1.placeholder(tf.int32, [2]) y = tf_v1.sparse_placeholder(tf.string, [None]) sig = _make_signature({"x": x, "y": y}, {}) input_map = tensor_info.build_input_map(sig.inputs, { "x": x, "y": y }) self.assertEqual(len(input_map), 4) self.assertEqual(input_map[x.name], x) self.assertEqual(input_map[y.indices.name], y.indices) self.assertEqual(input_map[y.values.name], y.values) self.assertEqual(input_map[y.dense_shape.name], y.dense_shape)
def testConvertTensors(self): with tf_v1.Graph().as_default(): a = tf_v1.placeholder(tf.int32, [None]) protomap = _make_signature({"a": a}, {}).inputs targets = tensor_info.parse_tensor_info_map(protomap) # convert constant in0 = [1, 2, 3] output = tensor_info.convert_dict_to_compatible_tensor({"a": in0}, targets) self.assertEqual(output["a"].dtype, a.dtype) # check sparsity in1 = tf_v1.sparse_placeholder(tf.int32, []) with self.assertRaisesRegexp(TypeError, "dense"): tensor_info.convert_dict_to_compatible_tensor({"a": in1}, targets)
def testRepr(self): with tf_v1.Graph().as_default(): sig = _make_signature( { "x": tf_v1.placeholder(tf.string, [2]), }, { "y": tf_v1.placeholder(tf.int32, [2]), "z": tf_v1.sparse_placeholder(tf.float32, [2, 10]), }) outputs = tensor_info.parse_tensor_info_map(sig.outputs) self.assertEqual( repr(outputs["y"]), "<hub.ParsedTensorInfo shape=(2,) dtype=int32 is_sparse=False>" ) self.assertEqual( repr(outputs["z"]), "<hub.ParsedTensorInfo shape=(2, 10) dtype=float32 is_sparse=True>" )