コード例 #1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #2
0
ファイル: match_gap_conv.py プロジェクト: brupa9/gap_sdk
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #3
0
ファイル: fuse_pad.py プロジェクト: hasetz/gap_sdk
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0', matcher=lambda node:\
             isinstance(node, PadParameters)))
     sub.add_node(MatchNode('1', matcher=lambda node:\
             isinstance(node, FilterLikeParameters) and\
             self.has_no_padding(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
コード例 #4
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0',
                            matcher=lambda node:
                            isinstance(node, FcParameters) and
                            self.valid_linear(node)))
     sub.add_node(MatchNode('1', matcher=lambda node:
                            isinstance(node, ActivationParameters) and
                            self.valid_activation(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
コード例 #5
0
def split_down_from(cur_g, node, res_g=None):
    """ split cur_g into 2 graphs. Everything from node down and the rest """
    if res_g is None:
        res_g = GraphView()
    out_edges = cur_g.out_edges(node.name)
    cur_g.remove(node)
    if node not in res_g.nodes():
        res_g.add_node(node)
    for edge in out_edges:
        res_g.add_edge(edge.clone())
        split_down_from(cur_g, edge.to_node, res_g=res_g)
    return res_g
コード例 #6
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, FilterParameters)))
        sub.add_node(MatchNode('1', matcher=lambda node:\
                isinstance(node, MatrixAddParameters)))
        sub.add_node(MatchNode('2', matcher=lambda node:\
                isinstance(node, ConstantInputParameters)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
コード例 #7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #8
0
def construct_subgraph(G, nodes):
    """ construct a subgraph from nodes """
    sub_g = GraphView()
    while nodes:
        node = nodes.pop(0)
        if node not in sub_g.nodes():
            sub_g.add_node(node)
        for edge in G.out_edges(node.name):
            if edge.to_node in nodes:
                sub_g.add_edge(edge.clone())
        for edge in G.in_edges(node.name):
            if edge.from_node in nodes:
                sub_g.add_edge(edge.clone())
    return sub_g
コード例 #9
0
ファイル: expression_matcher.py プロジェクト: brupa9/gap_sdk
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for node_set in self.find_sets(G):
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.info('matched expression - creating expression %s',
                     self._expr_num)
            expr = ExpressionFusionParameters(f"expr_{self._expr_num}",
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              in_dims=in_dims,
                                              out_dims=out_dims,
                                              constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]

            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping)
            self._expr_num += 1

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #10
0
ファイル: expression_matcher.py プロジェクト: dilawar/gap_sdk
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        # collect connected node sets
        node_sets = group_nodes(G, [
            node for node in G.nodes() if isinstance(node, FUSE_NODES) or (
                isinstance(node, ConstantInputParameters)
                and node.out_dims[0].size() == 1)
        ])
        # remove sets that are only ConstantInputs
        node_sets = [
            node_set for node_set in node_sets if not all(
                isinstance(node, ConstantInputParameters) for node in node_set)
        ]
        for node_set in node_sets:
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                isinstance(node_edge_idx[0], ConstantInputParameters)
                for node_edge_idx in in_edges
            ]
            expr = ExpressionFusionParameters("expr_%s" % self._expr_num,
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              constant_inputs=constant_inputs)
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=sorted(list(in_edges.keys()),
                                       key=lambda x: x[1]),
                edge_out_mapping=[[(edge.to_node, edge.to_idx)
                                   for edge in edge_set]
                                  for edge_set in out_edges.values()])

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #11
0
def find_concats_up(G, concat, subgraph: GraphView = None):
    # Produces a subgraph of concats operating on axis 0 separated by copys or reshapes.
    # the output node will be the final concat. the input nodes will be all the inputs
    # to a condensed concat that can replace this subgraph.
    if subgraph is None:
        subgraph = GraphView()
        edge_path = []
    for edge in G.indexed_in_edges(concat.name):
        edge_path = traverse_to_concat(G, edge, subgraph)
        if edge_path:
            for inter_edge in edge_path:
                subgraph.add_edge(inter_edge)
        else:
            subgraph.add_edge(
                NNEdge(from_node=DummyInput(
                    f"{edge.from_node.name}_{edge.from_idx}", edge),
                       to_node=edge.to_node,
                       to_idx=edge.to_idx))
    return subgraph
コード例 #12
0
ファイル: find_hsigmoid.py プロジェクト: dilawar/gap_sdk
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(
            MatchNode(
                '0',
                matcher=lambda node: isinstance(node, ReluActivationParameters
                                                ) and node.upper_bound == 6))
        sub.add_node(
            MatchNode(
                '1',
                matcher=lambda node: isinstance(node, MatrixMulParameters)))
        sub.add_node(
            MatchNode(
                '2',
                matcher=lambda node: isinstance(node, ConstantInputParameters)
                and check_equals(G, node, 1.0 / 6.0)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
コード例 #13
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, Conv2DParameters) and\
                self.valid_activation(node)))
        if self.match_activation and self.match_pool:
            if self.pool_after_activation:
                self.add_activation('1', sub)
                self.add_pooling('2', sub)
            else:
                self.add_pooling('1', sub)
                self.add_activation('2', sub)
            sub.add_edge(Edge('0', '1'))
            sub.add_edge(Edge('1', '2'))
        elif self.match_activation:
            self.add_activation('1', sub)
            sub.add_edge(Edge('0', '1'))
        elif self.match_pool:
            self.add_pooling('1', sub)
            sub.add_edge(Edge('0', '1'))

        return G.match_fragment(sub)
コード例 #14
0
def construct_subgraph(G, nodes):
    subg = GraphView()
    nodes = set(nodes)
    for node in nodes:
        for edge in G.out_edges(node.name):
            # only add internal edges
            if edge.to_node in nodes:
                subg.add_edge(NNEdge(from_node=edge.from_node, to_node=edge.to_node,
                                   from_idx=edge.from_idx, to_idx=edge.to_idx))

    def red_fn(state, edge):
        state.setdefault((edge.from_node, edge.from_idx), []
                         ).append((edge.to_node, edge.to_idx))
        return state

    inputs = reduce(red_fn, set([edge for node in subg.inputs()
                                 for edge in G.in_edges(node.name)]), {})
    inputs_map = []
    for (fnode, fidx), outs in inputs.items():
        inp = FusionInputParameters(
            f'{fnode.name}_{fidx}_in', dims=fnode.out_dims[fidx])
        inputs_map.append((fnode, fidx))
        for (tnode, tidx) in outs:
            subg.add_edge(NNEdge(from_node=inp, to_node=tnode, to_idx=tidx))
    outputs = [(node, set(edge.from_idx for edge in G.out_edges(node.name)))
               for node in subg.outputs()]
    outputs_map = []
    for (node, fidxes) in outputs:
        output_map = []
        outputs_map.append(output_map)
        for fidx in fidxes:
            output_map.append((edge.to_node, edge.to_idx)
                              for edge in G.out_edges(node.name) if edge.from_idx == fidx)
            outp = FusionOutputParameters(
                f'{node.name}_{fidx}_out', dims=node.out_dims[fidx])
            subg.add_edge(NNEdge(from_node=node, to_node=outp, from_idx=fidx))
    return (subg, inputs_map, outputs_map)
コード例 #15
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for conv_node in [
                params for params in G.nodes()
                if isinstance(params, Conv2DParameters)
        ]:
            node_list = self.get_node_list(G, conv_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                # NOTE: This is only for old POW2 kernels - SQ8 can handle this
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                in_dims=deepcopy(node_list.conv.in_dims),
                out_dims=deepcopy(node_list.order[-1].out_dims),
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    # TODO - stats
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=deepcopy(qrecs[0].in_qs),
                                           out_qs=deepcopy(qrecs[-1].out_qs))
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #16
0
ファイル: expression_matcher.py プロジェクト: mfkiwl/gap_sdk
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        node_sets = self.find_sets(G)
        for node_set in node_sets:
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for node in node_set:
                frag.add_node(node)
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in node_set),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                G.quantization[NodeId(expr)] = QRec(in_qs=in_qs,
                                                    out_qs=out_qs,
                                                    expression=stats,
                                                    ktype='scaled')
                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            G.quantization = quantizer.quantize(G, start_nodes=to_quantize)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #17
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                # if there are quantization stats then clear them. They need to be created again
                G.quantization.stats = None
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + \
                    G.indexed_in_edges(node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #18
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges(
                    node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #19
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for matmul_node in [
                params for params in G.nodes()
                if isinstance(params, MatMulOpParameters)
        ]:
            node_list = self.get_node_list(G, matmul_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            if node_list.active is not None:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.matmul,
                           to_node=node_list.active))
            input_mapping = [[(node_list.matmul, idx)] for idx in range(2)]
            if node_list.add:
                input_mapping += [[(node_list.matmul, 2)]]
            output_mapping = [(node_list.active,
                               0)] if node_list.active else [(node_list.matmul,
                                                              0)]
            pnode = MatMulOpFusionParameters(node_list.matmul.name + '_fusion',
                                             fusion_type=node_list.fusion_type,
                                             subgraph=subgraph,
                                             input_mapping=input_mapping,
                                             output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.matmul.name)
            if node_list.add:
                bias_edge = [
                    add_edge for add_edge in G.in_edges(node_list.add.name)
                    if isinstance(add_edge.from_node, ConstantInputParameters)
                ][0]
            out_edges = G.out_edges(node_list.order[-1].name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            if node_list.add:
                G.add_edge(
                    NNEdge(bias_edge.from_node,
                           pnode,
                           from_idx=bias_edge.from_idx,
                           to_idx=2))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #20
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
            valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
            valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL
        for pool_node in G.nodes(node_classes=(PoolingParameters,
                                               GlobalPoolingParameters)):
            node_list = self.get_node_list(G, pool_node, valid_activations,
                                           valid_activations_wo_pool)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.pool, 0)]]
            output_mapping = [(last_node, 0)]
            pnode = ActivationFusion(node_list.pool.name + '_fusion',
                                     fusion_type=node_list.fusion_type,
                                     subgraph=subgraph,
                                     input_mapping=input_mapping,
                                     output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                        if isinstance(node, GlobalPoolingParameters):
                            # Global pooling fused with activations need to have only the activation scale
                            G.quantization[NodeId(pnode,
                                                  node)].out_qs[0] = deepcopy(
                                                      G.quantization[NodeId(
                                                          pnode,
                                                          node)].in_qs[0])
                            G.quantization[NodeId(
                                pnode, node)].out_qs[0].dtype = np.int32
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.pool.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
コード例 #21
0
ファイル: remove_noops.py プロジェクト: dilawar/gap_sdk
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(NoOPMatcher('0'))
     return G.match_fragment(sub)