def _restore_run(): request_message = _get_request_message(RestoreRun()) _get_tracking_store().restore_run(request_message.run_id) response_message = RestoreRun.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _restore_experiment(): request_message = _get_request_message(RestoreExperiment()) _get_tracking_store().restore_experiment(request_message.experiment_id) response_message = RestoreExperiment.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def search_registered_models(self, filter_string=None, max_results=None, order_by=None, page_token=None): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ req_body = message_to_json( SearchRegisteredModels(filter=filter_string, max_results=max_results, order_by=order_by, page_token=page_token)) response_proto = self._call_endpoint(SearchRegisteredModels, req_body) registered_models = [ RegisteredModel.from_proto(registered_model) for registered_model in response_proto.registered_models ] return PagedList(registered_models, response_proto.next_page_token)
def _delete_tag(): request_message = _get_request_message(DeleteTag()) _get_tracking_store().delete_tag(request_message.run_id, request_message.key) response_message = DeleteTag.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def delete_tag(self, run_id, key): """ Delete a tag from a run. This is irreversible. :param run_id: String ID of the run :param key: Name of the tag """ req_body = message_to_json(DeleteTag(run_id=run_id, key=key)) self._call_endpoint(DeleteTag, req_body)
def _get_run(): request_message = _get_request_message(GetRun()) response_message = GetRun.Response() run_id = request_message.run_id or request_message.run_uuid response_message.run.MergeFrom( _get_tracking_store().get_run(run_id).to_proto()) response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def test_message_to_json(): json_out = message_to_json( Experiment("123", "name", "arty", 'active').to_proto()) assert json.loads(json_out) == { "experiment_id": "123", "name": "name", "artifact_location": "arty", "lifecycle_stage": 'active', }
def _set_experiment_tag(): request_message = _get_request_message(SetExperimentTag()) tag = ExperimentTag(request_message.key, request_message.value) _get_tracking_store().set_experiment_tag(request_message.experiment_id, tag) response_message = SetExperimentTag.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _log_param(): request_message = _get_request_message(LogParam()) param = Param(request_message.key, request_message.value) run_id = request_message.run_id or request_message.run_uuid _get_tracking_store().log_param(run_id, param) response_message = LogParam.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _update_run(): request_message = _get_request_message(UpdateRun()) run_id = request_message.run_id or request_message.run_uuid updated_info = _get_tracking_store().update_run_info( run_id, request_message.status, request_message.end_time) response_message = UpdateRun.Response(run_info=updated_info.to_proto()) response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _update_experiment(): request_message = _get_request_message(UpdateExperiment()) if request_message.new_name: _get_tracking_store().rename_experiment(request_message.experiment_id, request_message.new_name) response_message = UpdateExperiment.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _set_tag(): request_message = _get_request_message(SetTag()) tag = RunTag(request_message.key, request_message.value) run_id = request_message.run_id or request_message.run_uuid _get_tracking_store().set_tag(run_id, tag) response_message = SetTag.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def update_run_info(self, run_id, run_status, end_time): """ Updates the metadata of the specified run. """ req_body = message_to_json( UpdateRun(run_uuid=run_id, run_id=run_id, status=run_status, end_time=end_time)) response_proto = self._call_endpoint(UpdateRun, req_body) return RunInfo.from_proto(response_proto.run_info)
def _create_experiment(): request_message = _get_request_message(CreateExperiment()) experiment_id = _get_tracking_store().create_experiment( request_message.name, request_message.artifact_location) response_message = CreateExperiment.Response() response_message.experiment_id = experiment_id response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _get_experiment(): request_message = _get_request_message(GetExperiment()) response_message = GetExperiment.Response() experiment = _get_tracking_store().get_experiment( request_message.experiment_id).to_proto() response_message.experiment.MergeFrom(experiment) response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _log_metric(): request_message = _get_request_message(LogMetric()) metric = Metric(request_message.key, request_message.value, request_message.timestamp, request_message.step) run_id = request_message.run_id or request_message.run_uuid _get_tracking_store().log_metric(run_id, metric) response_message = LogMetric.Response() response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _get_metric_history(): request_message = _get_request_message(GetMetricHistory()) response_message = GetMetricHistory.Response() run_id = request_message.run_id or request_message.run_uuid metric_entites = _get_tracking_store().get_metric_history( run_id, request_message.metric_key) response_message.metrics.extend([m.to_proto() for m in metric_entites]) response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def _list_experiments(): request_message = _get_request_message(ListExperiments()) experiment_entities = _get_tracking_store().list_experiments( request_message.view_type) response_message = ListExperiments.Response() response_message.experiments.extend( [e.to_proto() for e in experiment_entities]) response = Response(mimetype='application/json') response.set_data(message_to_json(response_message)) return response
def get_registered_model(self, name): """ Get registered model instance by name. :param name: Registered model name. :return: A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ req_body = message_to_json(GetRegisteredModel(name=name)) response_proto = self._call_endpoint(GetRegisteredModel, req_body) return RegisteredModel.from_proto(response_proto.registered_model)
def list_experiments(self, view_type=ViewType.ACTIVE_ONLY): """ :return: a list of all known Experiment objects """ req_body = message_to_json(ListExperiments(view_type=view_type)) response_proto = self._call_endpoint(ListExperiments, req_body) return [ Experiment.from_proto(experiment_proto) for experiment_proto in response_proto.experiments ]
def delete_registered_model(self, name): """ Delete the registered model. Backend raises exception if a registered model with given name does not exist. :param name: Registered model name. :return: None """ req_body = message_to_json(DeleteRegisteredModel(name=name)) self._call_endpoint(DeleteRegisteredModel, req_body)
def log_batch(self, run_id, metrics, params, tags): metric_protos = [metric.to_proto() for metric in metrics] param_protos = [param.to_proto() for param in params] tag_protos = [tag.to_proto() for tag in tags] req_body = message_to_json( LogBatch(metrics=metric_protos, params=param_protos, tags=tag_protos, run_id=run_id)) self._call_endpoint(LogBatch, req_body)
def get_run(self, run_id): """ Fetch the run from backend store :param run_id: Unique identifier for the run :return: A single Run object if it exists, otherwise raises an Exception """ req_body = message_to_json(GetRun(run_uuid=run_id, run_id=run_id)) response_proto = self._call_endpoint(GetRun, req_body) return Run.from_proto(response_proto.run)
def set_registered_model_tag(self, name, tag): """ Set a tag for the registered model. :param name: Registered model name. :param tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log. :return: None """ req_body = message_to_json( SetRegisteredModelTag(name=name, key=tag.key, value=tag.value)) self._call_endpoint(SetRegisteredModelTag, req_body)
def delete_registered_model_tag(self, name, key): """ Delete a tag associated with the registered model. :param name: Registered model name. :param key: Registered model tag key. :return: None """ req_body = message_to_json(DeleteRegisteredModelTag(name=name, key=key)) self._call_endpoint(DeleteRegisteredModelTag, req_body)
def delete_model_version(self, name, version): """ Delete model version in backend. :param name: Registered model name. :param version: Registered model version. :return: None """ req_body = message_to_json( DeleteModelVersion(name=name, version=str(version))) self._call_endpoint(DeleteModelVersion, req_body)
def set_experiment_tag(self, experiment_id, tag): """ Set a tag for the specified experiment :param experiment_id: String ID of the experiment :param tag: ExperimentRunTag instance to log """ req_body = message_to_json( SetExperimentTag(experiment_id=experiment_id, key=tag.key, value=tag.value)) self._call_endpoint(SetExperimentTag, req_body)
def rename_registered_model(self, name, new_name): """ Rename the registered model. :param name: Registered model name. :param new_name: New proposed name. :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ req_body = message_to_json( RenameRegisteredModel(name=name, new_name=new_name)) response_proto = self._call_endpoint(RenameRegisteredModel, req_body) return RegisteredModel.from_proto(response_proto.registered_model)
def update_registered_model(self, name, description): """ Update description of the registered model. :param name: Registered model name. :param description: New description. :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ req_body = message_to_json( UpdateRegisteredModel(name=name, description=description)) response_proto = self._call_endpoint(UpdateRegisteredModel, req_body) return RegisteredModel.from_proto(response_proto.registered_model)
def delete_model_version_tag(self, name, version, key): """ Delete a tag associated with the model version. :param name: Registered model name. :param version: Registered model version. :param key: Tag key. :return: None """ req_body = message_to_json( DeleteModelVersionTag(name=name, version=version, key=key)) self._call_endpoint(DeleteModelVersionTag, req_body)