Ejemplo n.º 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fac = MatScalePairMatchFactory()
        has_modified_graph = False

        for frag, match in fac.get_matcher().match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx] for node, idx in match
            ]
            first_node = frag.inputs()[0]
            last_node = frag.outputs()[0]
            out_edges = G.out_edges(last_node.name)
            for node in frag.nodes():
                G.remove(node)

            input_mapping = MatScaleFusionParameters.get_mapping_from_edges(
                match_edges)

            fnode = MatScaleFusionParameters(
                "{}_{}_fusion".format(first_node.name, last_node.name),
                fusion_type="vec_scalar",
                subgraph=frag,
                input_mapping=MatScaleFusionParameters.convert_input_mapping(
                    input_mapping))
            has_modified_graph = True
            G.add_node(fnode)
            fnode.in_dims_hint = [None] * 3

            for idx, edge in enumerate(match_edges):
                new_edge = edge.clone(
                    to_node=fnode,
                    to_idx=list(input_mapping[edge.to_node].keys())[0])
                if new_edge.from_node.out_dims_hint:
                    fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[
                        edge.from_idx]
                G.add_edge(new_edge)
            for edge in out_edges:
                new_edge = edge.clone(from_node=fnode)
                G.add_edge(new_edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph