def test_get_predefined_config_path_domain_failure(self) -> None: """Test getting predefined config path for onnx NLP models.""" with self.assertRaises(Exception): get_predefined_config_path( framework="onnxrt", domain="object_detection", )
def test_get_predefined_config_path_framework_failure(self) -> None: """Test getting predefined config path for onnx NLP models.""" with self.assertRaises(Exception): get_predefined_config_path( framework="onnx", domain="image_recognition", )
def get_predefined_configuration( data: Dict[str, Any], ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """Get configuration.""" from lpot.ux.utils.utils import get_framework_from_path, get_predefined_config_path from lpot.ux.utils.workload.config import Config model_path = data.get("model_path", "") if not os.path.isfile(model_path): raise ClientErrorException( f"Could not find model in specified path: {model_path}.", ) model_name = Path(model_path).stem domain = data.get("domain", None) if not domain: raise ClientErrorException("Domain is not defined!") framework = get_framework_from_path(model_path) if framework is None: raise ClientErrorException( f"Could not find framework for specified model {model_name} in path {model_path}.", ) config = Config() predefined_config_path = get_predefined_config_path(framework, domain) config.load(predefined_config_path) return { "config": config.serialize(), "framework": framework, "name": model_name, "domain": domain, }
def _assert_predefined_config_path(self, framework: str, domain: str) -> None: """Assert predefined config path.""" result = get_predefined_config_path(framework, domain) expected = os.path.join( os.path.abspath( os.path.dirname( inspect.getfile(get_predefined_config_path), ), ), "configs", "predefined_configs", f"{framework}", f"{domain}.yaml", ) self.assertEqual(result, expected) self.assertEqual(os.path.isfile(result), True)
def __init__(self, data: Dict[str, Any]): """Initialize Workload class.""" super().__init__() self.config: Config = Config() self.id: str = str(data.get("id", "")) if not self.id: raise ClientErrorException("Workload ID not specified.") self.model_path: str = data.get("model_path", "") if not self.model_path: raise ClientErrorException("Model path is not defined!") self.model_name = Path(self.model_path).stem self.domain: str = data.get("domain", None) if not self.domain: raise ClientErrorException("Domain is not defined!") self.framework: str = data.get( "framework", get_framework_from_path(self.model_path), ) self.predefined_config_path = data.get( "config_path", get_predefined_config_path(self.framework, self.domain), ) self.workspace_path = data.get( "workspace_path", os.path.dirname(self.model_path), ) self.workload_path = data.get( "workload_path", os.path.join( self.workspace_path, "workloads", f"{self.model_name}_{self.id}", ), ) self.set_workspace() self.config_name = "config.yaml" self.config_path = os.path.join( self.workload_path, self.config_name, ) model_output_name = (self.model_name + "_int8." + get_file_extension(self.model_path)) self.model_output_path = os.path.join( self.workload_path, model_output_name, ) self.eval_dataset_path: str = data.get("eval_dataset_path", "") self.calib_dataset_path: str = data.get("eval_dataset_path", "") self.set_dataset_paths(data) for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]: if dataset_path != "no_dataset_location" and not os.path.exists( dataset_path, ): raise ClientErrorException( f'Could not found dataset in specified location: "{dataset_path}".', ) if not os.path.isfile(self.model_path): raise ClientErrorException( f'Could not found model in specified location: "{self.model_path}".', ) self.accuracy_goal: float = data.get("accuracy_goal", 0.01) if not os.path.isfile(self.config_path): self.config.load(self.predefined_config_path) else: self.config.load(self.config_path) self.config.model.name = self.model_name self.config.set_evaluation_dataset_path(self.eval_dataset_path) self.config.set_quantization_dataset_path(self.calib_dataset_path) self.config.set_workspace(self.workload_path) self.config.set_accuracy_goal(self.accuracy_goal)
def __init__(self, data: Dict[str, Any]): """Initialize Workload class.""" super().__init__() self.config: Config = Config() self.id: str = str(data.get("id", "")) if not self.id: raise ClientErrorException("Workload ID not specified.") self.model_path: str = data.get("model_path", "") if not self.model_path: raise ClientErrorException("Model path is not defined!") self.model_name = Path(self.model_path).stem self.domain: str = data.get("domain", None) if not self.domain: raise ClientErrorException("Domain is not defined!") self.framework: str = data.get( "framework", get_framework_from_path(self.model_path), ) self.workspace_path = data.get( "workspace_path", os.path.dirname(self.model_path), ) self.workload_path = data.get( "workload_path", os.path.join( self.workspace_path, "workloads", f"{self.model_name}_{self.id}", ), ) self.set_workspace() self.eval_dataset_path: str = data.get("eval_dataset_path", "") self.calib_dataset_path: str = data.get("eval_dataset_path", "") self.set_dataset_paths(data) for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]: if dataset_path != "no_dataset_location" and not os.path.exists( dataset_path, ): raise ClientErrorException( f'Could not found dataset in specified location: "{dataset_path}".', ) if not ModelRepository.is_model_path(self.model_path): raise ClientErrorException( f'Could not found model in specified location: "{self.model_path}".', ) self.accuracy_goal: float = data.get("accuracy_goal", 0.01) self.config_name = "config.yaml" self.predefined_config_path = data.get( "config_path", get_predefined_config_path(self.framework, self.domain), ) self.config_path = os.path.join( self.workload_path, self.config_name, ) self.input_precision = Precisions.FP32 # TODO: Detect input model precision self.output_precision = data.get("precision", data.get("output_precision")) self.mode = self.get_optimization_mode() self.tune = data.get("tune", self.is_tuning_enabled(data)) self.initialize_config(data) self.input_nodes: Optional[str] = data.get("inputs", data.get("input_nodes")) self.output_nodes: Optional[str] = data.get("outputs", data.get("output_nodes")) self.model_output_path = os.path.join( self.workload_path, self.model_output_name, ) self.version = "2.0"