def get_predefined_configuration( data: Dict[str, Any], ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """Get configuration.""" from lpot.ux.utils.utils import get_framework_from_path, get_predefined_config_path from lpot.ux.utils.workload.config import Config model_path = data.get("model_path", "") if not os.path.isfile(model_path): raise ClientErrorException( f"Could not find model in specified path: {model_path}.", ) model_name = Path(model_path).stem domain = data.get("domain", None) if not domain: raise ClientErrorException("Domain is not defined!") framework = get_framework_from_path(model_path) if framework is None: raise ClientErrorException( f"Could not find framework for specified model {model_name} in path {model_path}.", ) config = Config() predefined_config_path = get_predefined_config_path(framework, domain) config.load(predefined_config_path) return { "config": config.serialize(), "framework": framework, "name": model_name, "domain": domain, }
def test_get_tensorflow_framework_from_path( self, mocked_get_model_type: MagicMock) -> None: """Test getting framework name from path.""" mocked_get_model_type.return_value = "frozen_pb" path = "/home/user/model.pb" result = get_framework_from_path(path) self.assertEqual(result, "tensorflow") mocked_get_model_type.assert_called_with(path)
def find(self, model_path: str) -> Reader: """Find Graph Model Reader for given model.""" framework = get_framework_from_path(model_path) if framework is None: raise ClientErrorException( f"Models of {model_path} type are not yet supported.", ) reader_name = self._framework_readers.get(framework) if reader_name is None: raise ClientErrorException( f"Models from {framework} framework are not yet supported.", ) return reader_name()
def __init__(self, data: Dict[str, Any]): """Initialize Workload class.""" super().__init__() self.config: Config = Config() self.id: str = str(data.get("id", "")) if not self.id: raise ClientErrorException("Workload ID not specified.") self.model_path: str = data.get("model_path", "") if not self.model_path: raise ClientErrorException("Model path is not defined!") self.model_name = Path(self.model_path).stem self.domain: str = data.get("domain", None) if not self.domain: raise ClientErrorException("Domain is not defined!") self.framework: str = data.get( "framework", get_framework_from_path(self.model_path), ) self.predefined_config_path = data.get( "config_path", get_predefined_config_path(self.framework, self.domain), ) self.workspace_path = data.get( "workspace_path", os.path.dirname(self.model_path), ) self.workload_path = data.get( "workload_path", os.path.join( self.workspace_path, "workloads", f"{self.model_name}_{self.id}", ), ) self.set_workspace() self.config_name = "config.yaml" self.config_path = os.path.join( self.workload_path, self.config_name, ) model_output_name = (self.model_name + "_int8." + get_file_extension(self.model_path)) self.model_output_path = os.path.join( self.workload_path, model_output_name, ) self.eval_dataset_path: str = data.get("eval_dataset_path", "") self.calib_dataset_path: str = data.get("eval_dataset_path", "") self.set_dataset_paths(data) for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]: if dataset_path != "no_dataset_location" and not os.path.exists( dataset_path, ): raise ClientErrorException( f'Could not found dataset in specified location: "{dataset_path}".', ) if not os.path.isfile(self.model_path): raise ClientErrorException( f'Could not found model in specified location: "{self.model_path}".', ) self.accuracy_goal: float = data.get("accuracy_goal", 0.01) if not os.path.isfile(self.config_path): self.config.load(self.predefined_config_path) else: self.config.load(self.config_path) self.config.model.name = self.model_name self.config.set_evaluation_dataset_path(self.eval_dataset_path) self.config.set_quantization_dataset_path(self.calib_dataset_path) self.config.set_workspace(self.workload_path) self.config.set_accuracy_goal(self.accuracy_goal)
def __init__(self, data: Dict[str, Any]): """Initialize Workload class.""" super().__init__() self.config: Config = Config() self.id: str = str(data.get("id", "")) if not self.id: raise ClientErrorException("Workload ID not specified.") self.model_path: str = data.get("model_path", "") if not self.model_path: raise ClientErrorException("Model path is not defined!") self.model_name = Path(self.model_path).stem self.domain: str = data.get("domain", None) if not self.domain: raise ClientErrorException("Domain is not defined!") self.framework: str = data.get( "framework", get_framework_from_path(self.model_path), ) self.workspace_path = data.get( "workspace_path", os.path.dirname(self.model_path), ) self.workload_path = data.get( "workload_path", os.path.join( self.workspace_path, "workloads", f"{self.model_name}_{self.id}", ), ) self.set_workspace() self.eval_dataset_path: str = data.get("eval_dataset_path", "") self.calib_dataset_path: str = data.get("eval_dataset_path", "") self.set_dataset_paths(data) for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]: if dataset_path != "no_dataset_location" and not os.path.exists( dataset_path, ): raise ClientErrorException( f'Could not found dataset in specified location: "{dataset_path}".', ) if not ModelRepository.is_model_path(self.model_path): raise ClientErrorException( f'Could not found model in specified location: "{self.model_path}".', ) self.accuracy_goal: float = data.get("accuracy_goal", 0.01) self.config_name = "config.yaml" self.predefined_config_path = data.get( "config_path", get_predefined_config_path(self.framework, self.domain), ) self.config_path = os.path.join( self.workload_path, self.config_name, ) self.input_precision = Precisions.FP32 # TODO: Detect input model precision self.output_precision = data.get("precision", data.get("output_precision")) self.mode = self.get_optimization_mode() self.tune = data.get("tune", self.is_tuning_enabled(data)) self.initialize_config(data) self.input_nodes: Optional[str] = data.get("inputs", data.get("input_nodes")) self.output_nodes: Optional[str] = data.get("outputs", data.get("output_nodes")) self.model_output_path = os.path.join( self.workload_path, self.model_output_name, ) self.version = "2.0"
def test_get_unknown_framework_from_path(self) -> None: """Test getting framework name from path.""" path = "/home/user/model.some_extension" result = get_framework_from_path(path) self.assertIsNone(result)
def test_get_onnx_framework_from_path(self) -> None: """Test getting framework name from path.""" path = "/home/user/model.onnx" result = get_framework_from_path(path) self.assertEqual(result, "onnxrt")
def get_boundary_nodes(data: Dict[str, Any]) -> None: """Get configuration.""" from lpot.ux.utils.utils import find_boundary_nodes request_id = str(data.get("id", "")) model_path = data.get("model_path", None) if not (request_id and model_path): message = "Missing model path or request id." mq.post_error( "boundary_nodes_finish", { "message": message, "code": 404, "id": request_id }, ) return if not os.path.isfile(model_path): message = "Could not found model in specified path." mq.post_error( "boundary_nodes_finish", { "message": message, "code": 404, "id": request_id }, ) return try: mq.post_success( "boundary_nodes_start", { "message": "started", "id": request_id }, ) framework = get_framework_from_path(model_path) if framework is None: supported_frameworks = list(framework_extensions.keys()) raise ClientErrorException( f"Framework for specified model is not yet supported. " f"Supported frameworks are: {', '.join(supported_frameworks)}.", ) try: check_module(framework) except ClientErrorException: raise ClientErrorException( f"Detected {framework} model. " f"Could not find installed {framework} module. " f"Please install {framework}.", ) framework_version = get_module_version(framework) response_data = find_boundary_nodes(model_path) response_data["id"] = request_id response_data["framework"] = framework response_data["framework_version"] = framework_version except ClientErrorException as err: mq.post_error( "boundary_nodes_finish", { "message": str(err), "code": 404, "id": request_id }, ) return log.debug(f"Parsed data is {json.dumps(response_data)}") mq.post_success("boundary_nodes_finish", response_data)
def test_get_tensorflow_framework_from_path(self) -> None: """Test getting framework name from path.""" path = "/home/user/model.pb" result = get_framework_from_path(path) self.assertEqual(result, "tensorflow")