def download_config(self) -> None: """Find yaml config resource and initialize downloading.""" if not (self.request_id and self.framework and self.domain and self.model and self.workspace_path): message = "Missing request id, workspace path, framework, domain or model." self.mq.post_error( "download_finish", { "message": message, "code": 404, "id": self.request_id }, ) raise ClientErrorException(message) model_config = load_model_config() model_info = (model_config.get(self.framework, {}).get(self.domain, {}).get(self.model, None)) if model_info is None: raise Exception( f"{self.framework} {self.domain} {self.model} is not supported.", ) self.download_dir = os.path.join( self.workspace_path, "examples", self.framework, self.domain, self.model, ) self.download_yaml_config(model_info)
def get_frameworks() -> List[dict]: """Get list of available frameworks.""" frameworks = [] models_config = load_model_config() for framework in models_config.keys(): if framework.startswith("__help__"): continue if framework not in framework_extensions.keys(): continue help_msg = models_config.get(f"__help__{framework}", "") frameworks.append({"name": framework, "help": help_msg}) return frameworks
def get_available_models( workspace_path: Optional[str]) -> List[Dict[str, Any]]: """Get available models from Examples.""" model_list = [] full_list = load_model_config() for framework in SUPPORTED_FRAMEWORKS: try: framework_version = get_module_version(framework) except Exception: log.debug(f"Framework {framework} not installed.") continue log.debug(f"{framework} version is {framework_version}") framework_dict = full_list[framework] for domain, domain_dict in framework_dict.items(): if not isinstance(domain_dict, dict): continue for model, model_dict in domain_dict.items(): if not isinstance(model_dict, dict): continue if check_version( framework_version, model_dict.get("framework_version", []), ): model_list.append( { "framework": framework, "domain": domain, "model": model, "yaml": get_model_zoo_config_path( workspace_path, framework, domain, model, model_dict, ), "model_path": get_model_zoo_model_path( workspace_path, framework, domain, model, model_dict, ), }, ) validate_model_list(model_list) return model_list
def get_domains(self) -> List[Dict[str, Any]]: """Get list of available domains.""" framework = self.config.get("framework", None) if framework is None: raise ClientErrorException("Framework not set.") models_config = load_model_config() domains = [] for domain in models_config.get(framework, {}).keys(): if domain.startswith("__help__"): continue help_msg = models_config.get(framework, {}).get(f"__help__{domain}", "") domains.append({ "name": domain, "help": help_msg, }, ) return domains
def get_models(self) -> List[Dict[str, Any]]: """Get list of models.""" framework = self.config.get("framework", None) if framework is None: raise ClientErrorException("Framework not set.") domain = self.config.get("domain", None) if domain is None: raise ClientErrorException("Domain not set.") models_config = load_model_config() raw_models_dict = models_config.get(framework, {}).get(domain, {}) models = [] for model in raw_models_dict.keys(): if model.startswith("__help__"): continue help_msg = raw_models_dict.get(f"__help__{model}", "") models.append({"name": model, "help": help_msg}) return models
def download_model(self) -> None: """Find model resource and initialize downloading.""" model_config = load_model_config() model_info = (model_config.get(self.framework, {}).get(self.domain, {}).get(self.model, None)) if model_info is None: raise Exception( f"{self.framework} {self.domain} {self.model} is not supported.", ) self.download_dir = os.path.join( self.workspace_path, "examples", self.framework, self.domain, self.model, ) self.download(model_info)
def test_load_model_config(self) -> None: """Test getting models config.""" result = load_model_config() self.assertIs(type(result), dict) self.assertIsNot(result, {})