def upload_loss_profile(project_id: str): """ Route for creating a new loss profile for a given project from uploaded data. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a loss profile for :return: a tuple containing (json response, http status code) """ _LOGGER.info("uploading loss profile for project {}".format(project_id)) project = get_project_by_id(project_id) # validate id if "loss_file" not in request.files: _LOGGER.error("missing uploaded file 'loss_file'") raise ValidationError("missing uploaded file 'loss_file'") # read loss analysis file try: loss_analysis = json.load(request.files["loss_file"]) # type: Dict except Exception as err: _LOGGER.error( "error while reading uploaded loss analysis file: {}".format(err)) raise ValidationError( "error while reading uploaded loss analysis file: {}".format(err)) # override or default potential previous data fields loss_analysis_args = CreateProjectLossProfileSchema().load(loss_analysis) loss_analysis.update(loss_analysis_args) loss_analysis["profile_id"] = "<none>" loss_analysis["project_id"] = "<none>" loss_analysis["created"] = datetime.datetime.now() loss_analysis["source"] = "uploaded" loss_analysis["job"] = None loss_analysis = data_dump_and_validation(ProjectLossProfileSchema(), loss_analysis) del loss_analysis["profile_id"] # delete to create a new one on DB insert del loss_analysis[ "project_id"] # delete because project is passed in on DB insert model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a loss profile." ).format(project_id)) loss_profile = ProjectLossProfile.create(project=project, **loss_analysis) resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(), {"profile": loss_profile}) _LOGGER.info("created loss profile: id: {}, name: {}".format( resp_profile["profile"]["profile_id"], resp_profile["profile"]["name"])) return jsonify(resp_profile), HTTPStatus.OK.value
def create_loss_profile(project_id: str): """ Route for creating a new loss profile for a given project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a loss profile for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "creating loss profile for project {} for request json {}".format( project_id, request.json)) project = get_project_by_id(project_id) loss_profile_params = CreateProjectLossProfileSchema().load( request.get_json(force=True)) model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a loss profile." ).format(project_id)) loss_profile = None job = None try: loss_profile = ProjectLossProfile.create(project=project, source="generated", **loss_profile_params) job = Job.create( project_id=project_id, type_=CreateLossProfileJobWorker.get_type(), worker_args=CreateLossProfileJobWorker.format_args( model_id=model.model_id, profile_id=loss_profile.profile_id, pruning_estimations=loss_profile_params["pruning_estimations"], pruning_estimation_type=loss_profile_params[ "pruning_estimation_type"], pruning_structure=loss_profile_params["pruning_structure"], quantized_estimations=loss_profile_params[ "quantized_estimations"], ), ) loss_profile.job = job loss_profile.save() except Exception as err: _LOGGER.error( "error while creating new loss profile, rolling back: {}".format( err)) if loss_profile: try: loss_profile.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new loss profile: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new loss profile: {}".format( rollback_err)) raise err # call into JobWorkerManager to kick off job if it's not already running JobWorkerManager().refresh() resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(), {"profile": loss_profile}) _LOGGER.info("created loss profile and job: {}".format(resp_profile)) return jsonify(resp_profile), HTTPStatus.OK.value