示例#1
0
    def quantize(self, G: NNGraph) -> OrderedDict:
        '''quantize the graph'''
        if G.has_quantized_parameters:
            self.dequantize(G)
            G.has_quantized_parameters = False
            G.quantization = None

        self.qrecs = QuantizationSet()
        edge_recs = {}
        opts = {
            'force_width': self._force_width,
            'quantized_dimension': self._quantized_dimension,
            'narrow_weights': self._narrow_weights
        }
        opts.update(self._options)
        quant_kwargs = {
            'opts': opts,
            'all_stats': self._activation_stats,
            'G': G,
            'qrecs': self.qrecs
        }
        dtype = WIDTH_TO_DTYPE[self._force_width]
        self.quantize_forward(edge_recs, dtype=dtype, **quant_kwargs)
        self.qrecs['__quantizer'] = self
        G.graph_identity.quantization_type = 'SQ8'
        return self.qrecs
示例#2
0
    def create_graph(self, filename, opts):
        opts = self.get_opts(opts)
        self._name_cache = {}
        add_sys_path(os.path.dirname(__file__))
        buf = open(filename, "rb").read()
        model = Model.GetRootAsModel(buf, 0)
        LOG.info("Importing TFLITE model version %s", model.Version())
        check(model.Version() == 3, "Only support version 3 graphs at present")
        if model.SubgraphsLength() > 1:
            LOG.warning(
                "nntool only supports one subgraph. There may be errors loading this graph."
            )
        G = NNGraph(model=model,
                    filename=filename,
                    name=opts.get('name'),
                    constant_store=ConstantStore())
        if opts.get('load_quantization'):
            G.quantization = QuantizationSet()
            G.has_quantized_parameters = True
            G.quantization.schemes_present.add('SQ8')

        self._import_tflite_graph(G, model, opts)
        clean_dangling_nodes(G)
        fix_split_in_edges(G)
        MatchDuplicateConstants().match(G)
        # DrawGraphReporter().report(G)
        G.add_dimensions()
        remove_concats(G)
        if opts['remove_quantize_ops']:
            RemoveQuantizeOperators().match(G)
            G.add_dimensions()

        if opts.get('load_quantization'):
            # get rid of qrecs on nodes that were not used
            to_remove = []
            for nid in G.quantization:
                if nid.node_name not in G:
                    to_remove.append(nid)
            for nid in to_remove:
                del G.quantization[nid]
            nodes_with_bad_quantization = self.find_nodes_with_bad_quantization(
                G)
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            # check for quantization problems
            # 1) need to force softmax/Sigmoid input to POW2 quantization
            # 2) need to check that all concats and splits have same input and
            #    output quantization
            # 3) Need to check that all nodes have qrecs and that they are consistent
            nodes_with_bad_quantization |= set(
                G.nodes(node_classes=(ConcatParameters, SoftMaxParameters,
                                      SplitParameters,
                                      SigmoidActivationParameters)))
            G.quantization = quantizer.quantize(
                G, start_nodes=nodes_with_bad_quantization)
            G.add_dimensions()

        return G
示例#3
0
 def quantize(self, G: NNGraph) -> OrderedDict:
     '''quantize the graph'''
     if G.has_quantized_parameters:
         self.dequantize(G)
         G.has_quantized_parameters = False
         G.quantization = None
     edge_recs = {}
     dtype = WIDTH_TO_DTYPE[self._force_width]
     qrecs = self.quantize_forward(G, edge_recs, dtype)
     qrecs['__quantizer'] = self
     G.graph_identity.quantization_type = 'SQ8'
     return qrecs
示例#4
0
文件: tflite.py 项目: brupa9/gap_sdk
    def create_graph(self, filename, opts):
        opts = self.get_opts(opts)
        self._name_cache = {}
        add_sys_path(os.path.dirname(__file__))
        buf = open(filename, "rb").read()
        model = Model.GetRootAsModel(buf, 0)
        LOG.info("Importing TFLITE model version %s", model.Version())
        check(model.Version() == 3, "Only support version 3 graphs at present")
        if model.SubgraphsLength() > 1:
            LOG.warning("nntool only supports one subgraph. There may be errors loading this graph.")
        G = NNGraph(model=model, filename=filename, name=opts.get('name'),
                    constant_store=ConstantStore())
        if opts.get('load_quantization'):
            G.quantization = QuantizationSet()
            G.has_quantized_parameters = True
            G.graph_identity.quantization_types.add('SQ8')

        self._import_tflite_graph(G, TFLiteGraph.from_model(model, 0), opts)
        clean_dangling_nodes(G)
        fix_split_in_edges(G)
        MatchDuplicateConstants().match(G)
        G.add_dimensions()
        remove_concats(G)
        if opts['remove_quantize_ops']:
            RemoveQuantizeOperators().match(G)
            G.add_dimensions()

        if opts.get('load_quantization'):
            # get rid of qrecs on nodes that were not used
            to_remove = []
            for nid in G.quantization:
                if nid.node_name not in G:
                    to_remove.append(nid)
            for nid in to_remove:
                del G.quantization[nid]

        return G