def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes(node_classes=StridedSliceParameters) ]: if node.slice_shape != tuple(node.in_dims[0].shape): continue has_modified_graph = True nid = NodeId(node) if node.slice_shape == node.out_shape: LOG.info( f'removing strided slice {node.name} that does nothing') G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization and nid in G.quantization: del G.quantization[nid] else: reshape = ReshapeParameters( G.unique_name(f'{node.name}_reshape'), old_shape=node.slice_shape, shape=node.out_shape) LOG.info( f'replacing strided slice {node.name} with reshape {reshape.name}' ) G.replace_node(node, reshape) if G.quantization and nid in G.quantization: G.quantization[NodeId(reshape)] = G.quantization[nid] del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=MatMulOpParameters): in_edges = [edge for edge in G.indexed_in_edges(node.name)] trans_node = in_edges[1].from_node if not isinstance(trans_node, TransposeParameters): continue if isinstance(node, MatMulTransposedParameters): new_node = MatMulOpParameters(node.name) else: new_node = MatMulTransposedParameters(node.name) in_trans_edge = [ edge for edge in G.indexed_in_edges(trans_node.name) ][0] G.replace_node(node.name, new_node) G.remove(trans_node) G.add_edge( NNEdge(in_trans_edge.from_node, new_node, from_idx=in_trans_edge.from_idx, to_idx=1)) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph