def testFindHintedOutputNodes(self): """Test if all hinted output nodes are correctly found.""" with ops.Graph().as_default(): def _build_ophinted_op(name, input1, input2): custom_op = op_hint.OpHint(name) input1 = custom_op.add_input(input1) input2 = custom_op.add_input(input2) output = math_ops.mul(input1, input2) return custom_op.add_output(output) output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]), array_ops.constant([2.])) output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]), array_ops.constant([4.])) with self.cached_session() as sess: hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes( sess) expected_hinted_output_nodes = [ _node_name(output_1.name), _node_name(output_2.name) ] self.assertEqual(len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
def testFindHintedOutputNodes(self): """Test if all hinted output nodes are correctly found.""" def _build_ophinted_op(name, input1, input2): custom_op = op_hint.OpHint(name) input1 = custom_op.add_input(input1) input2 = custom_op.add_input(input2) output = math_ops.mul(input1, input2) return custom_op.add_output(output) output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]), array_ops.constant([2.])) output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]), array_ops.constant([4.])) with self.cached_session() as sess: hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess) expected_hinted_output_nodes = [ _node_name(output_1.name), _node_name(output_2.name) ] self.assertEqual( len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError( "Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] else: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted( list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out