def _get_analysis_config(self, use_gpu=False, use_trt=False, use_mkldnn=False): ''' Return a new object of AnalysisConfig. ''' config = AnalysisConfig(self.path) config.disable_gpu() config.switch_specify_input_names(True) config.switch_ir_optim(True) config.switch_use_feed_fetch_ops(False) if use_gpu: config.enable_use_gpu(100, 0) if use_trt: config.enable_tensorrt_engine( self.trt_parameters.workspace_size, self.trt_parameters.max_batch_size, self.trt_parameters.min_subgraph_size, self.trt_parameters.precision, self.trt_parameters.use_static, self.trt_parameters.use_calib_mode) if self.trt_parameters.use_inspector: config.enable_tensorrt_inspector() self.assertTrue( config.tensorrt_inspector_enabled(), "The inspector option is not set correctly.") if self.dynamic_shape_params: config.set_trt_dynamic_shape_info( self.dynamic_shape_params.min_input_shape, self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.disable_trt_plugin_fp16) if self.enable_tensorrt_varseqlen: config.enable_tensorrt_varseqlen() elif use_mkldnn: config.enable_mkldnn() if self.enable_mkldnn_bfloat16: config.enable_mkldnn_bfloat16() print('config summary:', config.summary()) return config
def _get_analysis_config(self, use_gpu=False, use_trt=False, use_mkldnn=False): ''' Return a new object of AnalysisConfig. ''' config = AnalysisConfig(os.path.join(self.path, "model"), os.path.join(self.path, "params")) config.disable_gpu() config.switch_specify_input_names(True) config.switch_ir_optim(True) config.switch_use_feed_fetch_ops(False) if use_gpu: config.enable_use_gpu(100, 0) if use_trt: config.enable_tensorrt_engine( self.trt_parameters.workspace_size, self.trt_parameters.max_batch_size, self.trt_parameters.min_subgraph_size, self.trt_parameters.precision, self.trt_parameters.use_static, self.trt_parameters.use_calib_mode) if self.dynamic_shape_params: config.set_trt_dynamic_shape_info( self.dynamic_shape_params.min_input_shape, self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.disable_trt_plugin_fp16) elif use_mkldnn: config.enable_mkldnn() if self.enable_mkldnn_bfloat16: config.enable_mkldnn_bfloat16() return config