def upload_perf_profile(project_id: str): """ Route for creating a new perf profile for a given project from uploaded data. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a perf profile for :return: a tuple containing (json response, http status code) """ _LOGGER.info("uploading perf profile for project {}".format(project_id)) project = get_project_by_id(project_id) # validate id if "perf_file" not in request.files: _LOGGER.error("missing uploaded file 'perf_file'") raise ValidationError("missing uploaded file 'perf_file'") # read perf analysis file try: perf_analysis = json.load(request.files["perf_file"]) except Exception as err: _LOGGER.error( "error while reading uploaded perf analysis file: {}".format(err)) raise ValidationError( "error while reading uploaded perf analysis file: {}".format(err)) # override or default potential previous data fields perf_analysis_args = CreateProjectPerfProfileSchema().load(perf_analysis) perf_analysis.update(perf_analysis_args) perf_analysis["profile_id"] = "<none>" perf_analysis["project_id"] = "<none>" perf_analysis["created"] = datetime.datetime.now() perf_analysis["source"] = "uploaded" perf_analysis["job"] = None perf_analysis = data_dump_and_validation(ProjectPerfProfileSchema(), perf_analysis) del perf_analysis["profile_id"] # delete to create a new one on DB insert del perf_analysis[ "project_id"] # delete because project is passed in on DB insert model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a perf profile." ).format(project_id)) perf_profile = ProjectPerfProfile.create(project=project, **perf_analysis) resp_profile = data_dump_and_validation(ResponseProjectPerfProfileSchema(), {"profile": perf_profile}) _LOGGER.info("created perf profile: id: {}, name: {}".format( resp_profile["profile"]["profile_id"], resp_profile["profile"]["name"])) return jsonify(resp_profile), HTTPStatus.OK.value
def get_loss_profile(project_id: str, profile_id: str): """ Route for getting a specific loss profile for a given project. Raises an HTTPNotFoundError if the project or loss profile are not found in the database. :param project_id: the id of the project to get the loss profile for :param profile_id: the id of the loss profile to get :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting loss profile for project {} with id {}".format( project_id, profile_id)) get_project_by_id(project_id) # validate id # search for loss profile and verify that project_id matches loss_profile = ProjectLossProfile.get_or_none( ProjectLossProfile.profile_id == profile_id, ProjectLossProfile.project_id == project_id, ) if loss_profile is None: _LOGGER.error( "could not find loss profile with profile_id {} and project_id {}". format(profile_id, project_id)) raise HTTPNotFoundError( "could not find loss profile with profile_id {} and project_id {}". format(profile_id, project_id)) resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(), {"profile": loss_profile}) _LOGGER.info( "found loss profile with profile_id {} and project_id: {}".format( profile_id, project_id)) return jsonify(resp_profile), HTTPStatus.OK.value
def get_perf_profiles(project_id: str): """ Route for getting a list of project perf profiles filtered by the flask request args Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to get the perf profiles for :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting perf profiles for project {} ".format(project_id)) get_project_by_id(project_id) # validate id args = SearchProjectProfilesSchema().load( {key: val for key, val in request.args.items()}) perf_profiles_query = (ProjectPerfProfile.select().where( ProjectPerfProfile.project_id == project_id).order_by( ProjectPerfProfile.created).paginate(args["page"], args["page_length"])) perf_profiles = [res for res in perf_profiles_query] resp_profiles = data_dump_and_validation( ResponseProjectPerfProfilesSchema(), {"profiles": perf_profiles}) _LOGGER.info("retrieved {} profiles".format(len(perf_profiles))) return jsonify(resp_profiles), HTTPStatus.OK.value
def delete_analysis(project_id: str) -> Tuple[Response, int]: """ Route for deleting the model analysis for a given project matching the project_id. Raises an HTTPNotFoundError if the project is not found in the database or a model doesn't exit. :param project_id: the project_id to delete the analysis for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "deleting analysis for project_id {} model".format(project_id)) get_project_by_id(project_id) project_model = get_project_model_by_project_id(project_id) project_model.analysis = None project_model.save() resp_deleted = data_dump_and_validation( ResponseProjectModelDeletedSchema(), { "project_id": project_id, "model_id": project_model.model_id }, ) _LOGGER.info( "deleted model analysis for project_id {} and model_id {} from path {}" .format(project_id, project_model.model_id, project_model.file_path)) return jsonify(resp_deleted), HTTPStatus.OK.value
def get_analysis(project_id: str) -> Tuple[Response, int]: """ Route for getting the model analysis for a given project matching the project_id. Raises an HTTPNotFoundError if the project is not found in the database or a model doesn't exit. :param project_id: the project_id to get the analysis for :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting model analysis for project_id {}".format(project_id)) get_project_by_id(project_id) project_model = get_project_model_by_project_id(project_id) analysis = (ProjectModelAnalysisSchema().dump(project_model.analysis) if project_model.analysis else None) if analysis is None: raise ValidationError("analysis must be created first") resp_analysis = data_dump_and_validation( ResponseProjectModelAnalysisSchema(), {"analysis": analysis}) _LOGGER.info( "retrieved model analysis for project_id {} and model_id {}".format( project_id, project_model.model_id)) return jsonify(resp_analysis), HTTPStatus.OK.value
def create_analysis(project_id: str) -> Tuple[Response, int]: """ Route for creating a model analysis for a given project matching the project_id. If one exists, will overwrite the previous. Raises an HTTPNotFoundError if the project is not found in the database or a model doesn't exit. :param project_id: the project_id to create the analysis for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "creating analysis for project_id {} model".format(project_id)) get_project_by_id(project_id) project_model = get_project_model_by_project_id(project_id) project_model.validate_filesystem() analyzer = ModelAnalyzer(project_model.file_path) analysis = ProjectModelAnalysisSchema().load(analyzer.dict()) project_model.analysis = analysis project_model.save() resp_analysis = data_dump_and_validation( ResponseProjectModelAnalysisSchema(), {"analysis": analysis}) _LOGGER.info( "analyzed model for project_id {} and model_id {} from path {}".format( project_id, project_model.model_id, project_model.file_path)) return jsonify(resp_analysis), HTTPStatus.OK.value
def get_model_details(project_id: str) -> Tuple[Response, int]: """ Route to get the model details for a given project matching the project_id. Raises an HTTPNotFoundError if the project is not found in the database or a model doesn't exit. :param project_id: the project_id to get the model details for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "getting the model details for project_id {}".format(project_id)) project = get_project_by_id(project_id) query = ProjectModel.select(ProjectModel).where( ProjectModel.project == project) project_model = None for res in query: project_model = res if project_model is None: raise HTTPNotFoundError( "could not find model for project_id {}".format(project_id)) resp_model = data_dump_and_validation(ResponseProjectModelSchema(), {"model": project_model}) _LOGGER.info("retrieved project model details {}".format(resp_model)) return jsonify(resp_model), HTTPStatus.OK.value
def load_model_from_repo(project_id: str) -> Tuple[Response, int]: """ Route for loading a model for a project from the Neural Magic model repo. Starts a background job in the JobWorker setup to run. The state of the job can be checked after. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to load the model for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "loading model from repo for project {} for request json {}".format( project_id, request.json)) project = _add_model_check(project_id) data = SetProjectModelFromSchema().load(request.get_json(force=True)) project_model = None job = None try: project_model = ProjectModel.create(project=project, source="downloaded_repo", job=None) job = Job.create( project_id=project.project_id, type_=ModelFromRepoJobWorker.get_type(), worker_args=ModelFromRepoJobWorker.format_args( model_id=project_model.model_id, uri=data["uri"], ), ) project_model.job = job project_model.save() project_model.setup_filesystem() project_model.validate_filesystem() except Exception as err: _LOGGER.error( "error while creating new project model, rolling back: {}".format( err)) if project_model: try: project_model.delete_instance() except Exception as rollback_err: _LOGGER.error("error while rolling back new model: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error("error while rolling back new model: {}".format( rollback_err)) raise err # call into JobWorkerManager to kick off job if it's not already running JobWorkerManager().refresh() resp_model = data_dump_and_validation(ResponseProjectModelSchema(), {"model": project_model}) _LOGGER.info("created project model from repo {}".format(resp_model)) return jsonify(resp_model), HTTPStatus.OK.value
def delete_benchmark(project_id: str, benchmark_id: str): """ Route for deleting a specific benchmark for a given project. Raises an HTTPNotFoundError if the project or benchmark are not found in the database. :param project_id: the id of the project to delete the benchmark for :param benchmark_id: the id of the benchmark to delete :return: a tuple containing (json response, http status code) """ _LOGGER.info("deleting benchmark for project {} with id {}".format( project_id, benchmark_id)) get_project_by_id(project_id) benchmark = get_project_benchmark_by_ids(project_id, benchmark_id) benchmark.delete_instance() resp_del = data_dump_and_validation( ResponseProjectBenchmarkDeletedSchema(), { "success": True, "project_id": project_id, "benchmark_id": benchmark_id }, ) _LOGGER.info( "deleted benchmark with benchmark_id {} and project_id: {}".format( benchmark_id, project_id)) return jsonify(resp_del), HTTPStatus.OK.value
def get_data_details(project_id: str): """ Route to get the details for all data for a given project matching the project_id. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the project_id to get the data details for :return: a tuple containing (json response, http status code) """ args = {key: val for key, val in request.args.items()} _LOGGER.info( "getting all the data for project_id {} and request args {}".format( project_id, args)) args = SearchProjectDataSchema().load(args) # Validate project and model get_project_by_id(project_id) get_project_model_by_project_id(project_id) project_data = (ProjectData.select().where( ProjectData.project_id == project_id).group_by(ProjectData).order_by( ProjectData.created).paginate(args["page"], args["page_length"])) resp_data = data_dump_and_validation(ResponseProjectDataSchema(), {"data": project_data}) _LOGGER.info("sending project data {}".format(resp_data)) return jsonify(resp_data), HTTPStatus.OK.value
def info(): """ Route for getting the info describing the current system the server is running on :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting system info") sys_info = get_ml_sys_info() resp_info = data_dump_and_validation(ResponseSystemInfo(), {"info": sys_info}) _LOGGER.info("retrieved system info {}".format(resp_info)) return jsonify(resp_info), HTTPStatus.OK.value
def upload_model(project_id: str) -> Tuple[Response, int]: """ Route for uploading a model file to a project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to upload the model for :return: a tuple containing (json response, http status code) """ _LOGGER.info("uploading model for project {}".format(project_id)) project = _add_model_check(project_id) if "model_file" not in request.files: _LOGGER.error("missing uploaded file 'model_file'") raise ValidationError("missing uploaded file 'model_file'") model_file = request.files["model_file"] project_model = None with NamedTemporaryFile() as temp: # Verify onnx model is valid and contains opset field tempname = os.path.join(gettempdir(), temp.name) model_file.save(tempname) validate_onnx_file(tempname) try: # Create project model data = CreateUpdateProjectModelSchema().dump({ "file": "model.onnx", "source": "uploaded" }) project_model = ProjectModel.create(project=project, **data) project_model.setup_filesystem() shutil.copy(tempname, project_model.file_path) project_model.validate_filesystem() except Exception as err: _LOGGER.error( "error while creating new project model, rolling back: {}". format(err)) if project_model: try: project_model.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new model: {}".format( rollback_err)) raise err resp_model = data_dump_and_validation(ResponseProjectModelSchema(), {"model": project_model}) _LOGGER.info("created project model {}".format(resp_model)) return jsonify(resp_model), HTTPStatus.OK.value
def _run_benchmark( self, benchmark: ProjectBenchmark, model: Union[str, ModelProto], runner: ModelRunner, core_count: int, batch_size: int, inference_engine: str, inference_model_optimization: Union[str, None], num_steps: int, step_index: int, ): data_iter = DataLoader.from_model_random(model, batch_size=batch_size, iter_steps=-1) measurements = [] total_iterations = self.warmup_iterations_per_check + self.iterations_per_check iterations = 0 for _, current_measurements in runner.run_iter( data_iter, show_progress=False, max_steps=total_iterations, ): measurements.append(current_measurements) iteration_percent = (iterations + 1) / (total_iterations) iter_val = (step_index + iteration_percent) / num_steps yield iter_val iterations += 1 if self.warmup_iterations_per_check > 0: measurements = measurements[self.warmup_iterations_per_check:] result = data_dump_and_validation( ProjectBenchmarkResultSchema(), { "core_count": core_count, "batch_size": batch_size, "inference_engine": inference_engine, "inference_model_optimization": inference_model_optimization, "measurements": measurements, }, ) benchmark.result["benchmarks"].append(result)
def handle_client_error(error: ValidationError): """ Handle an error that occurred in the flask app that should return a 400 response :param error: the error that occurred :return: a tuple containing (json response, http status code [400]) """ _LOGGER.error( "handling client error, returning 400 status: {}".format(error)) resp_error = data_dump_and_validation( ErrorSchema(), { "error_type": error.__class__.__name__, "error_message": str(error) }, ) return jsonify(resp_error), HTTPStatus.BAD_REQUEST
def handle_not_found_error(error: HTTPNotFoundError): """ Handle an error that occurred in the flask app that should return a 404 response :param error: the error that occurred :return: a tuple containing (json response, http status code [404]) """ _LOGGER.error( "handling not found error, returning 404 status: {}".format(error)) resp_error = data_dump_and_validation( ErrorSchema(), { "error_type": error.__class__.__name__, "error_message": str(error) }, ) return jsonify(resp_error), HTTPStatus.NOT_FOUND
def delete_perf_profile(project_id: str, profile_id: str): """ Route for deleting a specific perf profile for a given project. Raises an HTTPNotFoundError if the project or perf profile are not found in the database. :param project_id: the id of the project to delete the perf profile for :param profile_id: the id of the perf profile to delete :return: a tuple containing (json response, http status code) """ _LOGGER.info("deleting perf profile for project {} with id {}".format( project_id, profile_id)) get_project_by_id(project_id) # validate id # search for perf profile and verify that project_id matches perf_profile = ProjectPerfProfile.get_or_none( ProjectPerfProfile.profile_id == profile_id, ProjectPerfProfile.project_id == project_id, ) if perf_profile is None: _LOGGER.error( "could not find perf profile with profile_id {} and project_id {}". format(profile_id, project_id)) raise HTTPNotFoundError( "could not find perf profile with profile_id {} and project_id {}". format(profile_id, project_id)) perf_profile.delete_instance() resp_del = data_dump_and_validation( ResponseProjectProfileDeletedSchema(), { "success": True, "project_id": project_id, "profile_id": profile_id }, ) _LOGGER.info( "deleted perf profile with profile_id {} and project_id: {}".format( profile_id, project_id)) return jsonify(resp_del), HTTPStatus.OK.value
def handle_unexpected_error(error: Exception): """ Handle an error that occurred in the flask app that was not expected. Will return as a 500 representing a server error, 500 :param error: the error that occurred :return: a tuple containing (json response, http status code [500]) """ _LOGGER.error( "handling unexpected error, returning 500 status: {}".format(error)) resp_error = data_dump_and_validation( ErrorSchema(), { "error_type": error.__class__.__name__, "error_message": str(error) }, ) return jsonify(resp_error), HTTPStatus.INTERNAL_SERVER_ERROR
def get_data_single_details(project_id: str, data_id: str): """ Route to get the details for all data for a given project matching the project_id. Raises an HTTPNotFoundError if the project or data is not found in the database. :param project_id: the project_id to get the data details for :param data_id: the data_id to get the data details for :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting the data with data_id {} for project_id {}".format( data_id, project_id)) get_project_by_id(project_id) get_project_model_by_project_id(project_id) project_data = get_project_data_by_ids(project_id, data_id) resp_data = data_dump_and_validation(ResponseProjectDataSingleSchema(), {"data": project_data}) _LOGGER.info("sending project data from {}".format(project_data.file_path)) return jsonify(resp_data), HTTPStatus.OK.value
def delete_model(project_id: str) -> Tuple[Response, int]: """ Route to delete the model for a given project matching the project_id. Raises an HTTPNotFoundError if the project is not found in the database or a model doesn't exit. :param project_id: the project_id to delete the model for :return: a tuple containing (json response, http status code) """ _LOGGER.info("deleting the model for project_id {}".format(project_id)) args = DeleteProjectModelSchema().load( {key: val for key, val in request.args.items()}) get_project_by_id(project_id) project_model = get_project_model_by_project_id(project_id) model_id = project_model.model_id try: project_model.delete_instance() project_model.delete_filesystem() except Exception as err: _LOGGER.error( "error while deleting project model for {}, rolling back: {}". format(project_id, err)) if not args["force"]: raise err resp_deleted = data_dump_and_validation( ResponseProjectModelDeletedSchema(), { "project_id": project_id, "model_id": model_id }, ) _LOGGER.info( "deleted model for project_id {} and model_id {} from path {}".format( project_id, project_model.model_id, project_model.dir_path)) return jsonify(resp_deleted), HTTPStatus.OK.value
def optim_trainable_default_nodes( default_trainable: bool, model_analysis: Dict, node_overrides: Union[None, List[Dict[str, Any]]] = None, ) -> List[Dict[str, Any]]: """ Create the default trainable nodes for optimizing a model. Creates a node for all prunable nodes in the model with trainable set to default_trainable. :param default_trainable: True to default all prunable nodes to trainable, False otherwise :param model_analysis: the analysis for the model :param node_overrides: specific node overrides to use instead of default_trainable :return: the default trainable nodes """ trainable_nodes = [] for node in model_analysis["nodes"]: if not node["prunable"]: continue trainable_val = default_trainable if node_overrides: for override in node_overrides: if override["node_id"] == node["id"]: trainable_val = override["trainable"] break trainable_nodes.append( data_dump_and_validation( ProjectOptimizationModifierTrainableNodeSchema(), { "node_id": node["id"], "trainable": trainable_val }, )) return trainable_nodes
def delete_data(project_id: str, data_id: str): """ Route to delete a data file for a given project matching the project_id. Raises an HTTPNotFoundError if the project or data is not found in the database. :param project_id: the project_id to get the data for :param data_id: the data_id to get the data for :return: a tuple containing (json response, http status code) """ _LOGGER.info("deleting data with data_id {} for project_id {}".format( data_id, project_id)) get_project_by_id(project_id) get_project_model_by_project_id(project_id) project_data = get_project_data_by_ids(project_id, data_id) args = {key: val for key, val in request.args.items()} try: project_data.delete_instance() project_data.delete_filesystem() except Exception as err: _LOGGER.error( "error while deleting project data for {}, rolling back: {}". format(data_id, err)) if not args["force"]: raise err resp_deleted = data_dump_and_validation( ResponseProjectDataDeletedSchema(), { "project_id": project_id, "data_id": data_id }, ) _LOGGER.info( "deleted data for project_id {} and data_id {} from path {}".format( project_id, project_data.data_id, project_data.dir_path)) return jsonify(resp_deleted), HTTPStatus.OK.value
def get_benchmark(project_id: str, benchmark_id: str): """ Route for getting a specific benchmark for a given project. Raises an HTTPNotFoundError if the project or benchmark are not found in the database. :param project_id: the id of the project to get the benchmark for :param benchmark_id: the id of the benchmark to get :return: a tuple containing (json response, http status code) """ _LOGGER.info("getting benchmark for project {} with id {}".format( project_id, benchmark_id)) get_project_by_id(project_id) benchmark = get_project_benchmark_by_ids(project_id, benchmark_id) resp_benchmark = data_dump_and_validation(ResponseProjectBenchmarkSchema(), {"benchmark": benchmark}) _LOGGER.info( "found benchmark with benchmark_id {} and project_id: {}".format( benchmark_id, project_id)) return jsonify(resp_benchmark), HTTPStatus.OK.value
def get_benchmarks(project_id: str): """ Route for getting a list of benchmarks for a given project filtered by the flask request args. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to get benchmarks for :return: a tuple containing (json response, http status code) """ args = {key: val for key, val in request.args.items()} _LOGGER.info( "getting project benchmark for project_id {}".format(project_id)) args = SearchProjectBenchmarksSchema().load(args) query = (ProjectBenchmark.select().where( ProjectBenchmark.project_id == project_id).order_by( ProjectBenchmark.created).paginate(args["page"], args["page_length"])) benchmarks = [res for res in query] resp_benchmarks = data_dump_and_validation( ResponseProjectBenchmarksSchema(), {"benchmarks": benchmarks}) _LOGGER.info("retrieved {} benchmarks".format(len(benchmarks))) return jsonify(resp_benchmarks), HTTPStatus.OK.value
def create_perf_profile(project_id: str): """ Route for creating a new perf profile for a given project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a perf profile for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "creating perf profile for project {} for request json {}".format( project_id, request.json)) project = get_project_by_id(project_id) perf_profile_params = CreateProjectPerfProfileSchema().load( request.get_json(force=True)) sys_info = get_ml_sys_info() if not perf_profile_params[ "core_count"] or perf_profile_params["core_count"] < 1: perf_profile_params["core_count"] = sys_info["cores_per_socket"] if not perf_profile_params["core_count"]: # extra check in case the system couldn't get cores_per_socket perf_profile_params["core_count"] = -1 perf_profile_params["instruction_sets"] = sys_info[ "available_instructions"] model = project.model if model is None: raise ValidationError( ("A model is has not been set for the project with id {}, " "project must set a model before running a perf profile." ).format(project_id)) perf_profile = None job = None try: perf_profile = ProjectPerfProfile.create(project=project, source="generated", **perf_profile_params) job = Job.create( project_id=project_id, type_=CreatePerfProfileJobWorker.get_type(), worker_args=CreatePerfProfileJobWorker.format_args( model_id=model.model_id, profile_id=perf_profile.profile_id, batch_size=perf_profile_params["batch_size"], core_count=perf_profile_params["core_count"], pruning_estimations=perf_profile_params["pruning_estimations"], quantized_estimations=perf_profile_params[ "quantized_estimations"], iterations_per_check=perf_profile_params[ "iterations_per_check"], warmup_iterations_per_check=perf_profile_params[ "warmup_iterations_per_check"], ), ) perf_profile.job = job perf_profile.save() except Exception as err: _LOGGER.error( "error while creating new perf profile, rolling back: {}".format( err)) if perf_profile: try: perf_profile.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new perf profile: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new perf profile: {}".format( rollback_err)) raise err # call into JobWorkerManager to kick off job if it's not already running JobWorkerManager().refresh() resp_profile = data_dump_and_validation(ResponseProjectPerfProfileSchema(), {"profile": perf_profile}) _LOGGER.info("created perf profile and job: {}".format(resp_profile)) return jsonify(resp_profile), HTTPStatus.OK.value
def create_benchmark(project_id: str): """ Route for creating a new benchmark for a given project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a benchmark for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "creating benchmark for project {} for request json {}".format( project_id, request.get_json())) project = get_project_by_id(project_id) benchmark_params = CreateProjectBenchmarkSchema().load( request.get_json(force=True)) model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a benchmark." ).format(project_id)) sys_info = get_ml_sys_info() benchmark = None job = None try: benchmark_params["instruction_sets"] = ( sys_info["available_instructions"] if "available_instructions" in sys_info else []) benchmark = ProjectBenchmark.create(project=project, source="generated", **benchmark_params) job = Job.create( project_id=project.project_id, type_=CreateBenchmarkJobWorker.get_type(), worker_args=CreateBenchmarkJobWorker.format_args( model_id=model.model_id, benchmark_id=benchmark.benchmark_id, core_counts=benchmark.core_counts, batch_sizes=benchmark.batch_sizes, instruction_sets=benchmark.instruction_sets, inference_models=benchmark.inference_models, warmup_iterations_per_check=benchmark. warmup_iterations_per_check, iterations_per_check=benchmark.iterations_per_check, ), ) benchmark.job = job benchmark.save() except Exception as err: _LOGGER.error( "error while creating new benchmark, rolling back: {}".format(err)) if benchmark: try: benchmark.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new benchmark: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new benchmark: {}".format( rollback_err)) raise err JobWorkerManager().refresh() resp_benchmark = data_dump_and_validation(ResponseProjectBenchmarkSchema(), {"benchmark": benchmark}) _LOGGER.info("created benchmark and job: {}".format(resp_benchmark)) return jsonify(resp_benchmark), HTTPStatus.OK.value
def create_loss_profile(project_id: str): """ Route for creating a new loss profile for a given project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a loss profile for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "creating loss profile for project {} for request json {}".format( project_id, request.json)) project = get_project_by_id(project_id) loss_profile_params = CreateProjectLossProfileSchema().load( request.get_json(force=True)) model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a loss profile." ).format(project_id)) loss_profile = None job = None try: loss_profile = ProjectLossProfile.create(project=project, source="generated", **loss_profile_params) job = Job.create( project_id=project_id, type_=CreateLossProfileJobWorker.get_type(), worker_args=CreateLossProfileJobWorker.format_args( model_id=model.model_id, profile_id=loss_profile.profile_id, pruning_estimations=loss_profile_params["pruning_estimations"], pruning_estimation_type=loss_profile_params[ "pruning_estimation_type"], pruning_structure=loss_profile_params["pruning_structure"], quantized_estimations=loss_profile_params[ "quantized_estimations"], ), ) loss_profile.job = job loss_profile.save() except Exception as err: _LOGGER.error( "error while creating new loss profile, rolling back: {}".format( err)) if loss_profile: try: loss_profile.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new loss profile: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new loss profile: {}".format( rollback_err)) raise err # call into JobWorkerManager to kick off job if it's not already running JobWorkerManager().refresh() resp_profile = data_dump_and_validation(ResponseProjectLossProfileSchema(), {"profile": loss_profile}) _LOGGER.info("created loss profile and job: {}".format(resp_profile)) return jsonify(resp_profile), HTTPStatus.OK.value
def upload_benchmark(project_id: str): """ Route for creating a new benchmark for a given project from uploaded data. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to create a benchmark for :return: a tuple containing (json response, http status code) """ _LOGGER.info("uploading benchmark for project {}".format(project_id)) project = get_project_by_id(project_id) # validate id if "benchmark_file" not in request.files: _LOGGER.error("missing uploaded file 'benchmark_file'") raise ValidationError("missing uploaded file 'benchmark_file'") # read benchmark file try: benchmark = json.load(request.files["benchmark_file"]) # type: Dict except Exception as err: _LOGGER.error( "error while reading uploaded benchmark file: {}".format(err)) raise ValidationError( "error while reading uploaded benchmark file: {}".format(err)) benchmark["benchmark_id"] = "<none>" benchmark["project_id"] = "<none>" benchmark["source"] = "uploaded" benchmark["job"] = None benchmark["created"] = datetime.datetime.now() benchmark = data_dump_and_validation(ProjectBenchmarkSchema(), benchmark) del benchmark["benchmark_id"] del benchmark["project_id"] del benchmark["created"] model = project.model if model is None: raise ValidationError( ("A model has not been set for the project with id {}, " "project must set a model before running a benchmark." ).format(project_id)) benchmark_model = None try: benchmark_model = ProjectBenchmark.create(project=project, **benchmark) resp_benchmark = data_dump_and_validation( ResponseProjectBenchmarkSchema(), {"benchmark": benchmark_model}) except Exception as err: _LOGGER.error( "error while creating new benchmark, rolling back: {}".format(err)) try: benchmark_model.delete_instance() except Exception as rollback_err: _LOGGER.error("error while rolling back new benchmark: {}".format( rollback_err)) raise err _LOGGER.info("created benchmark: id: {}, name: {}".format( resp_benchmark["benchmark"]["benchmark_id"], resp_benchmark["benchmark"]["name"], )) return jsonify(resp_benchmark), HTTPStatus.OK.value
def optim_lr_sched_default_mods( training_init_lr: Union[float, None], training_final_lr: Union[float, None], start_epoch: Union[float, None], start_fine_tuning_epoch: Union[float, None], end_epoch: Union[float, None], ) -> List[Dict[str, Any]]: """ Default modifiers for an LR schedule for pruning a model. If training_init_lr is set, adds a set LR modifier. If training_init_lr and training_final_lr are set, adds a step LR modifier. :param training_init_lr: the initial LR for training :param training_final_lr: the final LR for training :param start_epoch: the epoch training should start at :param start_fine_tuning_epoch: the epoch fine tuning should start at :param end_epoch: the final epoch for training :return: the default modifiers for an LR schedule """ optim_lr_mods = [] if training_init_lr is not None and start_epoch is not None: pruning_lr = (training_init_lr if not training_final_lr else (training_init_lr + training_final_lr) / 2.0) optim_lr_mods.append( data_dump_and_validation( ProjectOptimizationModifierLRSchema(), { "clazz": "set", "start_epoch": start_epoch, "end_epoch": -1.0, "init_lr": pruning_lr, "args": {}, }, )) if (training_final_lr is not None and start_fine_tuning_epoch is not None and end_epoch is not None): fine_tuning_epochs = end_epoch - start_fine_tuning_epoch gamma = 0.25 init_lr = pruning_lr * gamma target_final_lr = training_final_lr * 0.1 num_steps = math.log(target_final_lr / init_lr) / math.log( gamma) # final_lr = init_lr * gamma ^ n : solve for n step_size = math.floor((fine_tuning_epochs - 1.0) / num_steps) optim_lr_mods.append( data_dump_and_validation( ProjectOptimizationModifierLRSchema(), { "clazz": "step", "start_epoch": start_fine_tuning_epoch, "end_epoch": -1.0, "init_lr": init_lr, "args": { "step_size": step_size, "gamma": gamma }, }, )) return optim_lr_mods
def upload_data(project_id: str): """ Route for uploading a data file to a project. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to upload the data for :return: a tuple containing (json response, http status code) """ _LOGGER.info("uploading data file for project {}".format(project_id)) project = get_project_by_id(project_id) project_model = get_project_model_by_project_id(project_id) if "data_file" not in request.files: _LOGGER.error("missing uploaded file 'data_file'") raise ValidationError("missing uploaded file 'data_file'") data_file = request.files["data_file"] with NamedTemporaryFile() as temp: data_path = gettempdir() tempname = os.path.join(data_path, temp.name) data_file.save(tempname) try: _LOGGER.info(project_model.file_path) validate_model_data(os.path.join(data_path, "*"), project_model.file_path) data = CreateUpdateProjectDataSchema().dump({ "source": "uploaded", "job": None }) project_data = ProjectData.create(project=project, **data) project_data.file = "{}.npz".format(project_data) project_data.setup_filesystem() shutil.copy(tempname, project_data.file_path) project_data.validate_filesystem() validate_model_data(project_data.file_path, project_model.file_path) project_data.save() except Exception as err: if project_data: try: os.remove(project_data.file_path) except OSError: pass try: project_data.delete_instance() except Exception as rollback_err: _LOGGER.error( "error while rolling back new data: {}".format( rollback_err)) _LOGGER.error( "error while creating new project data, rolling back: {}". format(err)) raise err resp_data = data_dump_and_validation(ResponseProjectDataSingleSchema(), {"data": project_data}) _LOGGER.info("created project data {}".format(resp_data)) return jsonify(resp_data), HTTPStatus.OK.value
def load_data_from_repo(project_id: str): """ Route for loading data file(s) for a project from the Neural Magic model repo. Starts a background job in the JobWorker setup to run. The state of the job can be checked after. Raises an HTTPNotFoundError if the project is not found in the database. :param project_id: the id of the project to load the data for :return: a tuple containing (json response, http status code) """ _LOGGER.info( "loading data from repo for project {} for request json {}".format( project_id, request.json)) project = get_project_by_id(project_id) get_project_model_by_project_id(project_id) data = SetProjectDataFromSchema().load(request.get_json(force=True)) project_data = None job = None try: project_data = ProjectData.create(project=project, source="downloaded_path", job=None) job = Job.create( project_id=project.project_id, type_=DataFromRepoJobWorker.get_type(), worker_args=DataFromRepoJobWorker.format_args( data_id=project_data.data_id, uri=data["uri"]), ) project_data.job = job project_data.save() project_data.setup_filesystem() project_data.validate_filesystem() except Exception as err: if project_data: try: os.remove(project_data.file_path) except OSError: pass try: project_data.delete_instance() except Exception as rollback_err: _LOGGER.error("error while rolling back new data: {}".format( rollback_err)) if job: try: job.delete_instance() except Exception as rollback_err: _LOGGER.error("error while rolling back new data: {}".format( rollback_err)) _LOGGER.error( "error while creating new project data, rolling back: {}".format( err)) raise err # call into JobWorkerManager to kick off job if it's not already running JobWorkerManager().refresh() resp_data = data_dump_and_validation(ResponseProjectDataSingleSchema(), {"data": project_data}) _LOGGER.info("created project data from path {}".format(resp_data)) return jsonify(resp_data), HTTPStatus.OK.value