Exemplo n.º 1
0
    def _save_project_model(model: ProjectModel, path: str):
        validate_onnx_file(path)

        try:
            model.setup_filesystem()
            model.file = "model.onnx"
            shutil.copy(path, model.file_path)
            # revalidate to make sure the copy worked
            validate_onnx_file(model.file_path)
            model.save()
        except Exception as err:
            _LOGGER.error(
                "error while creating new project model, rolling back: {}".
                format(err))

            try:
                os.remove(model.file_path)
            except OSError:
                pass
            try:
                model.file = None
                model.save()
            except Exception as rollback_err:
                _LOGGER.error(
                    "error while rolling back new project model: {}".format(
                        rollback_err))
            raise err
Exemplo n.º 2
0
def get_model_details(project_id: str) -> Tuple[Response, int]:
    """
    Route to get the model details for a given project matching the project_id.
    Raises an HTTPNotFoundError if the project is not found in the database
    or a model doesn't exit.

    :param project_id: the project_id to get the model details for
    :return: a tuple containing (json response, http status code)
    """
    _LOGGER.info(
        "getting the model details for project_id {}".format(project_id))
    project = get_project_by_id(project_id)

    query = ProjectModel.select(ProjectModel).where(
        ProjectModel.project == project)
    project_model = None

    for res in query:
        project_model = res

    if project_model is None:
        raise HTTPNotFoundError(
            "could not find model for project_id {}".format(project_id))

    resp_model = data_dump_and_validation(ResponseProjectModelSchema(),
                                          {"model": project_model})
    _LOGGER.info("retrieved project model details {}".format(resp_model))

    return jsonify(resp_model), HTTPStatus.OK.value
Exemplo n.º 3
0
def load_model_from_repo(project_id: str) -> Tuple[Response, int]:
    """
    Route for loading a model for a project from the Neural Magic model repo.
    Starts a background job in the JobWorker setup to run.
    The state of the job can be checked after.
    Raises an HTTPNotFoundError if the project is not found in the database.

    :param project_id: the id of the project to load the model for
    :return: a tuple containing (json response, http status code)
    """
    _LOGGER.info(
        "loading model from repo for project {} for request json {}".format(
            project_id, request.json))
    project = _add_model_check(project_id)
    data = SetProjectModelFromSchema().load(request.get_json(force=True))
    project_model = None
    job = None

    try:
        project_model = ProjectModel.create(project=project,
                                            source="downloaded_repo",
                                            job=None)
        job = Job.create(
            project_id=project.project_id,
            type_=ModelFromRepoJobWorker.get_type(),
            worker_args=ModelFromRepoJobWorker.format_args(
                model_id=project_model.model_id,
                uri=data["uri"],
            ),
        )
        project_model.job = job
        project_model.save()
        project_model.setup_filesystem()
        project_model.validate_filesystem()
    except Exception as err:
        _LOGGER.error(
            "error while creating new project model, rolling back: {}".format(
                err))
        if project_model:
            try:
                project_model.delete_instance()
            except Exception as rollback_err:
                _LOGGER.error("error while rolling back new model: {}".format(
                    rollback_err))
        if job:
            try:
                job.delete_instance()
            except Exception as rollback_err:
                _LOGGER.error("error while rolling back new model: {}".format(
                    rollback_err))
        raise err

    # call into JobWorkerManager to kick off job if it's not already running
    JobWorkerManager().refresh()

    resp_model = data_dump_and_validation(ResponseProjectModelSchema(),
                                          {"model": project_model})
    _LOGGER.info("created project model from repo {}".format(resp_model))

    return jsonify(resp_model), HTTPStatus.OK.value
Exemplo n.º 4
0
    def _get_project_model(self) -> ProjectModel:
        model = ProjectModel.get_or_none(ProjectModel.project_id == self._project_id)

        if model is None:
            raise ValueError(
                "ProjectModel with project_id {} was not found".format(self._project_id)
            )

        return model
Exemplo n.º 5
0
    def _get_project_model(self) -> ProjectModel:
        """
        :return: the project's model matching the given ids
        """
        model = ProjectModel.get_or_none(ProjectModel.model_id == self.model_id)

        if model is None:
            raise ValueError("could not find model_id of {}".format(self.model_id))

        return model
Exemplo n.º 6
0
def upload_model(project_id: str) -> Tuple[Response, int]:
    """
    Route for uploading a model file to a project.
    Raises an HTTPNotFoundError if the project is not found in the database.

    :param project_id: the id of the project to upload the model for
    :return: a tuple containing (json response, http status code)
    """
    _LOGGER.info("uploading model for project {}".format(project_id))
    project = _add_model_check(project_id)

    if "model_file" not in request.files:
        _LOGGER.error("missing uploaded file 'model_file'")
        raise ValidationError("missing uploaded file 'model_file'")

    model_file = request.files["model_file"]
    project_model = None

    with NamedTemporaryFile() as temp:
        # Verify onnx model is valid and contains opset field
        tempname = os.path.join(gettempdir(), temp.name)
        model_file.save(tempname)
        validate_onnx_file(tempname)

        try:
            # Create project model
            data = CreateUpdateProjectModelSchema().dump({
                "file": "model.onnx",
                "source": "uploaded"
            })
            project_model = ProjectModel.create(project=project, **data)
            project_model.setup_filesystem()
            shutil.copy(tempname, project_model.file_path)
            project_model.validate_filesystem()
        except Exception as err:
            _LOGGER.error(
                "error while creating new project model, rolling back: {}".
                format(err))
            if project_model:
                try:
                    project_model.delete_instance()
                except Exception as rollback_err:
                    _LOGGER.error(
                        "error while rolling back new model: {}".format(
                            rollback_err))
            raise err

        resp_model = data_dump_and_validation(ResponseProjectModelSchema(),
                                              {"model": project_model})
        _LOGGER.info("created project model {}".format(resp_model))

        return jsonify(resp_model), HTTPStatus.OK.value
Exemplo n.º 7
0
def get_project_model_by_project_id(
    project_id: str, raise_not_found: bool = True
) -> ProjectModel:
    """
    Get a project model by its project_id

    :param project_id: project id of the project model
    :param raise_not_found: if no model is found raise an HTTPNotFoundError,
        otherwise return the result no matter what
    :return: Project model with the project id
    """
    query = ProjectModel.get_or_none(ProjectModel.project_id == project_id)

    if query is None and raise_not_found:
        _LOGGER.error("could not find model for project_id {}".format(project_id))
        raise HTTPNotFoundError(
            "could not find model for project_id {}".format(project_id)
        )

    return query