def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [ node for node in G.nodes(node_classes=SplitParameters) if search_up_for_input(G, node) ] has_modified_graph = False for node in candidates: LOG.info("Insert copy on split input %s", node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{node.name}_copy')) G.insert_node_at_edge(cnode, G.in_edges(node.name)[0]) if G.quantization: G.quantization.copy_qrec(node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))] need_a_copy_edges = [] for node in candidates: for idx, edge in enumerate(G.indexed_in_edges(node.name)): real_from_node, _ = find_real_in_edge(G, edge) if isinstance(real_from_node, (InputParameters, ConstantInputParameters)): need_a_copy_edges.append((edge, idx)) has_modified_graph = False for edge in need_a_copy_edges: LOG.info( "Insert copy on split input %s", edge[0].to_node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy')) G.insert_node_at_edge(cnode, edge[0]) if G.quantization: G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=SplitParameters): same_op_edges = self.moveable_same_operation_edges(G, node) if not same_op_edges: continue has_modified_graph = True in_edges = G.in_edges(node.name) assert len(in_edges) == 1 # sort by name to ensure that operation is repeatable same_op_edges.sort(key=lambda x: x.to_node.name) keep_node = same_op_edges[0].to_node LOG.info('split node %s has duplicate operations on its out edges', node.name) LOG.info('moving %s before split node %s', keep_node.name, node.name) for edge in G.out_edges(node.name): node_out_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) if edge.to_node != keep_node: LOG.info('deleting duplicate node %s', edge.to_node.name) if G.quantization: nid = NodeId(edge.to_node) if nid in G.quantization: del G.quantization[nid] for out_edge in node_out_edges: G.add_edge( NNEdge(from_node=node, from_idx=edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge) if G.quantization: quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() if set_identity: self.set_identity(G) return has_modified_graph