예제 #1
0
 def replace_function(self, G: NNGraph, subgraph: GraphView):
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, FcParameters):
             linear_name = node.name + "_fusion"
             break
     LOG.info("fusing nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(
                     in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
예제 #2
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     if not self.validate_match(subgraph):
         raise DontReplaceError()
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, Conv2DParameters):
             conv_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
예제 #3
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #4
0
 def quantize_fusion(self,
                     G: NNGraph,
                     node: ConvFusionParameters,
                     in_qs,
                     force_out=None) -> SymmetricQuantizationRecord:
     if node.fusion_type == 'conv_active':
         result = OrderedDict()
         nodes = node.contained_nodes()
         conv_node = nodes[0]
         conv_astats = self._activation_stats.get(NodeId(node, conv_node))
         conv_qrec = self.calculate_filter_q(conv_node,
                                             conv_astats,
                                             in_q=in_qs[0],
                                             force_width=self._force_width,
                                             out_as_acc=True)
         result[NodeId(node, conv_node)] = conv_qrec
         act_node = nodes[1]
         act_astats = self._activation_stats.get(NodeId(node, act_node))
         if force_out and force_out.bits:
             act_max_q = self.compute_activation_out_maxq(
                 act_node, force_out.bits)
             if force_out.q is not None:
                 if (act_max_q is not None and force_out.q > act_max_q
                     ) or force_out.q > conv_qrec.out_qs[0].q:
                     # We cannot shift left in the kernel
                     # TODO - This should try to increase the input q and perhaps the width
                     # Unlikely to happen
                     raise NotImplementedError()
                 act_o_q = QType(bits=force_out.bits,
                                 q=force_out.q,
                                 signed=True)
             else:
                 if act_max_q is not None:
                     act_o_q.q = min(act_max_q, act_o_q.q)
         else:
             act_o_q = QType.from_min_max(
                 max_val=act_astats['range_out'][0]['max'],
                 min_val=act_astats['range_out'][0]['min'],
                 bits=self._force_width)
             act_o_q.q = min(act_o_q.q, conv_qrec.out_qs[0].q)
             if force_out and force_out.q:
                 if force_out.q > act_max_q or force_out.q > conv_qrec.out_qs[
                         0].q:
                     # We cannot shift left in the kernel
                     # TODO - This should try to increase the input q and perhaps the width
                     # Unlikely to happen
                     raise NotImplementedError()
                 act_o_q.q = force_out.q
         act_qrec = SymmetricQuantizationRecord(in_qs=conv_qrec.out_qs,
                                                out_qs=[act_o_q])
         result[NodeId(node, act_node)] = act_qrec
         return SymmetricQuantizationRecord(in_qs=in_qs,
                                            out_qs=act_qrec.out_qs), result
     else:
         return self.default_quantize_fusion(G,
                                             node,
                                             in_qs,
                                             force_out=force_out)
예제 #5
0
 def default_quantize_fusion(self,
                             G: NNGraph,
                             node: ConvFusionParameters,
                             in_qs,
                             force_out=None) -> SymmetricQuantizationRecord:
     del G
     result = OrderedDict()
     fin_qs = in_qs
     for fnode in node.contained_nodes():
         qrec = self.calculate_q(fnode,
                                 self._activation_stats.get(
                                     NodeId(node, fnode)),
                                 fin_qs,
                                 self._force_width,
                                 force_out=force_out)
         result[NodeId(node, fnode)] = qrec
         fin_qs = qrec.out_qs
     return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=fin_qs), result
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for conv_node in [
                params for params in G.nodes()
                if isinstance(params, Conv2DParameters)
        ]:
            node_list = self.get_node_list(G, conv_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                # NOTE: This is only for old POW2 kernels - SQ8 can handle this
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                in_dims=deepcopy(node_list.conv.in_dims),
                out_dims=deepcopy(node_list.order[-1].out_dims),
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    # TODO - stats
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=deepcopy(qrecs[0].in_qs),
                                           out_qs=deepcopy(qrecs[-1].out_qs))
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph