예제 #1
0
    def calculate_q(self, node, astats, in_qs, force_width, force_out=None):

        if isinstance(node,
                      (InputParameters, MatrixBroadcastedLinearOpParameters,
                       ConstantInputParameters, MatScaleFusionParameters)):
            qrec = self.calculate_output_q(node,
                                           astats,
                                           in_qs,
                                           force_width=force_width,
                                           force_out=force_out)
        elif isinstance(node, Conv2DParameters):
            qrec = self.calculate_filter_q(node,
                                           astats,
                                           in_q=in_qs[0],
                                           force_width=force_width,
                                           force_out=force_out)
        elif isinstance(node, FcParameters):
            qrec = self.calculate_filter_q(node,
                                           astats,
                                           in_q=in_qs[0],
                                           force_width=force_width,
                                           force_out=force_out)
        elif isinstance(node, SoftMaxParameters):
            # softmax always outputs Q15
            qrec = SymmetricQuantizationRecord(in_qs=in_qs,
                                               out_qs=[QType(16, 15, True)])
        elif isinstance(node, ActivationParameters):
            qrec = SymmetricQuantizationRecord(
                in_qs=in_qs,
                out_qs=[self.compute_activation_out_qtype(node, in_qs[0])])
        else:
            qrec = SymmetricQuantizationRecord(in_qs=in_qs, out_qs=in_qs)
        return qrec
예제 #2
0
    def _quantize(cls, params, in_qs, stats, **kwargs):
        o_q = in_qs[0]
        force_out_qs, _ = cls.get_pow2_opts(**kwargs)
        first_forced_q = force_out_qs and next(
            iter(out_q for out_q in force_out_qs if out_q is not None), None)
        if first_forced_q and not all(out_q == first_forced_q
                                      for out_q in force_out_qs
                                      if out_q is not None):
            LOG.error(
                'split %s is being forced to have different output qtypes',
                params.name)
            return None

        if first_forced_q:
            backwards = kwargs.get('backwards', None)
            if backwards:
                # if going backwards and forced then we force our input
                return SymmetricQuantizationRecord(
                    in_qs=first_forced_q,
                    out_qs=[
                        deepcopy(first_forced_q)
                        for _ in range(params.num_splits)
                    ])
            elif o_q != first_forced_q:
                LOG.error(
                    'split %s is being forced to have different output to input',
                    params.name)
                return None
            # continue here if forced since o_q == forced_q

        return SymmetricQuantizationRecord(
            in_qs=in_qs,
            out_qs=[deepcopy(o_q) for _ in range(params.num_splits)])
예제 #3
0
 def quantize_fusion(self,
                     G: NNGraph,
                     node: ConvFusionParameters,
                     in_qs,
                     force_out=None) -> SymmetricQuantizationRecord:
     if node.fusion_type == 'conv_active':
         result = OrderedDict()
         nodes = node.contained_nodes()
         conv_node = nodes[0]
         conv_astats = self._activation_stats.get(NodeId(node, conv_node))
         conv_qrec = self.calculate_filter_q(conv_node,
                                             conv_astats,
                                             in_q=in_qs[0],
                                             force_width=self._force_width,
                                             out_as_acc=True)
         result[NodeId(node, conv_node)] = conv_qrec
         act_node = nodes[1]
         act_astats = self._activation_stats.get(NodeId(node, act_node))
         if force_out and force_out.bits:
             act_max_q = self.compute_activation_out_maxq(
                 act_node, force_out.bits)
             if force_out.q is not None:
                 if (act_max_q is not None and force_out.q > act_max_q
                     ) or force_out.q > conv_qrec.out_qs[0].q:
                     # We cannot shift left in the kernel
                     # TODO - This should try to increase the input q and perhaps the width
                     # Unlikely to happen
                     raise NotImplementedError()
                 act_o_q = QType(bits=force_out.bits,
                                 q=force_out.q,
                                 signed=True)
             else:
                 if act_max_q is not None:
                     act_o_q.q = min(act_max_q, act_o_q.q)
         else:
             act_o_q = QType.from_min_max(
                 max_val=act_astats['range_out'][0]['max'],
                 min_val=act_astats['range_out'][0]['min'],
                 bits=self._force_width)
             act_o_q.q = min(act_o_q.q, conv_qrec.out_qs[0].q)
             if force_out and force_out.q:
                 if force_out.q > act_max_q or force_out.q > conv_qrec.out_qs[
                         0].q:
                     # We cannot shift left in the kernel
                     # TODO - This should try to increase the input q and perhaps the width
                     # Unlikely to happen
                     raise NotImplementedError()
                 act_o_q.q = force_out.q
         act_qrec = SymmetricQuantizationRecord(in_qs=conv_qrec.out_qs,
                                                out_qs=[act_o_q])
         result[NodeId(node, act_node)] = act_qrec
         return SymmetricQuantizationRecord(in_qs=in_qs,
                                            out_qs=act_qrec.out_qs), result
     else:
         return self.default_quantize_fusion(G,
                                             node,
                                             in_qs,
                                             force_out=force_out)
예제 #4
0
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, out_dtype = cls.get_pow2_opts(**kwargs)
        force_out_q = force_out_qs and force_out_qs[0]

        if params.activation == "relu6":
            int_bits = calc_bits(6)
        elif params.activation == "relun":
            relun = params.activation_params
            if isinstance(relun, list):
                relun = max(relun)
            int_bits = calc_bits(relun)
        elif params.activation == "relu" or params.activation == "hswish" or params.activation == "hsigmoid" or params.activation == "leaky":
            int_bits = bits(stats['range_out'][0]['max'],
                            stats['range_out'][0]['min'])
        else:
            raise ValueError(
                f'no support for activation {params.activation} in POW2 quantizer'
            )

        in_q = in_qs[0]
        if force_out_q is None:
            q = max(cls.get_pow2_bits(**kwargs) - int_bits, 0)
            out_q = QType(q=q, dtype=out_dtype)
        else:
            if force_out_q.bits - force_out_q.q < int_bits:
                LOG.warning(
                    'quantization is forcing node %s to have an output that may clip',
                    params.name)
            out_q = force_out_q
        return SymmetricQuantizationRecord(in_qs=[in_q], out_qs=[out_q])
예제 #5
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     if not self.validate_match(subgraph):
         raise DontReplaceError()
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, Conv2DParameters):
             conv_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
예제 #6
0
 def replace_function(self, G: NNGraph, subgraph: GraphView):
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, FcParameters):
             linear_name = node.name + "_fusion"
             break
     LOG.info("fusing nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(
                     in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
예제 #7
0
 def replace_function(self, G: NNGraph, subgraph: GraphView):
     nodes = list(subgraph.nodes())
     pnode = ActivationFusion(nodes[0].name + "fusion",
                              nodes[0].op_name + "_active", subgraph)
     nodes[0].step_idx = 0
     nodes[1].step_idx = 1
     LOG.debug("fused nodes %s", ",".join((node.name for node in nodes)))
     if G.quantization:
         qrecs = G.quantization.get_all(subgraph.nodes())
         if qrecs:
             if isinstance(qrecs[0],
                           (SymmetricQuantizationRecord,
                            SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                    out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0],
                             (MultQuantizationRecord,
                              MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                               out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0],
                             (Float32QuantizationRecord,
                              Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs,
                                                  out_qs=qrecs[-1].out_qs)
             for node in subgraph.nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode
예제 #8
0
 def calculate_output_q(self,
                        node: Parameters,
                        astats,
                        in_qs,
                        force_width=None,
                        force_out=None):
     del node
     if force_out:
         if force_out.bits:
             if force_out.q:
                 o_q = QType(bits=force_out.bits,
                             q=force_out.q,
                             signed=True)
             else:
                 o_q = QType.from_min_max(
                     max_val=astats['range_out'][0]['max'],
                     min_val=astats['range_out'][0]['min'],
                     bits=force_out.bits)
         elif force_out.q:
             o_q = QType.from_min_max(max_val=astats['range_out'][0]['max'],
                                      min_val=astats['range_out'][0]['min'],
                                      bits=force_width)
             o_q.q = force_out.q
     else:
         o_q = QType.from_min_max(max_val=astats['range_out'][0]['max'],
                                  min_val=astats['range_out'][0]['min'],
                                  bits=force_width)
     return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
예제 #9
0
 def _quantize(cls, params, in_qs, stats, **kwargs):
     force_out_qs, _ = cls.get_pow2_opts(**kwargs)
     force_out_q = force_out_qs and force_out_qs[0]
     out_q = QType.Pow2(16, 15, True)
     if force_out_q and force_out_q != out_q:
         return None
     return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[QType.Pow2(16, 15, True)])
예제 #10
0
    def replace_function(self, G: GraphView, subgraph: GraphView):
        relu_node = None
        constant_node = None
        mul_node = None
        for node in subgraph.nodes():
            if isinstance(node, ReluActivationParameters):
                relu_node = node
            elif isinstance(node, ConstantInputParameters):
                constant_node = node
            elif isinstance(node, MatrixMulParameters):
                mul_node = node

        activation = HSigmoidActivationParameters(mul_node.name +
                                                  "_fused_close_hsigmoid",
                                                  offset=0)

        if G.quantization:
            reluqrec = G.quantization[NodeId(relu_node)]
            mulqrec = G.quantization[NodeId(mul_node)]
            del G.quantization[NodeId(constant_node)]
            if isinstance(reluqrec, (SymmetricQuantizationRecord)):
                pqrec = SymmetricQuantizationRecord(in_qs=reluqrec.in_qs,
                                                    out_qs=mulqrec.out_qs)
            elif isinstance(reluqrec, (MultQuantizationRecord)):
                pqrec = MultQuantizationRecord(in_qs=reluqrec.in_qs,
                                               out_qs=mulqrec.out_qs)
            elif isinstance(reluqrec, (Float32QuantizationRecord)):
                pqrec = Float32QuantizationRecord(in_qs=reluqrec.in_qs,
                                                  out_qs=mulqrec.out_qs)
            else:
                raise NotImplementedError()
            G.quantization[NodeId(activation)] = pqrec
        return activation, None, None
예제 #11
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #12
0
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, _ = cls.get_pow2_opts(**kwargs)
        force_out_q = force_out_qs and force_out_qs[0]
        backwards = kwargs.get('backwards')
        # if we are going backwards
        if backwards:
            # if output must be forced
            assert force_out_q, f'going backwards at {params.name} but output is not forced'
            return SymmetricQuantizationRecord(in_qs=[force_out_q] *
                                               len(in_qs),
                                               out_qs=[deepcopy(force_out_q)])

        # if going forwards and our output is forced and does not match input then
        # we cannot satisfy
        if force_out_q and not all(in_q == force_out_q for in_q in in_qs):
            return None

        return SymmetricQuantizationRecord(in_qs=in_qs,
                                           out_qs=[deepcopy(in_qs[0])])
예제 #13
0
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, out_dtype = cls.get_pow2_opts(**kwargs)
        force_out_q = force_out_qs and force_out_qs[0]
        o_q = QType.from_min_max_pow2(stats['range_out'][0]['min'],
                                      stats['range_out'][0]['max'],
                                      dtype=out_dtype)
        if force_out_q:
            if force_out_q.bits - force_out_q.q < o_q.bits - o_q.q:
                LOG.warning('%s is being forced to output in Q%s and may clip',
                            params.name, force_out_q.q)
            o_q = force_out_q

        return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
예제 #14
0
 def default_quantize_fusion(self,
                             G: NNGraph,
                             node: ConvFusionParameters,
                             in_qs,
                             force_out=None) -> SymmetricQuantizationRecord:
     del G
     result = OrderedDict()
     fin_qs = in_qs
     for fnode in node.contained_nodes():
         qrec = self.calculate_q(fnode,
                                 self._activation_stats.get(
                                     NodeId(node, fnode)),
                                 fin_qs,
                                 self._force_width,
                                 force_out=force_out)
         result[NodeId(node, fnode)] = qrec
         fin_qs = qrec.out_qs
     return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=fin_qs), result
예제 #15
0
 def calculate_output_q(self,
                        node: Parameters,
                        astats,
                        in_qs,
                        min_qsnr=None,
                        force_width=None,
                        force_out=None):
     del node
     if force_out:
         if force_out.bits:
             if force_out.q:
                 o_q = QType(bits=force_out.bits,
                             q=force_out.q,
                             signed=True)
             else:
                 o_q = self.get_quantization(astats, None, force_out.bits)
         elif force_out.q:
             o_q = self.get_quantization(astats, min_qsnr, force_width)
             o_q.q = force_out.q
     else:
         o_q = self.get_quantization(astats, min_qsnr, force_width)
     return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
예제 #16
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges(
                    node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #17
0
    def quantize_backward(self,
                          G: NNGraph,
                          result,
                          edge_recs,
                          node,
                          force_out=None):

        LOG.debug("quantize backwards %s", node.name)
        recalculated = False
        while True:
            in_qs = self.get_in_qs(G, edge_recs, node)
            if self.is_filter_node(node):
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)

                if force_out and force_out.q is not None and qrec.out_qs[
                        0].q < force_out.q:
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    bits_to_gain = force_out.q - qrec.q
                    if bits_to_gain > in_qs[0].q:
                        raise NotImplementedError()
                    # Try to adjust the inputs to satisfy and then
                    # recalculate
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=QType(bits=force_out.bits,
                                                           q=in_qs[0].q -
                                                           bits_to_gain,
                                                           signed=True))
            elif isinstance(node, ConcatParameters):
                assert not recalculated
                max_width = max(in_q.bits for in_q in in_qs)
                min_q = min(in_q.q for in_q in in_qs)
                if force_out:
                    if not self.satisfied(force_out.bits, max_width):
                        max_width = force_out.bits
                    if not self.satisfied(force_out.q, min_q):
                        min_q = force_out.q
                LOG.debug("normalizing concat to %s",
                          QType(bits=max_width, q=min_q, signed=True))
                for pidx, pnode in enumerate(
                    [edge.from_node for edge in G.in_edges(node.name)]):
                    pqrec = in_qs[pidx]
                    if pqrec.q != min_q or pqrec.bits != max_width:
                        self.quantize_backward(G,
                                               result,
                                               edge_recs,
                                               pnode,
                                               force_out=QType(bits=max_width,
                                                               q=min_q,
                                                               signed=True))
                o_q = QType(bits=max_width, q=min_q, signed=True)
                qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs(
                    G, edge_recs, node),
                                                   out_qs=[o_q])
            elif isinstance(node, SoftMaxParameters):
                raise NotImplementedError(
                    "softmax kernel cannot change width or q")
            else:
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)
                o_q = qrec.out_qs[0]
                if not (self.satisfied(force_out.q, o_q.q)
                        and self.satisfied(force_out.bits, o_q.bits)):
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    if len(G.in_edges(node.name)) > 1:
                        raise NotImplementedError(
                            "Nodes with multiple input edges \
                            need custom handling")
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=force_out)

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]

            result[NodeId(node, None)] = qrec

            o_q = qrec.out_qs[0]
            if self.satisfied_force(force_out, o_q):
                break
            if recalculated:
                raise NotImplementedError("no quantization solution found")
            LOG.debug("recalculate %s", node.name)
            recalculated = True
        LOG.debug("back complete %s %s", node.name, qrec)
        return qrec
예제 #18
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for matmul_node in [
                params for params in G.nodes()
                if isinstance(params, MatMulOpParameters)
        ]:
            node_list = self.get_node_list(G, matmul_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            if node_list.active is not None:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.matmul,
                           to_node=node_list.active))
            input_mapping = [[(node_list.matmul, idx)] for idx in range(2)]
            if node_list.add:
                input_mapping += [[(node_list.matmul, 2)]]
            output_mapping = [(node_list.active,
                               0)] if node_list.active else [(node_list.matmul,
                                                              0)]
            pnode = MatMulOpFusionParameters(node_list.matmul.name + '_fusion',
                                             fusion_type=node_list.fusion_type,
                                             subgraph=subgraph,
                                             input_mapping=input_mapping,
                                             output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.matmul.name)
            if node_list.add:
                bias_edge = [
                    add_edge for add_edge in G.in_edges(node_list.add.name)
                    if isinstance(add_edge.from_node, ConstantInputParameters)
                ][0]
            out_edges = G.out_edges(node_list.order[-1].name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            if node_list.add:
                G.add_edge(
                    NNEdge(bias_edge.from_node,
                           pnode,
                           from_idx=bias_edge.from_idx,
                           to_idx=2))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph