示例#1
0
文件: model.py 项目: vuiseng9/lpot
    def __init__(self, model, framework_specific_info={}, **kwargs):
        self.framework_specific_info = framework_specific_info
        input_tensor_names = deep_get(framework_specific_info,
                                      'input_tensor_names', [])
        output_tensor_names = deep_get(framework_specific_info,
                                       'output_tensor_names', [])
        self.workspace_path = deep_get(framework_specific_info,
                                       'workspace_path', './')
        self.kwargs = copy.deepcopy(kwargs)
        kwargs.update({'name': deep_get(framework_specific_info, 'name')})
        self._model = model
        self.sess, self._input_tensor_names, self._output_tensor_names = \
            create_session_with_input_output(
                model, input_tensor_names, output_tensor_names, **kwargs)

        deep_set(framework_specific_info, 'input_tensor_names',
                 self._input_tensor_names)
        deep_set(framework_specific_info, 'output_tensor_names',
                 self._output_tensor_names)

        # sess_global_initialize(self.sess)
        self.iter_op = None
        if 'MakeIterator' in [
                node.op for node in self.sess.graph.as_graph_def().node
        ]:
            self.iter_op = self.sess.graph.get_operation_by_name(
                'MakeIterator')

        tf.compat.v1.get_variable_scope().reuse_variables()
        self._graph_info = {}
示例#2
0
    def __init__(self,
                 input_graph,
                 output_graph,
                 inputs=[],
                 outputs=[],
                 qt_config={},
                 int8_sequences={},
                 fp32_ops=[],
                 bf16_ops=[],
                 data_loader=None):
        """Convert graph.

        :param input_graph: input graph pb file.
        :param output_graph: output graph pb file. If set, output directory should be exist.
        :param inputs: input nodes' names.
        :param outputs: output nodes' names.
        :param qt_config: quantization configs, including interation and op-wise quant config
        :param fp32_ops: fall back to fp32 dtype op list
        :param bf16_ops: fall back to bf16 dtype op list
        :param data_loader: for calibration phase used dataloader
        """
        # Logger initial
        self.logger = logging.getLogger()
        self.debug = bool(self.logger.level == logging.DEBUG)

        # as we may have outputs with suffix, strip to get raw name
        self.output_node_names = list(set([x.split(":")[0] for x in outputs]))
        self.input_node_names = list(set([x.split(":")[0] for x in inputs]))
        # For lpot, the input_graph is not graph file path but Graph object.
        self.input_graph = get_graph_def(input_graph, self.output_node_names)
        if 'MakeIterator' in [node.op for node in self.input_graph.node]:
            self.output_node_names.append('MakeIterator')
        self.output_graph = output_graph
        self.input_tensor_names = inputs
        self.output_tensor_names = outputs

        # quantize specific config
        self.calib_iteration = qt_config['calib_iteration']
        self.op_wise_config = qt_config['op_wise_config']
        self.advance_config = deep_get(qt_config, 'advance')
        self.device = qt_config['device'] if 'device' in qt_config else 'cpu'
        self.int8_sequences = int8_sequences
        self.fp32_ops = fp32_ops
        self.bf16_ops = bf16_ops

        self._calibration_data = []
        self._fp32_print_data = []
        self.data_loader = data_loader
        self._check_tf_version()
        self._check_args()
        self._gen_tmp_filenames()
        self._kl_op_dict = {}
        self._kl_keys = []
        self._print_node_mapping = {}
        self._enable_kl_op_names = [
            k for k in self.op_wise_config if self.op_wise_config[k][1] == 'kl'
        ]
        self._fp32_origin_graph = copy.deepcopy(self.input_graph)
示例#3
0
    def __init__(self,
                 model,
                 qt_config={},
                 recipes={},
                 int8_sequences={},
                 fp32_ops=[],
                 bf16_ops=[],
                 data_loader=None,
                 fake_quant=False):
        """Convert graph.

        :param model: input tensorflow model.
        :param qt_config: quantization configs, including interation and op-wise quant config
        :param fp32_ops: fall back to fp32 dtype op list
        :param bf16_ops: fall back to bf16 dtype op list
        :param data_loader: for calibration phase used dataloader
        :param fake_quant: for quantization-aware training model conversion to default model
        """
        # Logger initial
        self.logger = logging.getLogger()
        self.debug = bool(self.logger.level == logging.DEBUG)
        self.model = model
        # quantize specific config
        self.calib_iteration = qt_config[
            'calib_iteration'] if not fake_quant else 0
        self.op_wise_config = qt_config['op_wise_config']
        self.advance_config = deep_get(qt_config, 'advance')
        self.device = qt_config['device'] if 'device' in qt_config else 'cpu'
        self.int8_sequences = int8_sequences
        self.fp32_ops = fp32_ops
        self.bf16_ops = bf16_ops
        self.recipes = recipes
        self.fake_quant = fake_quant
        self.quantized_node_info = []
        self._calibration_data = []
        self._fp32_print_data = []
        self.data_loader = data_loader
        self._check_tf_version()
        self._check_args()
        self._gen_tmp_filenames()
        self._kl_op_dict = {}
        self._kl_keys = []
        self._print_node_mapping = {}
        self._enable_kl_op_names = [
            k for k in self.op_wise_config if self.op_wise_config[k][1] == 'kl'
        ]
        self._fp32_model = TensorflowModel(self.model._model,
                                           self.model.framework_specific_info,
                                           **self.model.kwargs)
        self._fp32_model.graph_def = self.model.graph_def

        self._sampling_model = TensorflowModel(
            self.model._model, self.model.framework_specific_info,
            **self.model.kwargs)
        self._sampling_model.graph_def = self.model.graph_def

        self._tmp_graph_def = copy.deepcopy(self.model.graph_def)
示例#4
0
    def _fuse_requantize_with_fused_quantized_node(self):
        if self.fake_quant:
            self._tmp_graph_def = FreezeFakeQuantOpOptimizer(
                self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FuseConvRequantizeTransformer(
            self._tmp_graph_def, self.device).do_transformation()

        if not self.fake_quant:
            self._tmp_graph_def = FuseMatMulRequantizeTransformer(
                self._tmp_graph_def).do_transformation()

            self._tmp_graph_def = FuseMatMulRequantizeDequantizeTransformer(
                self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = StripUnusedNodesOptimizer(
            self._tmp_graph_def, self._tmp_model.input_node_names,
            self._tmp_model.output_node_names).do_transformation()

        self._tmp_graph_def = RemoveTrainingNodesOptimizer(
            self._tmp_graph_def,
            protected_nodes=self._tmp_model.output_node_names
        ).do_transformation()

        self._tmp_graph_def = FoldBatchNormNodesOptimizer(
            self._tmp_graph_def).do_transformation()

        if 'scale_propagation_concat' in self.recipes and self.recipes[
                'scale_propagation_concat']:
            self._tmp_graph_def = RerangeQuantizedConcat(
                self._tmp_graph_def, self.device).do_transformation()

        self._tmp_graph_def = MetaInfoChangingMemOpOptimizer(
            self._tmp_graph_def).do_transformation()

        if self.advance_config is not None and \
           deep_get(self.advance_config, 'bias_correction') is not None:
            self._tmp_graph_def = BiasCorrection(
                self._tmp_graph_def, self.model.graph_def).do_transformation()

        self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)

        self._tmp_model.graph_def = self._tmp_graph_def
示例#5
0
    def _fuse_requantize_with_fused_quantized_node(self):
        self._tmp_graph_def = FuseConvRequantizeTransformer(
            self._tmp_graph_def, self.device).do_transformation()

        self._tmp_graph_def = FuseMatMulRequantizeTransformer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FuseMatMulRequantizeDequantizeTransformer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = StripUnusedNodesOptimizer(
            self._tmp_graph_def, self.input_node_names,
            self.output_node_names).do_transformation()

        self._tmp_graph_def = RemoveTrainingNodesOptimizer(
            self._tmp_graph_def,
            protected_nodes=self.output_node_names).do_transformation()

        self._tmp_graph_def = FoldBatchNormNodesOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = RerangeQuantizedConcat(
            self._tmp_graph_def, self.device).do_transformation()

        self._tmp_graph_def = PostCseOptimizer(
            self._tmp_graph_def).do_transformation()

        if self.advance_config is not None and \
           deep_get(self.advance_config, 'bias_correction') is not None:
            self._tmp_graph_def = BiasCorrection(
                self._tmp_graph_def, self.input_graph).do_transformation()

        self._tmp_graph_def.library.CopyFrom(self.input_graph.library)

        if self.debug:
            write_graph(self._tmp_graph_def, self.output_graph)
            self.logger.info('Converted graph file is saved to: %s',
                             self.output_graph)