Ejemplo n.º 1
0
 def test_get_predefined_config_path_domain_failure(self) -> None:
     """Test getting predefined config path for onnx NLP models."""
     with self.assertRaises(Exception):
         get_predefined_config_path(
             framework="onnxrt",
             domain="object_detection",
         )
Ejemplo n.º 2
0
 def test_get_predefined_config_path_framework_failure(self) -> None:
     """Test getting predefined config path for onnx NLP models."""
     with self.assertRaises(Exception):
         get_predefined_config_path(
             framework="onnx",
             domain="image_recognition",
         )
Ejemplo n.º 3
0
def get_predefined_configuration(
    data: Dict[str, Any], ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """Get configuration."""
    from lpot.ux.utils.utils import get_framework_from_path, get_predefined_config_path
    from lpot.ux.utils.workload.config import Config

    model_path = data.get("model_path", "")
    if not os.path.isfile(model_path):
        raise ClientErrorException(
            f"Could not find model in specified path: {model_path}.", )

    model_name = Path(model_path).stem

    domain = data.get("domain", None)

    if not domain:
        raise ClientErrorException("Domain is not defined!")

    framework = get_framework_from_path(model_path)
    if framework is None:
        raise ClientErrorException(
            f"Could not find framework for specified model {model_name} in path {model_path}.",
        )

    config = Config()
    predefined_config_path = get_predefined_config_path(framework, domain)
    config.load(predefined_config_path)

    return {
        "config": config.serialize(),
        "framework": framework,
        "name": model_name,
        "domain": domain,
    }
Ejemplo n.º 4
0
 def _assert_predefined_config_path(self, framework: str,
                                    domain: str) -> None:
     """Assert predefined config path."""
     result = get_predefined_config_path(framework, domain)
     expected = os.path.join(
         os.path.abspath(
             os.path.dirname(
                 inspect.getfile(get_predefined_config_path), ), ),
         "configs",
         "predefined_configs",
         f"{framework}",
         f"{domain}.yaml",
     )
     self.assertEqual(result, expected)
     self.assertEqual(os.path.isfile(result), True)
Ejemplo n.º 5
0
    def __init__(self, data: Dict[str, Any]):
        """Initialize Workload class."""
        super().__init__()
        self.config: Config = Config()

        self.id: str = str(data.get("id", ""))
        if not self.id:
            raise ClientErrorException("Workload ID not specified.")

        self.model_path: str = data.get("model_path", "")
        if not self.model_path:
            raise ClientErrorException("Model path is not defined!")

        self.model_name = Path(self.model_path).stem

        self.domain: str = data.get("domain", None)

        if not self.domain:
            raise ClientErrorException("Domain is not defined!")

        self.framework: str = data.get(
            "framework",
            get_framework_from_path(self.model_path),
        )
        self.predefined_config_path = data.get(
            "config_path",
            get_predefined_config_path(self.framework, self.domain),
        )
        self.workspace_path = data.get(
            "workspace_path",
            os.path.dirname(self.model_path),
        )
        self.workload_path = data.get(
            "workload_path",
            os.path.join(
                self.workspace_path,
                "workloads",
                f"{self.model_name}_{self.id}",
            ),
        )

        self.set_workspace()
        self.config_name = "config.yaml"
        self.config_path = os.path.join(
            self.workload_path,
            self.config_name,
        )

        model_output_name = (self.model_name + "_int8." +
                             get_file_extension(self.model_path))
        self.model_output_path = os.path.join(
            self.workload_path,
            model_output_name,
        )

        self.eval_dataset_path: str = data.get("eval_dataset_path", "")
        self.calib_dataset_path: str = data.get("eval_dataset_path", "")
        self.set_dataset_paths(data)

        for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]:
            if dataset_path != "no_dataset_location" and not os.path.exists(
                    dataset_path, ):
                raise ClientErrorException(
                    f'Could not found dataset in specified location: "{dataset_path}".',
                )

        if not os.path.isfile(self.model_path):
            raise ClientErrorException(
                f'Could not found model in specified location: "{self.model_path}".',
            )

        self.accuracy_goal: float = data.get("accuracy_goal", 0.01)

        if not os.path.isfile(self.config_path):
            self.config.load(self.predefined_config_path)
        else:
            self.config.load(self.config_path)

        self.config.model.name = self.model_name
        self.config.set_evaluation_dataset_path(self.eval_dataset_path)
        self.config.set_quantization_dataset_path(self.calib_dataset_path)
        self.config.set_workspace(self.workload_path)
        self.config.set_accuracy_goal(self.accuracy_goal)
Ejemplo n.º 6
0
    def __init__(self, data: Dict[str, Any]):
        """Initialize Workload class."""
        super().__init__()
        self.config: Config = Config()

        self.id: str = str(data.get("id", ""))
        if not self.id:
            raise ClientErrorException("Workload ID not specified.")

        self.model_path: str = data.get("model_path", "")
        if not self.model_path:
            raise ClientErrorException("Model path is not defined!")

        self.model_name = Path(self.model_path).stem

        self.domain: str = data.get("domain", None)

        if not self.domain:
            raise ClientErrorException("Domain is not defined!")

        self.framework: str = data.get(
            "framework",
            get_framework_from_path(self.model_path),
        )

        self.workspace_path = data.get(
            "workspace_path",
            os.path.dirname(self.model_path),
        )
        self.workload_path = data.get(
            "workload_path",
            os.path.join(
                self.workspace_path,
                "workloads",
                f"{self.model_name}_{self.id}",
            ),
        )

        self.set_workspace()

        self.eval_dataset_path: str = data.get("eval_dataset_path", "")
        self.calib_dataset_path: str = data.get("eval_dataset_path", "")
        self.set_dataset_paths(data)

        for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]:
            if dataset_path != "no_dataset_location" and not os.path.exists(
                    dataset_path, ):
                raise ClientErrorException(
                    f'Could not found dataset in specified location: "{dataset_path}".',
                )

        if not ModelRepository.is_model_path(self.model_path):
            raise ClientErrorException(
                f'Could not found model in specified location: "{self.model_path}".',
            )

        self.accuracy_goal: float = data.get("accuracy_goal", 0.01)

        self.config_name = "config.yaml"
        self.predefined_config_path = data.get(
            "config_path",
            get_predefined_config_path(self.framework, self.domain),
        )
        self.config_path = os.path.join(
            self.workload_path,
            self.config_name,
        )

        self.input_precision = Precisions.FP32  # TODO: Detect input model precision
        self.output_precision = data.get("precision",
                                         data.get("output_precision"))

        self.mode = self.get_optimization_mode()
        self.tune = data.get("tune", self.is_tuning_enabled(data))

        self.initialize_config(data)

        self.input_nodes: Optional[str] = data.get("inputs",
                                                   data.get("input_nodes"))
        self.output_nodes: Optional[str] = data.get("outputs",
                                                    data.get("output_nodes"))

        self.model_output_path = os.path.join(
            self.workload_path,
            self.model_output_name,
        )
        self.version = "2.0"