def set_config(self, model_path, num_threads, mkldnn_cache_capacity, warmup_data=None, use_analysis=False, enable_ptq=False): config = AnalysisConfig(model_path) config.set_cpu_math_library_num_threads(num_threads) if use_analysis: config.disable_gpu() config.switch_use_feed_fetch_ops(True) config.switch_ir_optim(True) config.enable_mkldnn() config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) if enable_ptq: # This pass to work properly, must be added before fc_fuse_pass config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass") config.enable_quantizer() config.quantizer_config().set_quant_data(warmup_data) config.quantizer_config().set_quant_batch_size(1) return config
def set_config_ptq(model_path, warmup_data): config = None if os.path.exists(os.path.join(model_path, '__model__')): config = AnalysisConfig(model_path) else: config = AnalysisConfig(model_path + '/model', model_path + '/params') config.switch_ir_optim(True) # This pass must be added before fc_fuse_pass to work properly config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass") config.pass_builder().append_pass("fc_mkldnn_pass") config.enable_mkldnn() config.set_mkldnn_cache_capacity(test_args.mkldnn_cache_capacity) config.set_cpu_math_library_num_threads(test_args.num_threads) config.enable_quantizer() config.quantizer_config().set_quant_data(warmup_data) config.quantizer_config().set_quant_batch_size(1) ops_to_quantize = set() if len(test_args.ops_to_quantize) > 0: ops_to_quantize = set(test_args.ops_to_quantize.split(',')) config.quantizer_config().set_enabled_op_types(ops_to_quantize) return config