def create_run(self, experiment_id, user_id, start_time, tags): with self.ManagedSessionMaker() as session: experiment = self.get_experiment(experiment_id) self._check_experiment_is_active(experiment) run_id = uuid.uuid4().hex artifact_location = append_to_uri_path( experiment.artifact_location, run_id, SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) run = SqlRun( name="", artifact_uri=artifact_location, run_uuid=run_id, experiment_id=experiment_id, source_type=SourceType.to_string(SourceType.UNKNOWN), source_name="", entry_point_name="", user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, source_version="", lifecycle_stage=LifecycleStage.ACTIVE, ) tags_dict = {} for tag in tags: tags_dict[tag.key] = tag.value run.tags = [ SqlTag(key=key, value=value) for key, value in tags_dict.items() ] self._save_to_db(objs=run, session=session) return run.to_mlflow_entity()
def set_tag(self, run_id, tag): """ Set a tag on a run. :param run_id: String ID of the run :param tag: RunTag instance to log """ with self.ManagedSessionMaker() as session: _validate_tag(tag.key, tag.value) run = self._get_run(run_uuid=run_id, session=session) self._check_run_is_active(run) session.merge(SqlTag(run_uuid=run_id, key=tag.key, value=tag.value))
def record_logged_model(self, run_id, mlflow_model): if not isinstance(mlflow_model, Model): raise TypeError("Argument 'mlflow_model' should be mlflow.models.Model, got '{}'" .format(type(mlflow_model))) model_dict = mlflow_model.to_dict() with self.ManagedSessionMaker() as session: run = self._get_run(run_uuid=run_id, session=session) self._check_run_is_active(run) previous_tag = [t for t in run.tags if t.key == MLFLOW_LOGGED_MODELS] if previous_tag: value = json.dumps(json.loads(previous_tag[0].value) + [model_dict]) else: value = json.dumps([model_dict]) _validate_tag(MLFLOW_LOGGED_MODELS, value) session.merge(SqlTag(key=MLFLOW_LOGGED_MODELS, value=value, run_uuid=run_id))