Пример #1
0
def get_predefined_configuration(
    data: Dict[str, Any], ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """Get configuration."""
    from lpot.ux.utils.utils import get_framework_from_path, get_predefined_config_path
    from lpot.ux.utils.workload.config import Config

    model_path = data.get("model_path", "")
    if not os.path.isfile(model_path):
        raise ClientErrorException(
            f"Could not find model in specified path: {model_path}.", )

    model_name = Path(model_path).stem

    domain = data.get("domain", None)

    if not domain:
        raise ClientErrorException("Domain is not defined!")

    framework = get_framework_from_path(model_path)
    if framework is None:
        raise ClientErrorException(
            f"Could not find framework for specified model {model_name} in path {model_path}.",
        )

    config = Config()
    predefined_config_path = get_predefined_config_path(framework, domain)
    config.load(predefined_config_path)

    return {
        "config": config.serialize(),
        "framework": framework,
        "name": model_name,
        "domain": domain,
    }
Пример #2
0
 def test_get_tensorflow_framework_from_path(
         self, mocked_get_model_type: MagicMock) -> None:
     """Test getting framework name from path."""
     mocked_get_model_type.return_value = "frozen_pb"
     path = "/home/user/model.pb"
     result = get_framework_from_path(path)
     self.assertEqual(result, "tensorflow")
     mocked_get_model_type.assert_called_with(path)
Пример #3
0
    def find(self, model_path: str) -> Reader:
        """Find Graph Model Reader for given model."""
        framework = get_framework_from_path(model_path)
        if framework is None:
            raise ClientErrorException(
                f"Models of {model_path} type are not yet supported.", )

        reader_name = self._framework_readers.get(framework)
        if reader_name is None:
            raise ClientErrorException(
                f"Models from {framework} framework are not yet supported.", )

        return reader_name()
Пример #4
0
    def __init__(self, data: Dict[str, Any]):
        """Initialize Workload class."""
        super().__init__()
        self.config: Config = Config()

        self.id: str = str(data.get("id", ""))
        if not self.id:
            raise ClientErrorException("Workload ID not specified.")

        self.model_path: str = data.get("model_path", "")
        if not self.model_path:
            raise ClientErrorException("Model path is not defined!")

        self.model_name = Path(self.model_path).stem

        self.domain: str = data.get("domain", None)

        if not self.domain:
            raise ClientErrorException("Domain is not defined!")

        self.framework: str = data.get(
            "framework",
            get_framework_from_path(self.model_path),
        )
        self.predefined_config_path = data.get(
            "config_path",
            get_predefined_config_path(self.framework, self.domain),
        )
        self.workspace_path = data.get(
            "workspace_path",
            os.path.dirname(self.model_path),
        )
        self.workload_path = data.get(
            "workload_path",
            os.path.join(
                self.workspace_path,
                "workloads",
                f"{self.model_name}_{self.id}",
            ),
        )

        self.set_workspace()
        self.config_name = "config.yaml"
        self.config_path = os.path.join(
            self.workload_path,
            self.config_name,
        )

        model_output_name = (self.model_name + "_int8." +
                             get_file_extension(self.model_path))
        self.model_output_path = os.path.join(
            self.workload_path,
            model_output_name,
        )

        self.eval_dataset_path: str = data.get("eval_dataset_path", "")
        self.calib_dataset_path: str = data.get("eval_dataset_path", "")
        self.set_dataset_paths(data)

        for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]:
            if dataset_path != "no_dataset_location" and not os.path.exists(
                    dataset_path, ):
                raise ClientErrorException(
                    f'Could not found dataset in specified location: "{dataset_path}".',
                )

        if not os.path.isfile(self.model_path):
            raise ClientErrorException(
                f'Could not found model in specified location: "{self.model_path}".',
            )

        self.accuracy_goal: float = data.get("accuracy_goal", 0.01)

        if not os.path.isfile(self.config_path):
            self.config.load(self.predefined_config_path)
        else:
            self.config.load(self.config_path)

        self.config.model.name = self.model_name
        self.config.set_evaluation_dataset_path(self.eval_dataset_path)
        self.config.set_quantization_dataset_path(self.calib_dataset_path)
        self.config.set_workspace(self.workload_path)
        self.config.set_accuracy_goal(self.accuracy_goal)
Пример #5
0
    def __init__(self, data: Dict[str, Any]):
        """Initialize Workload class."""
        super().__init__()
        self.config: Config = Config()

        self.id: str = str(data.get("id", ""))
        if not self.id:
            raise ClientErrorException("Workload ID not specified.")

        self.model_path: str = data.get("model_path", "")
        if not self.model_path:
            raise ClientErrorException("Model path is not defined!")

        self.model_name = Path(self.model_path).stem

        self.domain: str = data.get("domain", None)

        if not self.domain:
            raise ClientErrorException("Domain is not defined!")

        self.framework: str = data.get(
            "framework",
            get_framework_from_path(self.model_path),
        )

        self.workspace_path = data.get(
            "workspace_path",
            os.path.dirname(self.model_path),
        )
        self.workload_path = data.get(
            "workload_path",
            os.path.join(
                self.workspace_path,
                "workloads",
                f"{self.model_name}_{self.id}",
            ),
        )

        self.set_workspace()

        self.eval_dataset_path: str = data.get("eval_dataset_path", "")
        self.calib_dataset_path: str = data.get("eval_dataset_path", "")
        self.set_dataset_paths(data)

        for dataset_path in [self.eval_dataset_path, self.calib_dataset_path]:
            if dataset_path != "no_dataset_location" and not os.path.exists(
                    dataset_path, ):
                raise ClientErrorException(
                    f'Could not found dataset in specified location: "{dataset_path}".',
                )

        if not ModelRepository.is_model_path(self.model_path):
            raise ClientErrorException(
                f'Could not found model in specified location: "{self.model_path}".',
            )

        self.accuracy_goal: float = data.get("accuracy_goal", 0.01)

        self.config_name = "config.yaml"
        self.predefined_config_path = data.get(
            "config_path",
            get_predefined_config_path(self.framework, self.domain),
        )
        self.config_path = os.path.join(
            self.workload_path,
            self.config_name,
        )

        self.input_precision = Precisions.FP32  # TODO: Detect input model precision
        self.output_precision = data.get("precision",
                                         data.get("output_precision"))

        self.mode = self.get_optimization_mode()
        self.tune = data.get("tune", self.is_tuning_enabled(data))

        self.initialize_config(data)

        self.input_nodes: Optional[str] = data.get("inputs",
                                                   data.get("input_nodes"))
        self.output_nodes: Optional[str] = data.get("outputs",
                                                    data.get("output_nodes"))

        self.model_output_path = os.path.join(
            self.workload_path,
            self.model_output_name,
        )
        self.version = "2.0"
Пример #6
0
 def test_get_unknown_framework_from_path(self) -> None:
     """Test getting framework name from path."""
     path = "/home/user/model.some_extension"
     result = get_framework_from_path(path)
     self.assertIsNone(result)
Пример #7
0
 def test_get_onnx_framework_from_path(self) -> None:
     """Test getting framework name from path."""
     path = "/home/user/model.onnx"
     result = get_framework_from_path(path)
     self.assertEqual(result, "onnxrt")
Пример #8
0
def get_boundary_nodes(data: Dict[str, Any]) -> None:
    """Get configuration."""
    from lpot.ux.utils.utils import find_boundary_nodes

    request_id = str(data.get("id", ""))
    model_path = data.get("model_path", None)

    if not (request_id and model_path):
        message = "Missing model path or request id."
        mq.post_error(
            "boundary_nodes_finish",
            {
                "message": message,
                "code": 404,
                "id": request_id
            },
        )
        return

    if not os.path.isfile(model_path):
        message = "Could not found model in specified path."
        mq.post_error(
            "boundary_nodes_finish",
            {
                "message": message,
                "code": 404,
                "id": request_id
            },
        )
        return

    try:
        mq.post_success(
            "boundary_nodes_start",
            {
                "message": "started",
                "id": request_id
            },
        )
        framework = get_framework_from_path(model_path)
        if framework is None:
            supported_frameworks = list(framework_extensions.keys())
            raise ClientErrorException(
                f"Framework for specified model is not yet supported. "
                f"Supported frameworks are: {', '.join(supported_frameworks)}.",
            )
        try:
            check_module(framework)
        except ClientErrorException:
            raise ClientErrorException(
                f"Detected {framework} model. "
                f"Could not find installed {framework} module. "
                f"Please install {framework}.", )
        framework_version = get_module_version(framework)

        response_data = find_boundary_nodes(model_path)
        response_data["id"] = request_id
        response_data["framework"] = framework
        response_data["framework_version"] = framework_version
    except ClientErrorException as err:
        mq.post_error(
            "boundary_nodes_finish",
            {
                "message": str(err),
                "code": 404,
                "id": request_id
            },
        )
        return
    log.debug(f"Parsed data is {json.dumps(response_data)}")
    mq.post_success("boundary_nodes_finish", response_data)
Пример #9
0
 def test_get_tensorflow_framework_from_path(self) -> None:
     """Test getting framework name from path."""
     path = "/home/user/model.pb"
     result = get_framework_from_path(path)
     self.assertEqual(result, "tensorflow")