Esempio n. 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 2
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0', matcher=lambda node:\
             isinstance(node, PadParameters)))
     sub.add_node(MatchNode('1', matcher=lambda node:\
             isinstance(node, FilterLikeParameters) and\
             self.has_no_padding(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
Esempio n. 3
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0',
                            matcher=lambda node:
                            isinstance(node, FcParameters) and
                            self.valid_linear(node)))
     sub.add_node(MatchNode('1', matcher=lambda node:
                            isinstance(node, ActivationParameters) and
                            self.valid_activation(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
Esempio n. 4
0
def split_down_from(cur_g, node, res_g=None):
    """ split cur_g into 2 graphs. Everything from node down and the rest """
    if res_g is None:
        res_g = GraphView()
    out_edges = cur_g.out_edges(node.name)
    cur_g.remove(node)
    if node not in res_g.nodes():
        res_g.add_node(node)
    for edge in out_edges:
        res_g.add_edge(edge.clone())
        split_down_from(cur_g, edge.to_node, res_g=res_g)
    return res_g
Esempio n. 5
0
def construct_subgraph(G, nodes):
    """ construct a subgraph from nodes """
    sub_g = GraphView()
    while nodes:
        node = nodes.pop(0)
        if node not in sub_g.nodes():
            sub_g.add_node(node)
        for edge in G.out_edges(node.name):
            if edge.to_node in nodes:
                sub_g.add_edge(edge.clone())
        for edge in G.in_edges(node.name):
            if edge.from_node in nodes:
                sub_g.add_edge(edge.clone())
    return sub_g
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fac = MatScalePairMatchFactory()
        has_modified_graph = False

        for frag, match in fac.get_matcher().match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx] for node, idx in match
            ]
            first_node = frag.inputs()[0]
            last_node = frag.outputs()[0]
            out_edges = G.out_edges(last_node.name)
            for node in frag.nodes():
                G.remove(node)

            input_mapping = MatScaleFusionParameters.get_mapping_from_edges(
                match_edges)

            fnode = MatScaleFusionParameters(
                "{}_{}_fusion".format(first_node.name, last_node.name),
                fusion_type="vec_scalar",
                subgraph=frag,
                input_mapping=MatScaleFusionParameters.convert_input_mapping(
                    input_mapping))
            has_modified_graph = True
            G.add_node(fnode)
            fnode.in_dims_hint = [None] * 3

            for idx, edge in enumerate(match_edges):
                new_edge = edge.clone(
                    to_node=fnode,
                    to_idx=list(input_mapping[edge.to_node].keys())[0])
                if new_edge.from_node.out_dims_hint:
                    fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[
                        edge.from_idx]
                G.add_edge(new_edge)
            for edge in out_edges:
                new_edge = edge.clone(from_node=fnode)
                G.add_edge(new_edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 7
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, Conv2DParameters) and\
                self.valid_activation(node)))
        if self.match_activation and self.match_pool:
            if self.pool_after_activation:
                self.add_activation('1', sub)
                self.add_pooling('2', sub)
            else:
                self.add_pooling('1', sub)
                self.add_activation('2', sub)
            sub.add_edge(Edge('0', '1'))
            sub.add_edge(Edge('1', '2'))
        elif self.match_activation:
            self.add_activation('1', sub)
            sub.add_edge(Edge('0', '1'))
        elif self.match_pool:
            self.add_pooling('1', sub)
            sub.add_edge(Edge('0', '1'))

        return G.match_fragment(sub)
Esempio n. 8
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, FilterParameters)))
        sub.add_node(MatchNode('1', matcher=lambda node:\
                isinstance(node, MatrixAddParameters)))
        sub.add_node(MatchNode('2', matcher=lambda node:\
                isinstance(node, ConstantInputParameters)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
Esempio n. 9
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(
            MatchNode(
                '0',
                matcher=lambda node: isinstance(node, ReluActivationParameters
                                                ) and node.upper_bound == 6))
        sub.add_node(
            MatchNode(
                '1',
                matcher=lambda node: isinstance(node, MatrixMulParameters)))
        sub.add_node(
            MatchNode(
                '2',
                matcher=lambda node: isinstance(node, ConstantInputParameters)
                and check_equals(G, node, 1.0 / 6.0)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
Esempio n. 10
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        node_sets = self.find_sets(G)
        for node_set in node_sets:
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for node in node_set:
                frag.add_node(node)
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in node_set),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                G.quantization[NodeId(expr)] = QRec(in_qs=in_qs,
                                                    out_qs=out_qs,
                                                    expression=stats,
                                                    ktype='scaled')
                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            G.quantization = quantizer.quantize(G, start_nodes=to_quantize)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 11
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(NoOPMatcher('0'))
     return G.match_fragment(sub)