def get_predefined_configuration( data: Dict[str, Any], ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """Get configuration.""" from lpot.ux.utils.utils import get_framework_from_path, get_predefined_config_path from lpot.ux.utils.workload.config import Config model_path = data.get("model_path", "") if not os.path.isfile(model_path): raise ClientErrorException( f"Could not find model in specified path: {model_path}.", ) model_name = Path(model_path).stem domain = data.get("domain", None) if not domain: raise ClientErrorException("Domain is not defined!") framework = get_framework_from_path(model_path) if framework is None: raise ClientErrorException( f"Could not find framework for specified model {model_name} in path {model_path}.", ) config = Config() predefined_config_path = get_predefined_config_path(framework, domain) config.load(predefined_config_path) return { "config": config.serialize(), "framework": framework, "name": model_name, "domain": domain, }
def test_load(self) -> None: """Test load.""" config = Config() read_yaml = yaml.dump(self.predefined_config, sort_keys=False) with patch( "lpot.ux.utils.workload.config.open", mock_open(read_data=read_yaml), ) as mocked_open: config.load("path to yaml file") mocked_open.assert_called_once_with("path to yaml file") expected = Config(self.predefined_config) self.assertEqual(expected.serialize(), config.serialize())