def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]: out_edges = G.out_edges(relu_node) if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters): continue mul_node = out_edges[0].to_node in_edges = G.in_edges(mul_node) if len(in_edges) != 2: continue other_edge = (set(in_edges) - {out_edges[0]}).pop() constant_node = other_edge.from_node if len(G.out_edges(constant_node)) != 1: continue if (not isinstance(constant_node, ConstantInputParameters) or not check_equals(G, constant_node, 1.0/6.0)): continue something_changed = True activation = HSigmoidActivationParameters( G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0) in_edges = G.in_edges(relu_node) out_edges = G.out_edges(mul_node) nodes_to_replace = [relu_node, mul_node, constant_node] LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}') G.remove_all(nodes_to_replace) for in_edge in in_edges: G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0)) for out_edge in out_edges: G.add_edge(NNEdge.clone( out_edge, from_node=activation, from_idx=0)) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype( reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return something_changed
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: edge_groups = [] for node in G.nodes(node_classes=SplitParameters): cur_group = None for out_edge_bundle in G.indexed_out_edges(node): if len(out_edge_bundle) == 1: out_edge = out_edge_bundle[0] concat_node_edges = search_down(G, out_edge, ConcatParameters, can_pass=(CopyParameters, NoOPParameters)) if concat_node_edges: if cur_group: this_concat_edge = concat_node_edges[-1] last_concat_edge = cur_group[-1][-1] if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1: cur_group.append(concat_node_edges) continue if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = [concat_node_edges] continue if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None # we leave the splits and concats after this since they will be cleared up by remove_noops for edge_group in edge_groups: split_node = edge_group[0][0].from_node concat_node = edge_group[0][-1].to_node from_idx = edge_group[0][0].from_idx to_idx = edge_group[-1][0].from_idx LOG.info( f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}" ) # combine slices and shapes on edges in group new_slice, new_shape = reduce_slices( split_node.act_slices[from_idx:to_idx + 1], split_node.out_shapes[from_idx:to_idx + 1]) split_node.act_slices = split_node.act_slices[:from_idx] + [ new_slice ] + split_node.act_slices[to_idx + 1:] split_node.out_shapes = split_node.out_shapes[:from_idx] + [ new_shape ] + split_node.out_shapes[to_idx + 1:] # remove all edges and intermediate nodes on all edge groups except the first for edge_list in edge_group[1:]: remove_edges(G, edge_list) out_edge_bundles = G.indexed_out_edges(split_node) # move edges beyond the edge group after the first index for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]): assert len(edge_list) == 1 edge = edge_list[0] G.remove_edge(edge) G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset)) # reindex the in edges in the concat from_idx = edge_group[0][-1].to_idx to_idx = edge_group[-1][-1].to_idx in_edges = G.indexed_in_edges(concat_node) for offset, in_edge in enumerate(in_edges[to_idx + 1:]): G.remove_edge(in_edge) G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset)) return bool(edge_groups)