def test_write_config_file(self): model_config = ModelConfig.create_from_dictionary(self._model_config) model_output_path = os.path.abspath('./model_config') mock_model_config = MockModelConfig() mock_model_config.start() # Write the model config to output with patch('model_analyzer.triton.model.model_config.open', mock_open()) as mocked_file: with patch('model_analyzer.triton.model.model_config.copy_tree', MagicMock()): model_config.write_config_to_file(model_output_path, '/mock/path', None) content = mocked_file().write.call_args.args[0] mock_model_config.stop() mock_model_config = MockModelConfig(content) mock_model_config.start() model_config_from_file = \ ModelConfig.create_from_file(model_output_path) self.assertTrue( model_config_from_file.get_config() == self._model_config) mock_model_config.stop() # output path doesn't exist with patch('model_analyzer.triton.model.model_config.os.path.exists', MagicMock(return_value=False)): with self.assertRaises(TritonModelAnalyzerException): ModelConfig.create_from_file(model_output_path) # output path is a file with patch('model_analyzer.triton.model.model_config.os.path.isfile', MagicMock(return_value=True)): with self.assertRaises(TritonModelAnalyzerException): ModelConfig.create_from_file(model_output_path)
def test_create_from_file(self): test_protobuf = self._model_config_protobuf mock_model_config = MockModelConfig(test_protobuf) mock_model_config.start() model_config = ModelConfig.create_from_file('/path/to/model_config') self.assertTrue(model_config.get_config() == self._model_config) mock_model_config.stop()
def _generate_run_configs(self): analyzer_config = self._analyzer_config model_repository = analyzer_config['model_repository'] model = self._model model_name_index = 0 model_config_parameters = model.model_config_parameters() # Generate all the sweeps for a given parameter models_sweeps = \ self._generate_model_config_combinations( model_config_parameters) for model_sweep in models_sweeps: model_config = ModelConfig.create_from_file( f'{model_repository}/{model.model_name()}') model_config_dict = model_config.get_config() for key, value in model_sweep.items(): model_config_dict[key] = value model_config = ModelConfig.create_from_dictionary( model_config_dict) # Temporary model name to be used for profiling. We # can't use the same name for different configurations. # The new model name is the original model suffixed with # _i<config_index>. Where the config index is the index # of the model config alternative. model_tmp_name = f'{model.model_name()}_i{model_name_index}' model_config.set_field('name', model_tmp_name) perf_configs = self._generate_perf_config_for_model( model_tmp_name, model) # Add the new run config. self._run_configs.append( RunConfig(model.model_name(), model_config, perf_configs)) model_name_index += 1
def generate_run_config_for_model_sweep(self, model, model_sweep): """ Parameters ---------- model : ConfigModel The model for which a run config is being generated model_sweep: dict Model config parameters """ analyzer_config = self._analyzer_config model_repository = analyzer_config['model_repository'] num_retries = analyzer_config['max_retries'] if analyzer_config['triton_launch_mode'] != 'remote': model_config = ModelConfig.create_from_file( f'{model_repository}/{model.model_name()}') if model_sweep is not None: model_config_dict = model_config.get_config() for key, value in model_sweep.items(): if value is not None: model_config_dict[key] = value model_config = ModelConfig.create_from_dictionary( model_config_dict) model_name_index = self._model_name_index model_config_dict = model_config.get_config() try: model_name_index = self._model_configs.index(model_config_dict) except ValueError: self._model_configs.append(model_config_dict) self._model_name_index += 1 # Temporary model name to be used for profiling. We # can't use the same name for different configurations. # The new model name is the original model suffixed with # _i<config_index>. Where the config index is the index # of the model config alternative. model_tmp_name = f'{model.model_name()}_i{model_name_index}' model_config.set_field('name', model_tmp_name) perf_configs = self._generate_perf_config_for_model( model_tmp_name, model) for perf_config in perf_configs: self._run_configs.append( RunConfig(model.model_name(), model_config, perf_config)) else: model_config = ModelConfig.create_from_triton_api( self._client, model.model_name(), num_retries) perf_configs = self._generate_perf_config_for_model( model.model_name(), model) for perf_config in perf_configs: # Add the new run config. self._run_configs.append( RunConfig(model.model_name(), model_config, perf_config))
def _get_base_model_config_dict(self): config = ModelConfig.create_from_file( f'{self._model_repository}/{self._base_model_name}') return config.get_config()