示例#1
0
    def test_write_config_file(self):
        model_config = ModelConfig.create_from_dictionary(self._model_config)
        model_output_path = os.path.abspath('./model_config')

        mock_model_config = MockModelConfig()
        mock_model_config.start()
        # Write the model config to output
        with patch('model_analyzer.triton.model.model_config.open',
                   mock_open()) as mocked_file:
            with patch('model_analyzer.triton.model.model_config.copy_tree',
                       MagicMock()):
                model_config.write_config_to_file(model_output_path,
                                                  '/mock/path', None)
            content = mocked_file().write.call_args.args[0]
        mock_model_config.stop()

        mock_model_config = MockModelConfig(content)
        mock_model_config.start()
        model_config_from_file = \
            ModelConfig.create_from_file(model_output_path)
        self.assertTrue(
            model_config_from_file.get_config() == self._model_config)
        mock_model_config.stop()

        # output path doesn't exist
        with patch('model_analyzer.triton.model.model_config.os.path.exists',
                   MagicMock(return_value=False)):
            with self.assertRaises(TritonModelAnalyzerException):
                ModelConfig.create_from_file(model_output_path)

        # output path is a file
        with patch('model_analyzer.triton.model.model_config.os.path.isfile',
                   MagicMock(return_value=True)):
            with self.assertRaises(TritonModelAnalyzerException):
                ModelConfig.create_from_file(model_output_path)
示例#2
0
 def test_create_from_file(self):
     test_protobuf = self._model_config_protobuf
     mock_model_config = MockModelConfig(test_protobuf)
     mock_model_config.start()
     model_config = ModelConfig.create_from_file('/path/to/model_config')
     self.assertTrue(model_config.get_config() == self._model_config)
     mock_model_config.stop()
    def _generate_run_configs(self):
        analyzer_config = self._analyzer_config
        model_repository = analyzer_config['model_repository']
        model = self._model

        model_name_index = 0
        model_config_parameters = model.model_config_parameters()

        # Generate all the sweeps for a given parameter
        models_sweeps = \
            self._generate_model_config_combinations(
                model_config_parameters)
        for model_sweep in models_sweeps:
            model_config = ModelConfig.create_from_file(
                f'{model_repository}/{model.model_name()}')
            model_config_dict = model_config.get_config()
            for key, value in model_sweep.items():
                model_config_dict[key] = value
            model_config = ModelConfig.create_from_dictionary(
                model_config_dict)

            # Temporary model name to be used for profiling. We
            # can't use the same name for different configurations.
            # The new model name is the original model suffixed with
            # _i<config_index>. Where the config index is the index
            # of the model config alternative.
            model_tmp_name = f'{model.model_name()}_i{model_name_index}'
            model_config.set_field('name', model_tmp_name)
            perf_configs = self._generate_perf_config_for_model(
                model_tmp_name, model)

            # Add the new run config.
            self._run_configs.append(
                RunConfig(model.model_name(), model_config, perf_configs))
            model_name_index += 1
示例#4
0
    def generate_run_config_for_model_sweep(self, model, model_sweep):
        """
        Parameters
        ----------
        model : ConfigModel
            The model for which a run config is being generated
        model_sweep: dict
            Model config parameters
        """

        analyzer_config = self._analyzer_config
        model_repository = analyzer_config['model_repository']
        num_retries = analyzer_config['max_retries']

        if analyzer_config['triton_launch_mode'] != 'remote':
            model_config = ModelConfig.create_from_file(
                f'{model_repository}/{model.model_name()}')

            if model_sweep is not None:
                model_config_dict = model_config.get_config()
                for key, value in model_sweep.items():
                    if value is not None:
                        model_config_dict[key] = value
                model_config = ModelConfig.create_from_dictionary(
                    model_config_dict)

            model_name_index = self._model_name_index
            model_config_dict = model_config.get_config()

            try:
                model_name_index = self._model_configs.index(model_config_dict)
            except ValueError:
                self._model_configs.append(model_config_dict)
                self._model_name_index += 1

            # Temporary model name to be used for profiling. We
            # can't use the same name for different configurations.
            # The new model name is the original model suffixed with
            # _i<config_index>. Where the config index is the index
            # of the model config alternative.
            model_tmp_name = f'{model.model_name()}_i{model_name_index}'
            model_config.set_field('name', model_tmp_name)
            perf_configs = self._generate_perf_config_for_model(
                model_tmp_name, model)
            for perf_config in perf_configs:
                self._run_configs.append(
                    RunConfig(model.model_name(), model_config, perf_config))
        else:
            model_config = ModelConfig.create_from_triton_api(
                self._client, model.model_name(), num_retries)
            perf_configs = self._generate_perf_config_for_model(
                model.model_name(), model)

            for perf_config in perf_configs:
                # Add the new run config.
                self._run_configs.append(
                    RunConfig(model.model_name(), model_config, perf_config))
 def _get_base_model_config_dict(self):
     config = ModelConfig.create_from_file(
         f'{self._model_repository}/{self._base_model_name}')
     return config.get_config()