コード例 #1
0
 def _freeze_requantization_ranges(self):
     self._tmp_graph_def = freeze_max(self._tmp_graph_def,
                                      self._requant_min_max_log)
     self._tmp_graph_def = freeze_min(self._tmp_graph_def,
                                      self._requant_min_max_log)
     self._tmp_graph_def = freeze_requantization_range(
         self._tmp_graph_def, self._requant_min_max_log)
     if self.debug:
         write_graph(self._tmp_graph_def, self._int8_frozen_range_graph)
コード例 #2
0
 def _insert_logging(self):
     int8_dynamic_range_graph_def = graph_pb2.GraphDef()
     int8_dynamic_range_graph_def.CopyFrom(self._tmp_graph_def)
     InsertLogging(self._tmp_graph_def,
                   ops=[
                       "RequantizationRange{}".format(
                           "PerChannel" if self.per_channel else "")
                   ],
                   message="__requant_min_max:").do_transformation()
     InsertLogging(self._tmp_graph_def, ops=["Min"],
                   message="__min:").do_transformation()
     InsertLogging(self._tmp_graph_def, ops=["Max"],
                   message="__max:").do_transformation()
     write_graph(self._tmp_graph_def, self._int8_logged_graph)
     self._tmp_graph_def.CopyFrom(int8_dynamic_range_graph_def)
コード例 #3
0
 def _fuse_requantize_with_fused_quantized_conv(self):
     self._tmp_graph_def = fuse_quantized_conv_and_requantize(
         self._tmp_graph_def)
     # strip_unused_nodes with optimize_for_inference
     dtypes = self._get_dtypes(self._tmp_graph_def)
     # self._tmp_graph_def = optimize_for_inference(self._tmp_graph_def, self.inputs, self.outputs, dtypes, False)
     self._tmp_graph_def = StripUnusedNodes(self._tmp_graph_def,
                                            self.inputs, self.outputs,
                                            dtypes).do_transform()
     self._tmp_graph_def = graph_util.remove_training_nodes(
         self._tmp_graph_def, self.outputs)
     self._tmp_graph_def = FoldBatchNormNodes(
         self._tmp_graph_def).do_transform()
     RerangeQuantizedConcat(self._tmp_graph_def).do_transformation()
     write_graph(self._tmp_graph_def, self.output_graph)
     logging.info('Converted graph file is saved to: %s', self.output_graph)
コード例 #4
0
    def _optimize_frozen_fp32_graph(self):
        """Optimize fp32 frozen graph."""

        self._tmp_graph_def = read_graph(self.input_graph,
                                         self.input_graph_binary_flag)
        dtypes = self._get_dtypes(self._tmp_graph_def)
        # self._tmp_graph_def = optimize_for_inference(self._tmp_graph_def, self.inputs, self.outputs, dtypes, False)
        self._tmp_graph_def = FuseColumnWiseMul(
            self._tmp_graph_def).do_transformation()
        self._tmp_graph_def = StripUnusedNodes(self._tmp_graph_def,
                                               self.inputs, self.outputs,
                                               dtypes).do_transform()
        self._tmp_graph_def = graph_util.remove_training_nodes(
            self._tmp_graph_def, self.outputs)
        self._tmp_graph_def = FoldBatchNormNodes(
            self._tmp_graph_def).do_transform()
        write_graph(self._tmp_graph_def, self._fp32_optimized_graph)
コード例 #5
0
    def _quantize_graph(self):
        """quantize graph."""

        if not self._tmp_graph_def:
            self._tmp_graph_def = read_graph(self.input_graph,
                                             self.input_graph_binary_flag)

        g = ops.Graph()
        with g.as_default():
            importer.import_graph_def(self._tmp_graph_def)

        intel_quantizer = QuantizeGraphForIntel(
            self._tmp_graph_def,
            self.outputs,
            self.per_channel,
            excluded_ops=self.excluded_ops,
            excluded_nodes=self.excluded_nodes)
        self._tmp_graph_def = intel_quantizer.do_transform()

        if self.debug:
            write_graph(self._tmp_graph_def, self._int8_dynamic_range_graph)