Ejemplo n.º 1
0
    def model_output_name(self) -> str:
        """Get output model name."""
        output_name = self.model_name
        if self.mode == Optimizations.TUNING:
            output_name += "_tuned_" + self.output_precision
        elif self.mode == Optimizations.GRAPH:
            output_name = self.model_name + "_optimized_"
            if self.tune:
                output_name += "tuned_"
            output_name += "_".join([
                precision.strip()
                for precision in self.output_precision.split(",")
            ], )
        else:
            raise ClientErrorException(f"Mode {self.mode} is not supported.")

        return output_name + "." + get_file_extension(self.model_path)
Ejemplo n.º 2
0
    def __init__(self, data: Dict[str, Any]):
        """Initialize Workload class."""
        super().__init__()
        self.config: Config = Config()

        self.id: str = str(data.get("id", ""))
        if not self.id:
            raise ClientErrorException("Workload ID not specified.")

        self.model_path: str = data.get("model_path", "")
        if not self.model_path:
            raise ClientErrorException("Model path is not defined!")

        self.model_name = Path(self.model_path).stem

        self.domain: str = data.get("domain", None)

        if not self.domain:
            raise ClientErrorException("Domain is not defined!")

        self.framework: str = data.get(
            "framework",
            get_framework_from_path(self.model_path),
        )
        self.predefined_config_path = data.get(
            "config_path",
            get_predefined_config_path(self.framework, self.domain),
        )
        self.workspace_path = data.get(
            "workspace_path",
            os.path.dirname(self.model_path),
        )
        self.workload_path = data.get(
            "workload_path",
            os.path.join(
                self.workspace_path,
                "workloads",
                f"{self.model_name}_{self.id}",
            ),
        )

        self.set_workspace()
        self.config_name = "config.yaml"
        self.config_path = os.path.join(
            self.workload_path,
            self.config_name,
        )

        model_output_name = (self.model_name + "_int8." +
                             get_file_extension(self.model_path))
        self.model_output_path = os.path.join(
            self.workload_path,
            model_output_name,
        )

        self.eval_dataset_path: str = data.get("eval_dataset_path", "")
        self.calib_dataset_path: str = data.get("eval_dataset_path", "")
        self.set_dataset_paths(data)

        for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]:
            if dataset_path != "no_dataset_location" and not os.path.exists(
                    dataset_path, ):
                raise ClientErrorException(
                    f'Could not found dataset in specified location: "{dataset_path}".',
                )

        if not os.path.isfile(self.model_path):
            raise ClientErrorException(
                f'Could not found model in specified location: "{self.model_path}".',
            )

        self.accuracy_goal: float = data.get("accuracy_goal", 0.01)

        if not os.path.isfile(self.config_path):
            self.config.load(self.predefined_config_path)
        else:
            self.config.load(self.config_path)

        self.config.model.name = self.model_name
        self.config.set_evaluation_dataset_path(self.eval_dataset_path)
        self.config.set_quantization_dataset_path(self.calib_dataset_path)
        self.config.set_workspace(self.workload_path)
        self.config.set_accuracy_goal(self.accuracy_goal)
Ejemplo n.º 3
0
 def test_get_file_without_extension(self) -> None:
     """Test getting file extension from path."""
     path = "/home/user/file"
     result = get_file_extension(path)
     self.assertEqual(result, "")
Ejemplo n.º 4
0
 def test_get_file_with_dots_extension(self) -> None:
     """Test getting file extension from path."""
     path = "/home/user/file.name.ext2"
     result = get_file_extension(path)
     self.assertEqual(result, "ext2")
Ejemplo n.º 5
0
Archivo: model.py Proyecto: intel/lpot
 def supports_path(path: str) -> bool:
     """Check if given path is of supported model."""
     return "onnx" == get_file_extension(path)