Пример #1
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes(node_classes=StridedSliceParameters)
        ]:
            if node.slice_shape != tuple(node.in_dims[0].shape):
                continue
            has_modified_graph = True
            nid = NodeId(node)
            if node.slice_shape == node.out_shape:
                LOG.info(
                    f'removing strided slice {node.name} that does nothing')
                G.remove_and_reconnect(node, edge_class=NNEdge)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            else:
                reshape = ReshapeParameters(
                    G.unique_name(f'{node.name}_reshape'),
                    old_shape=node.slice_shape,
                    shape=node.out_shape)
                LOG.info(
                    f'replacing strided slice {node.name} with reshape {reshape.name}'
                )
                G.replace_node(node, reshape)
                if G.quantization and nid in G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[nid]
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Пример #2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=MatMulOpParameters):
            in_edges = [edge for edge in G.indexed_in_edges(node.name)]
            trans_node = in_edges[1].from_node
            if not isinstance(trans_node, TransposeParameters):
                continue
            if isinstance(node, MatMulTransposedParameters):
                new_node = MatMulOpParameters(node.name)
            else:
                new_node = MatMulTransposedParameters(node.name)

            in_trans_edge = [
                edge for edge in G.indexed_in_edges(trans_node.name)
            ][0]
            G.replace_node(node.name, new_node)
            G.remove(trans_node)
            G.add_edge(
                NNEdge(in_trans_edge.from_node,
                       new_node,
                       from_idx=in_trans_edge.from_idx,
                       to_idx=1))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph