Exemple #1
0
def upload_loss_profile(project_id: str):
    """
    Route for creating a new loss profile for a given project from uploaded data.
    Raises an HTTPNotFoundError if the project is not found in the database.

    :param project_id: the id of the project to create a loss profile for
    :return: a tuple containing (json response, http status code)
    """
    _LOGGER.info("uploading loss profile for project {}".format(project_id))

    project = get_project_by_id(project_id)  # validate id

    if "loss_file" not in request.files:
        _LOGGER.error("missing uploaded file 'loss_file'")
        raise ValidationError("missing uploaded file 'loss_file'")

    # read loss analysis file
    try:
        loss_analysis = json.load(request.files["loss_file"])  # type: Dict
    except Exception as err:
        _LOGGER.error(
            "error while reading uploaded loss analysis file: {}".format(err))
        raise ValidationError(
            "error while reading uploaded loss analysis file: {}".format(err))

    # override or default potential previous data fields
    loss_analysis_args = CreateProjectLossProfileSchema().load(loss_analysis)
    loss_analysis.update(loss_analysis_args)
    loss_analysis["profile_id"] = "<none>"
    loss_analysis["project_id"] = "<none>"
    loss_analysis["created"] = datetime.datetime.now()
    loss_analysis["source"] = "uploaded"
    loss_analysis["job"] = None

    loss_analysis = data_dump_and_validation(ProjectLossProfileSchema(),
                                             loss_analysis)
    del loss_analysis["profile_id"]  # delete to create a new one on DB insert
    del loss_analysis[
        "project_id"]  # delete because project is passed in on DB insert

    model = project.model
    if model is None:
        raise ValidationError(
            ("A model has not been set for the project with id {}, "
             "project must set a model before running a loss profile."
             ).format(project_id))

    loss_profile = ProjectLossProfile.create(project=project, **loss_analysis)

    resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(),
                                            {"profile": loss_profile})
    _LOGGER.info("created loss profile: id: {}, name: {}".format(
        resp_profile["profile"]["profile_id"],
        resp_profile["profile"]["name"]))

    return jsonify(resp_profile), HTTPStatus.OK.value
Exemple #2
0
def create_loss_profile(project_id: str):
    """
    Route for creating a new loss profile for a given project.
    Raises an HTTPNotFoundError if the project is not found in the database.

    :param project_id: the id of the project to create a loss profile for
    :return: a tuple containing (json response, http status code)
    """
    _LOGGER.info(
        "creating loss profile for project {} for request json {}".format(
            project_id, request.json))
    project = get_project_by_id(project_id)

    loss_profile_params = CreateProjectLossProfileSchema().load(
        request.get_json(force=True))

    model = project.model
    if model is None:
        raise ValidationError(
            ("A model has not been set for the project with id {}, "
             "project must set a model before running a loss profile."
             ).format(project_id))
    loss_profile = None
    job = None

    try:
        loss_profile = ProjectLossProfile.create(project=project,
                                                 source="generated",
                                                 **loss_profile_params)
        job = Job.create(
            project_id=project_id,
            type_=CreateLossProfileJobWorker.get_type(),
            worker_args=CreateLossProfileJobWorker.format_args(
                model_id=model.model_id,
                profile_id=loss_profile.profile_id,
                pruning_estimations=loss_profile_params["pruning_estimations"],
                pruning_estimation_type=loss_profile_params[
                    "pruning_estimation_type"],
                pruning_structure=loss_profile_params["pruning_structure"],
                quantized_estimations=loss_profile_params[
                    "quantized_estimations"],
            ),
        )
        loss_profile.job = job
        loss_profile.save()
    except Exception as err:
        _LOGGER.error(
            "error while creating new loss profile, rolling back: {}".format(
                err))
        if loss_profile:
            try:
                loss_profile.delete_instance()
            except Exception as rollback_err:
                _LOGGER.error(
                    "error while rolling back new loss profile: {}".format(
                        rollback_err))
        if job:
            try:
                job.delete_instance()
            except Exception as rollback_err:
                _LOGGER.error(
                    "error while rolling back new loss profile: {}".format(
                        rollback_err))
        raise err

    # call into JobWorkerManager to kick off job if it's not already running
    JobWorkerManager().refresh()

    resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(),
                                            {"profile": loss_profile})
    _LOGGER.info("created loss profile and job: {}".format(resp_profile))

    return jsonify(resp_profile), HTTPStatus.OK.value