def model_output_name(self) -> str: """Get output model name.""" output_name = self.model_name if self.mode == Optimizations.TUNING: output_name += "_tuned_" + self.output_precision elif self.mode == Optimizations.GRAPH: output_name = self.model_name + "_optimized_" if self.tune: output_name += "tuned_" output_name += "_".join([ precision.strip() for precision in self.output_precision.split(",") ], ) else: raise ClientErrorException(f"Mode {self.mode} is not supported.") return output_name + "." + get_file_extension(self.model_path)
def __init__(self, data: Dict[str, Any]): """Initialize Workload class.""" super().__init__() self.config: Config = Config() self.id: str = str(data.get("id", "")) if not self.id: raise ClientErrorException("Workload ID not specified.") self.model_path: str = data.get("model_path", "") if not self.model_path: raise ClientErrorException("Model path is not defined!") self.model_name = Path(self.model_path).stem self.domain: str = data.get("domain", None) if not self.domain: raise ClientErrorException("Domain is not defined!") self.framework: str = data.get( "framework", get_framework_from_path(self.model_path), ) self.predefined_config_path = data.get( "config_path", get_predefined_config_path(self.framework, self.domain), ) self.workspace_path = data.get( "workspace_path", os.path.dirname(self.model_path), ) self.workload_path = data.get( "workload_path", os.path.join( self.workspace_path, "workloads", f"{self.model_name}_{self.id}", ), ) self.set_workspace() self.config_name = "config.yaml" self.config_path = os.path.join( self.workload_path, self.config_name, ) model_output_name = (self.model_name + "_int8." + get_file_extension(self.model_path)) self.model_output_path = os.path.join( self.workload_path, model_output_name, ) self.eval_dataset_path: str = data.get("eval_dataset_path", "") self.calib_dataset_path: str = data.get("eval_dataset_path", "") self.set_dataset_paths(data) for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]: if dataset_path != "no_dataset_location" and not os.path.exists( dataset_path, ): raise ClientErrorException( f'Could not found dataset in specified location: "{dataset_path}".', ) if not os.path.isfile(self.model_path): raise ClientErrorException( f'Could not found model in specified location: "{self.model_path}".', ) self.accuracy_goal: float = data.get("accuracy_goal", 0.01) if not os.path.isfile(self.config_path): self.config.load(self.predefined_config_path) else: self.config.load(self.config_path) self.config.model.name = self.model_name self.config.set_evaluation_dataset_path(self.eval_dataset_path) self.config.set_quantization_dataset_path(self.calib_dataset_path) self.config.set_workspace(self.workload_path) self.config.set_accuracy_goal(self.accuracy_goal)
def test_get_file_without_extension(self) -> None: """Test getting file extension from path.""" path = "/home/user/file" result = get_file_extension(path) self.assertEqual(result, "")
def test_get_file_with_dots_extension(self) -> None: """Test getting file extension from path.""" path = "/home/user/file.name.ext2" result = get_file_extension(path) self.assertEqual(result, "ext2")
def supports_path(path: str) -> bool: """Check if given path is of supported model.""" return "onnx" == get_file_extension(path)