Example #1
0
    def quantize(self, G: NNGraph) -> OrderedDict:
        '''quantize the graph'''
        if G.has_quantized_parameters:
            self.dequantize(G)
            G.has_quantized_parameters = False
            G.quantization = None

        self.qrecs = QuantizationSet()
        edge_recs = {}
        opts = {
            'force_width': self._force_width,
            'quantized_dimension': self._quantized_dimension,
            'narrow_weights': self._narrow_weights
        }
        opts.update(self._options)
        quant_kwargs = {
            'opts': opts,
            'all_stats': self._activation_stats,
            'G': G,
            'qrecs': self.qrecs
        }
        dtype = WIDTH_TO_DTYPE[self._force_width]
        self.quantize_forward(edge_recs, dtype=dtype, **quant_kwargs)
        self.qrecs['__quantizer'] = self
        G.graph_identity.quantization_type = 'SQ8'
        return self.qrecs
Example #2
0
    def quantize_forward(self, G: NNGraph, edge_recs, result=None):
        if result is None:
            result = QuantizationSet()
        for node in [step['node'] for step in G.graph_state.steps]:
            LOG.debug("quantize forward %s", node.name)
            in_qs = self.get_in_qs(G, edge_recs, node)
            if isinstance(node, ConvFusionParameters):
                qrec, qrecs = self.quantize_fusion(G, node, in_qs)
                for node_id, fqrec in qrecs.items():
                    result[node_id] = fqrec
            elif isinstance(node, ConcatParameters):
                qrec = self.quantize_backward(G, result, edge_recs, node)
            else:
                qrec = self.calculate_q(
                    node, self._activation_stats.get(NodeId(node, None)),
                    self._filter_stats.get(NodeId(node, None)), in_qs,
                    self._min_qsnr, self._force_width)
            result[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
        return result
Example #3
0
    def do_qtune(self, args):
        """
Tune quantization of graph."""
        self._check_graph()
        nodes, node_descr = self.get_node_step_or_name(args.step)
        if not nodes:
            return
        if args.tune == 'q':
            self._check_quantized()
            for node in nodes:
                tuneq(self.G,
                      self.G.quantization,
                      node,
                      args.parameter,
                      args.X,
                      args.Y,
                      index=args.index)
        elif args.tune == 'set':
            if not self.G.quantization:
                self.G.quantization = QuantizationSet()
            for node in nodes:
                tune_type(self.G, self.G.quantization, node, args.type)
            self.pfeedback(
                f'quantization changed to {args.type} on {node_descr}')
        else:
            self.perror(f'{args.tune} subcommand invalid')
Example #4
0
    def create_graph(self, filename, opts):
        opts = self.get_opts(opts)
        self._name_cache = {}
        add_sys_path(os.path.dirname(__file__))
        buf = open(filename, "rb").read()
        model = Model.GetRootAsModel(buf, 0)
        LOG.info("Importing TFLITE model version %s", model.Version())
        check(model.Version() == 3, "Only support version 3 graphs at present")
        if model.SubgraphsLength() > 1:
            LOG.warning(
                "nntool only supports one subgraph. There may be errors loading this graph."
            )
        G = NNGraph(model=model,
                    filename=filename,
                    name=opts.get('name'),
                    constant_store=ConstantStore())
        if opts.get('load_quantization'):
            G.quantization = QuantizationSet()
            G.has_quantized_parameters = True
            G.quantization.schemes_present.add('SQ8')

        self._import_tflite_graph(G, model, opts)
        clean_dangling_nodes(G)
        fix_split_in_edges(G)
        MatchDuplicateConstants().match(G)
        # DrawGraphReporter().report(G)
        G.add_dimensions()
        remove_concats(G)
        if opts['remove_quantize_ops']:
            RemoveQuantizeOperators().match(G)
            G.add_dimensions()

        if opts.get('load_quantization'):
            # get rid of qrecs on nodes that were not used
            to_remove = []
            for nid in G.quantization:
                if nid.node_name not in G:
                    to_remove.append(nid)
            for nid in to_remove:
                del G.quantization[nid]
            nodes_with_bad_quantization = self.find_nodes_with_bad_quantization(
                G)
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            # check for quantization problems
            # 1) need to force softmax/Sigmoid input to POW2 quantization
            # 2) need to check that all concats and splits have same input and
            #    output quantization
            # 3) Need to check that all nodes have qrecs and that they are consistent
            nodes_with_bad_quantization |= set(
                G.nodes(node_classes=(ConcatParameters, SoftMaxParameters,
                                      SplitParameters,
                                      SigmoidActivationParameters)))
            G.quantization = quantizer.quantize(
                G, start_nodes=nodes_with_bad_quantization)
            G.add_dimensions()

        return G
Example #5
0
 def quantize(self, G: NNGraph) -> OrderedDict:
     '''quantize the graph'''
     if G.has_quantized_parameters:
         self.dequantize(G)
         G.has_quantized_parameters = False
         G.quantization = None
     self.qrecs = QuantizationSet()
     edge_recs = {}
     dtype = WIDTH_TO_DTYPE[self._force_width]
     self.quantize_forward(G, edge_recs, dtype)
     self.qrecs['__quantizer'] = self
     G.graph_identity.quantization_type = 'SQ8'
     return self.qrecs
Example #6
0
    def bfs_pass(self, only_inserted=False):
        """Execute breadth first quantization pass identifying all quantization conflicts.

        Args:
            schemes (list[str]): List of quantization schemes to use in order of priority
        """
        if not self._schemes:
            raise ValueError('no quantization schemes set')
        self._qtypes = TransactionalDict()
        self._qset = QuantizationSet()
        self.remove_quantizers(only_inserted=only_inserted)
        visited = []
        for node in self._graph.inputs():
            self._bfs_pass(self._graph, node, self._qset, visited)
Example #7
0
    def create_graph(self, filename, opts):
        opts = self.get_opts(opts)
        self._name_cache = {}
        add_sys_path(os.path.dirname(__file__))
        buf = open(filename, "rb").read()
        model = Model.GetRootAsModel(buf, 0)
        LOG.info("Importing TFLITE model version %s", model.Version())
        check(model.Version() == 3, "Only support version 3 graphs at present")
        if model.SubgraphsLength() > 1:
            LOG.warning("nntool only supports one subgraph. There may be errors loading this graph.")
        G = NNGraph(model=model, filename=filename, name=opts.get('name'),
                    constant_store=ConstantStore())
        if opts.get('load_quantization'):
            G.quantization = QuantizationSet()
            G.has_quantized_parameters = True
            G.graph_identity.quantization_types.add('SQ8')

        self._import_tflite_graph(G, TFLiteGraph.from_model(model, 0), opts)
        clean_dangling_nodes(G)
        fix_split_in_edges(G)
        MatchDuplicateConstants().match(G)
        G.add_dimensions()
        remove_concats(G)
        if opts['remove_quantize_ops']:
            RemoveQuantizeOperators().match(G)
            G.add_dimensions()

        if opts.get('load_quantization'):
            # get rid of qrecs on nodes that were not used
            to_remove = []
            for nid in G.quantization:
                if nid.node_name not in G:
                    to_remove.append(nid)
            for nid in to_remove:
                del G.quantization[nid]

        return G
Example #8
0
    def create_graph(self, filename, opts) -> NNGraph:
        opts = self.get_opts(opts)
        model = onnx.load(filename)

        # onnx.checker.check_model(model)
        try:
            model = shape_inference.infer_shapes(model)
        except RuntimeError as ex:
            msg = "\n".join(f">   {line}" for line in str(ex).split("\n")
                            if line)
            logger.warning(
                'shape inference failed on onnx graph. '
                f'This may not affect import.\nONNX runtime error was:\n{msg}')

        self._name_cache = {}
        if model.ir_version < 3:
            opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
        else:
            opset_import = model.opset_import
        G = NNGraph(filename=filename, name=opts.get('name'))
        G, qrecs = self._import_onnx_model(G, model.graph, opset_import, opts)
        G.add_dimensions(quiet=True)
        if qrecs:
            propagate_qrecs(G, qrecs)
            qset = QuantizationSet()
            qset.update(qrecs)
            qset.scheme_priority = ['SQ8']
            qset.schemes_present = {'SQ8'}
            G.quantization = qset
            try:
                quantizer = NewQuantizer(G)
                quantizer.quantize()
            except ValueError as ex:
                logger.warning(
                    f'unable to import quantization from FakeQuantize nodes correctly - {ex}'
                )

        clean_dangling_nodes(G)
        MatchDuplicateConstants().match(G)
        return G