예제 #1
0
 def set_config(self,
                model_path,
                num_threads,
                mkldnn_cache_capacity,
                warmup_data=None,
                use_analysis=False,
                enable_ptq=False):
     config = AnalysisConfig(model_path)
     config.set_cpu_math_library_num_threads(num_threads)
     if use_analysis:
         config.disable_gpu()
         config.switch_use_feed_fetch_ops(True)
         config.switch_ir_optim(True)
         config.enable_mkldnn()
         config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
         if enable_ptq:
             # This pass to work properly, must be added before fc_fuse_pass
             config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
             config.enable_quantizer()
             config.quantizer_config().set_quant_data(warmup_data)
             config.quantizer_config().set_quant_batch_size(1)
     return config
def set_config_ptq(model_path, warmup_data):
    config = None
    if os.path.exists(os.path.join(model_path, '__model__')):
        config = AnalysisConfig(model_path)
    else:
        config = AnalysisConfig(model_path + '/model', model_path + '/params')
    config.switch_ir_optim(True)
    # This pass must be added before fc_fuse_pass to work properly
    config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
    config.pass_builder().append_pass("fc_mkldnn_pass")

    config.enable_mkldnn()
    config.set_mkldnn_cache_capacity(test_args.mkldnn_cache_capacity)
    config.set_cpu_math_library_num_threads(test_args.num_threads)

    config.enable_quantizer()
    config.quantizer_config().set_quant_data(warmup_data)
    config.quantizer_config().set_quant_batch_size(1)
    ops_to_quantize = set()
    if len(test_args.ops_to_quantize) > 0:
        ops_to_quantize = set(test_args.ops_to_quantize.split(','))
    config.quantizer_config().set_enabled_op_types(ops_to_quantize)

    return config