def set_terminated(self, run_id, status=None, end_time=None): """Set a run's status to terminated. :param status: A string value of :py:class:`mlflow.entities.RunStatus`. Defaults to "FINISHED". :param end_time: If not provided, defaults to the current time.""" end_time = end_time if end_time else int(time.time() * 1000) status = status if status else RunStatus.to_string(RunStatus.FINISHED) self.store.update_run_info(run_id, run_status=RunStatus.from_string(status), end_time=end_time)
def update_run_info(self, run_uuid, run_status, end_time): run = self._get_run(run_uuid) self._check_run_is_active(run) run.status = RunStatus.to_string(run_status) run.end_time = end_time self._save_to_db(run) run = run.to_mlflow_entity() return run.info
def test_start_existing_run_end_time(empty_active_run_stack): run_id = mlflow.start_run().info.run_id mlflow.end_run() run_obj_info = MlflowClient().get_run(run_id).info old_end = run_obj_info.end_time assert run_obj_info.status == RunStatus.to_string(RunStatus.FINISHED) mlflow.start_run(run_id) mlflow.end_run() run_obj_info = MlflowClient().get_run(run_id).info assert run_obj_info.end_time > old_end
def test_on_pipeline_error(tmp_path, config_dir, mocker): # config_dir is a global fixture in conftest that emulates # the root of a Kedro project # Disable logging.config.dictConfig in KedroContext._setup_logging as # it changes logging.config and affects other unit tests mocker.patch("logging.config.dictConfig") mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True) # create the extra mlflow.ymlconfig file for the plugin def _write_yaml(filepath, config): filepath.parent.mkdir(parents=True, exist_ok=True) yaml_str = yaml.dump(config) filepath.write_text(yaml_str) tracking_uri = (tmp_path / "mlruns").as_uri() _write_yaml( tmp_path / "conf" / "base" / "mlflow.yml", dict(mlflow_tracking_uri=tracking_uri), ) def failing_node(): mlflow.start_run(nested=True) raise ValueError("Let's make this pipeline fail") class DummyContextWithHook(KedroContext): project_name = "fake project" package_name = "fake_project" project_version = "0.16.0" hooks = (MlflowPipelineHook(), ) def _get_pipelines(self): return { "__default__": Pipeline([ node( func=failing_node, inputs=None, outputs="fake_output", ) ]) } with pytest.raises(ValueError): failing_context = DummyContextWithHook(tmp_path.as_posix()) failing_context.run() # the run we want is the last one in Default experiment failing_run_info = MlflowClient(tracking_uri).list_run_infos("0")[0] assert mlflow.active_run() is None # the run must have been closed assert failing_run_info.status == RunStatus.to_string( RunStatus.FAILED) # it must be marked as failed
def update_run_info(self, run_id, run_status, end_time): with self.ManagedSessionMaker() as session: run = self._get_run(run_uuid=run_id, session=session) self._check_run_is_active(run) run.status = RunStatus.to_string(run_status) run.end_time = end_time self._save_to_db(objs=run, session=session) run = run.to_mlflow_entity() return run.info
def test_run_local_git_repo( patch_user, # pylint: disable=unused-argument local_git_repo, local_git_repo_uri, tracking_uri_mock, # pylint: disable=unused-argument use_start_run, version): if version is not None: uri = local_git_repo_uri + "#" + TEST_PROJECT_NAME else: uri = os.path.join("%s/" % local_git_repo, TEST_PROJECT_NAME) if version == "git-commit": version = _get_version_local_git_repo(local_git_repo) submitted_run = mlflow.projects.run( uri, entry_point="test_tracking", version=version, parameters={"use_start_run": use_start_run}, use_conda=False, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID) # Blocking runs should be finished when they return validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) # Test that we can call wait() on a synchronous run & that the run has the correct # status after calling wait(). submitted_run.wait() validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) # Validate run contents in the FileStore run_id = submitted_run.run_id mlflow_service = mlflow.tracking.MlflowClient() run_infos = mlflow_service.list_run_infos( experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, run_view_type=ViewType.ACTIVE_ONLY) assert len(run_infos) == 1 store_run_id = run_infos[0].run_id assert run_id == store_run_id run = mlflow_service.get_run(run_id) assert run.info.status == RunStatus.to_string(RunStatus.FINISHED) assert run.data.params == {"use_start_run": use_start_run} assert run.data.metrics == {"some_key": 3} tags = run.data.tags assert tags[MLFLOW_USER] == MOCK_USER assert "file:" in tags[MLFLOW_SOURCE_NAME] assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT) assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking" assert tags[MLFLOW_PROJECT_BACKEND] == "local" if version == "master": assert tags[MLFLOW_GIT_BRANCH] == "master" assert tags[MLFLOW_GIT_REPO_URL] == local_git_repo_uri assert tags[LEGACY_MLFLOW_GIT_BRANCH_NAME] == "master" assert tags[LEGACY_MLFLOW_GIT_REPO_URL] == local_git_repo_uri
def test_list_run_infos(self): fs = FileStore(self.test_root) for exp_id in self.experiments: run_infos = fs.list_run_infos(exp_id, run_view_type=ViewType.ALL) for run_info in run_infos: run_id = run_info.run_id dict_run_info = self.run_data[run_id] dict_run_info.pop("metrics") dict_run_info.pop("params") dict_run_info.pop("tags") dict_run_info['lifecycle_stage'] = LifecycleStage.ACTIVE dict_run_info['status'] = RunStatus.to_string(dict_run_info['status']) self.assertEqual(dict_run_info, dict(run_info))
def test_bad_comparators(entity_type, bad_comparators, key, entity_value): run = Run(run_info=RunInfo( run_uuid="hi", run_id="hi", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData(metrics=[], params=[], tags=[]) ) for bad_comparator in bad_comparators: bad_filter = "{entity_type}.{key} {comparator} {value}".format( entity_type=entity_type, key=key, comparator=bad_comparator, value=entity_value) with pytest.raises(MlflowException) as e: SearchUtils.filter([run], bad_filter) assert "Invalid comparator" in str(e.value.message)
def test_correct_filtering(filter_string, matching_runs): runs = [ Run(run_info=RunInfo(run_uuid="hi", run_id="hi", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData(metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[])), Run(run_info=RunInfo(run_uuid="hi2", run_id="hi2", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData(metrics=[Metric("key1", 123, 1, 0)], params=[Param("my_param", "A")], tags=[RunTag("tag1", "C")])), Run(run_info=RunInfo(run_uuid="hi3", run_id="hi3", experiment_id=1, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData(metrics=[Metric("key1", 125, 1, 0)], params=[Param("my_param", "B")], tags=[RunTag("tag1", "D")])), ] filtered_runs = SearchUtils.filter(runs, filter_string) assert set(filtered_runs) == set([runs[i] for i in matching_runs])
def _read_persisted_run_info_dict(run_info_dict): dict_copy = run_info_dict.copy() if 'lifecycle_stage' not in dict_copy: dict_copy['lifecycle_stage'] = LifecycleStage.ACTIVE # 'status' is stored as an integer enum in meta file, but RunInfo.status field is a string. # converting to string before hydrating RunInfo. # If 'status' value not recorded in files, mark it as 'RUNNING' (default) dict_copy['status'] = RunStatus.to_string(run_info_dict.get('status', RunStatus.RUNNING)) # 'experiment_id' was changed from int to string, so we must cast to string # when reading legacy run_infos if isinstance(dict_copy["experiment_id"], int): dict_copy["experiment_id"] = str(dict_copy["experiment_id"]) return RunInfo.from_dictionary(dict_copy)
def _create(): run_id = str(uuid.uuid4()) experiment_id = str(random_int(10, 2000)) user_id = random_str(random_int(10, 25)) status = RunStatus.to_string(random.choice(RunStatus.all_status())) start_time = random_int(1, 10) end_time = start_time + random_int(1, 10) lifecycle_stage = LifecycleStage.ACTIVE artifact_uri = random_str(random_int(10, 40)) ri = RunInfo(run_uuid=run_id, run_id=run_id, experiment_id=experiment_id, user_id=user_id, status=status, start_time=start_time, end_time=end_time, lifecycle_stage=lifecycle_stage, artifact_uri=artifact_uri) return (ri, run_id, experiment_id, user_id, status, start_time, end_time, lifecycle_stage, artifact_uri)
def _get_run_configs(self, name='test', experiment_id=None): return { 'experiment_id': experiment_id, 'name': name, 'user_id': 'Anderson', 'run_uuid': uuid.uuid4().hex, 'status': RunStatus.to_string(RunStatus.SCHEDULED), 'source_type': SourceType.to_string(SourceType.NOTEBOOK), 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), 'end_time': int(time.time()), 'source_version': mlflow.__version__, 'lifecycle_stage': entities.LifecycleStage.ACTIVE, 'artifact_uri': '//' }
def test_on_pipeline_error(kedro_project_with_mlflow_conf): tracking_uri = (kedro_project_with_mlflow_conf / "mlruns").as_uri() project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf) _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf) configure_project(project_metadata.package_name) with KedroSession.create( package_name=project_metadata.package_name, project_path=kedro_project_with_mlflow_conf, ): def failing_node(): mlflow.start_run(nested=True) raise ValueError("Let's make this pipeline fail") class DummyContextWithHook(KedroContext): project_name = "fake project" package_name = "fake_project" project_version = "0.16.5" hooks = (MlflowPipelineHook(),) def _get_pipeline(self, name: str = None) -> Pipeline: return Pipeline( [ node( func=failing_node, inputs=None, outputs="fake_output", ) ] ) with pytest.raises(ValueError): failing_context = DummyContextWithHook( "fake_package", kedro_project_with_mlflow_conf.as_posix() ) failing_context.run() # the run we want is the last one in Default experiment failing_run_info = MlflowClient(tracking_uri).list_run_infos("0")[0] assert mlflow.active_run() is None # the run must have been closed assert failing_run_info.status == RunStatus.to_string( RunStatus.FAILED ) # it must be marked as failed
def create_run(self, experiment_id, user_id, run_name, source_type, source_name, entry_point_name, start_time, source_version, tags, parent_run_id): with self.ManagedSessionMaker() as session: experiment = self.get_experiment(experiment_id) if experiment.lifecycle_stage != LifecycleStage.ACTIVE: raise MlflowException( 'Experiment id={} must be active'.format(experiment_id), INVALID_STATE) run_uuid = uuid.uuid4().hex artifact_location = build_path( experiment.artifact_location, run_uuid, SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) run = SqlRun(name=run_name or "", artifact_uri=artifact_location, run_uuid=run_uuid, experiment_id=experiment_id, source_type=SourceType.to_string(source_type), source_name=source_name, entry_point_name=entry_point_name, user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, source_version=source_version, lifecycle_stage=LifecycleStage.ACTIVE) tags_dict = {} for tag in tags: tags_dict[tag.key] = tag.value if parent_run_id: tags_dict[MLFLOW_PARENT_RUN_ID] = parent_run_id if run_name: tags_dict[MLFLOW_RUN_NAME] = run_name run.tags = [ SqlTag(key=key, value=value) for key, value in tags_dict.items() ] self._save_to_db(objs=run, session=session) return run.to_mlflow_entity()
def test_filter_runs_by_start_time(): runs = [ Run( run_info=RunInfo( run_uuid=run_id, run_id=run_id, experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED), start_time=idx, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE, ), run_data=RunData(), ) for idx, run_id in enumerate(["a", "b", "c"]) ] assert SearchUtils.filter(runs, "attribute.start_time >= 0") == runs assert SearchUtils.filter(runs, "attribute.start_time > 1") == runs[2:] assert SearchUtils.filter(runs, "attribute.start_time = 2") == runs[2:]
def create_run(self, experiment_id, user_id, start_time, tags): """ Creates a run with the specified attributes. """ experiment_id = FileStore.DEFAULT_EXPERIMENT_ID if experiment_id is None else experiment_id experiment = self.get_experiment(experiment_id) if experiment is None: raise MlflowException( "Could not create run under experiment with ID %s - no such experiment " "exists." % experiment_id, databricks_pb2.RESOURCE_DOES_NOT_EXIST, ) if experiment.lifecycle_stage != LifecycleStage.ACTIVE: raise MlflowException( "Could not create run under non-active experiment with ID " "%s." % experiment_id, databricks_pb2.INVALID_STATE, ) run_uuid = uuid.uuid4().hex artifact_uri = self._get_artifact_dir(experiment_id, run_uuid) run_info = RunInfo( run_uuid=run_uuid, run_id=run_uuid, experiment_id=experiment_id, artifact_uri=artifact_uri, user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, lifecycle_stage=LifecycleStage.ACTIVE, ) # Persist run metadata and create directories for logging metrics, parameters, artifacts run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) mkdir(run_dir) run_info_dict = _make_persisted_run_info_dict(run_info) write_yaml(run_dir, FileStore.META_DATA_FILE_NAME, run_info_dict) mkdir(run_dir, FileStore.METRICS_FOLDER_NAME) mkdir(run_dir, FileStore.PARAMS_FOLDER_NAME) mkdir(run_dir, FileStore.ARTIFACTS_FOLDER_NAME) for tag in tags: self.set_tag(run_uuid, tag) return self.get_run(run_id=run_uuid)
def test_on_pipeline_error(kedro_project_with_mlflow_conf): tracking_uri = (kedro_project_with_mlflow_conf / "mlruns").as_uri() bootstrap_project(kedro_project_with_mlflow_conf) with KedroSession.create( project_path=kedro_project_with_mlflow_conf) as session: context = session.load_context() with pytest.raises(ValueError): session.run() # the run we want is the last one in the configuration experiment mlflow_client = MlflowClient(tracking_uri) experiment = mlflow_client.get_experiment_by_name( context.mlflow.tracking.experiment.name) failing_run_info = MlflowClient(tracking_uri).list_run_infos( experiment.experiment_id)[0] assert mlflow.active_run() is None # the run must have been closed assert failing_run_info.status == RunStatus.to_string( RunStatus.FAILED) # it must be marked as failed
def test_order_by_metric_with_nans_and_infs(): metric_vals_str = ["nan", "inf", "-inf", "-1000", "0", "1000"] runs = [ Run(run_info=RunInfo(run_id=x, run_uuid=x, experiment_id=0, user_id="user", status=RunStatus.to_string(RunStatus.FINISHED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("x", float(x), 1, 0)]) ) for x in metric_vals_str ] sorted_runs_asc = [ x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x asc"]) ] sorted_runs_desc = [ x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x desc"]) ] # asc assert ["-inf", "-1000", "0", "1000", "inf", "nan"] == sorted_runs_asc # desc assert ["inf", "1000", "0", "-1000", "-inf", "nan"] == sorted_runs_desc
def create_run(self, experiment_id, user_id, start_time, tags): """ Create a run under the specified experiment ID, setting the run's status to "RUNNING" and the start time to the current time. :param experiment_id: String id of the experiment for this run :param user_id: ID of the user launching this run :return: The created Run object """ experiment = self.get_experiment(experiment_id) if experiment is None: raise MlflowException( "Could not create run under experiment with ID %s - no such experiment " "exists." % experiment_id, RESOURCE_DOES_NOT_EXIST, ) if experiment.lifecycle_stage != LifecycleStage.ACTIVE: raise MlflowException( "Could not create run under non-active experiment with ID " "%s." % experiment_id, INVALID_STATE, ) run_id = uuid.uuid4().hex artifact_uri = os.path.join(experiment.artifact_location, run_id, "artifacts") run_info = RunInfo( run_uuid=run_id, run_id=run_id, experiment_id=experiment_id, artifact_uri=artifact_uri, user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, lifecycle_stage=LifecycleStage.ACTIVE, ) if self._create_run_info(_make_persisted_run_info_dict(run_info)): for tag in tags: self.set_tag(run_id, tag) return self.get_run(run_id)
def create_run(self, experiment_id, user_id, run_name, source_type, source_name, entry_point_name, start_time, source_version, tags, parent_run_id): _ = parent_run_id experiment = self.get_experiment(experiment_id) if experiment.lifecycle_stage != Experiment.ACTIVE_LIFECYCLE: raise MlflowException('Experiment id={} must be active'.format(experiment_id), INVALID_STATE) status = RunStatus.to_string(RunStatus.RUNNING) run_uuid = uuid.uuid4().hex run = SqlRun(name=run_name, artifact_uri=None, run_uuid=run_uuid, experiment_id=experiment_id, source_type=source_type, source_name=source_name, entry_point_name=entry_point_name, user_id=user_id, status=status, start_time=start_time, end_time=None, source_version=source_version, lifecycle_stage=RunInfo.ACTIVE_LIFECYCLE) for tag in tags: run.tags.append(SqlTag(key=tag.key, value=tag.value)) self._save_to_db([run]) return run.to_mlflow_entity()
def test_run( tmpdir, # pylint: disable=unused-argument patch_user, # pylint: disable=unused-argument tracking_uri_mock, # pylint: disable=unused-argument use_start_run): submitted_run = mlflow.projects.run( TEST_PROJECT_DIR, entry_point="test_tracking", parameters={"use_start_run": use_start_run}, use_conda=False, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID) assert submitted_run.run_id is not None # Blocking runs should be finished when they return validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) # Test that we can call wait() on a synchronous run & that the run has the correct # status after calling wait(). submitted_run.wait() validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) # Validate run contents in the FileStore run_id = submitted_run.run_id mlflow_service = mlflow.tracking.MlflowClient() run_infos = mlflow_service.list_run_infos( experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, run_view_type=ViewType.ACTIVE_ONLY) assert len(run_infos) == 1 store_run_id = run_infos[0].run_id assert run_id == store_run_id run = mlflow_service.get_run(run_id) assert run.info.status == RunStatus.to_string(RunStatus.FINISHED) assert run.data.params == {"use_start_run": use_start_run} assert run.data.metrics == {"some_key": 3} tags = run.data.tags assert tags[MLFLOW_USER] == MOCK_USER assert "file:" in tags[MLFLOW_SOURCE_NAME] assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT) assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking"
def on_pipeline_error( self, error: Exception, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog, ): """Hook invoked when the pipeline execution fails. All the mlflow runs must be closed to avoid interference with further execution. Args: error: (Not used) The uncaught exception thrown during the pipeline run. run_params: (Not used) The params used to run the pipeline. Should be identical to the data logged by Journal with the following schema:: { "run_id": str "project_path": str, "env": str, "kedro_version": str, "tags": Optional[List[str]], "from_nodes": Optional[List[str]], "to_nodes": Optional[List[str]], "node_names": Optional[List[str]], "from_inputs": Optional[List[str]], "load_versions": Optional[List[str]], "pipeline_name": str, "extra_params": Optional[Dict[str, Any]] } pipeline: (Not used) The ``Pipeline`` that will was run. catalog: (Not used) The ``DataCatalog`` used during the run. """ if self._is_mlflow_enabled: while mlflow.active_run(): mlflow.end_run(RunStatus.to_string(RunStatus.FAILED)) else: # pragma: no cover # the catalog is supposed to be reloaded each time with _get_catalog, # hence it should not be modified. this is only a safeguard switch_catalog_logging(catalog, True)
def create_run(self, experiment_id, user_id, run_name, source_type, source_name, entry_point_name, start_time, source_version, tags, parent_run_id): experiment = self.get_experiment(experiment_id) if experiment.lifecycle_stage != LifecycleStage.ACTIVE: raise MlflowException( 'Experiment id={} must be active'.format(experiment_id), INVALID_STATE) run_uuid = uuid.uuid4().hex artifact_location = build_path(experiment.artifact_location, run_uuid, SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) run = SqlRun(name=run_name or "", artifact_uri=artifact_location, run_uuid=run_uuid, experiment_id=experiment_id, source_type=SourceType.to_string(source_type), source_name=source_name, entry_point_name=entry_point_name, user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, source_version=source_version, lifecycle_stage=LifecycleStage.ACTIVE) for tag in tags: run.tags.append(SqlTag(key=tag.key, value=tag.value)) if parent_run_id: run.tags.append( SqlTag(key=MLFLOW_PARENT_RUN_ID, value=parent_run_id)) if run_name: run.tags.append(SqlTag(key=MLFLOW_RUN_NAME, value=run_name)) self._save_to_db([run]) return run.to_mlflow_entity()
def get_infos(run_uuid, store=None): from mlflow.entities import RunStatus run = get_run(run_uuid, store=store) if run.info.end_time is None: duration = None else: duration = run.info.end_time - run.info.start_time return { ("run", "uuid"): run.info.run_uuid, ("run", "experiment_id"): run.info.experiment_id, ("run", "status"): RunStatus.to_string(run.info.status), ("run", "start_time"): run.info.start_time, ("run", "end_time"): run.info.end_time, ("run", "duration"): duration, **{("metric", m.key): m.value for m in get_all_metrics(run_uuid, store=store)}, **{("param", p.key): p.value for p in get_all_params(run_uuid, store=store)}, }
def create_run(self, experiment_id, user_id, start_time, tags): with self.ManagedSessionMaker() as session: experiment = self.get_experiment(experiment_id) self._check_experiment_is_active(experiment) run_id = uuid.uuid4().hex artifact_location = append_to_uri_path(experiment.artifact_location, run_id, SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) run = SqlRun(name="", artifact_uri=artifact_location, run_uuid=run_id, experiment_id=experiment_id, source_type=SourceType.to_string(SourceType.UNKNOWN), source_name="", entry_point_name="", user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), start_time=start_time, end_time=None, source_version="", lifecycle_stage=LifecycleStage.ACTIVE) tags_dict = {} for tag in tags: tags_dict[tag.key] = tag.value run.tags = [SqlTag(key=key, value=value) for key, value in tags_dict.items()] self._save_to_db(objs=run, session=session) return run.to_mlflow_entity()
def end_run(status=RunStatus.to_string(RunStatus.FINISHED)): """End an active MLflow run (if there is one). .. code-block:: python :caption: Example import mlflow # Start run and get status mlflow.start_run() run = mlflow.active_run() print("run_id: {}; status: {}".format(run.info.run_id, run.info.status)) # End run and get status mlflow.end_run() run = mlflow.get_run(run.info.run_id) print("run_id: {}; status: {}".format(run.info.run_id, run.info.status)) print("--") # Check for any active runs print("Active run: {}".format(mlflow.active_run())) .. code-block:: text :caption: Output run_id: b47ee4563368419880b44ad8535f6371; status: RUNNING run_id: b47ee4563368419880b44ad8535f6371; status: FINISHED -- Active run: None """ global _active_run_stack if len(_active_run_stack) > 0: # Clear out the global existing run environment variable as well. env.unset_variable(_RUN_ID_ENV_VAR) run = _active_run_stack.pop() MlflowClient().set_terminated(run.info.run_id, status)
def test_status_mappings(self): # test enum to string mappings self.assertEqual("RUNNING", RunStatus.to_string(RunStatus.RUNNING)) self.assertEqual(RunStatus.RUNNING, RunStatus.from_string("RUNNING")) self.assertEqual("SCHEDULED", RunStatus.to_string(RunStatus.SCHEDULED)) self.assertEqual(RunStatus.SCHEDULED, RunStatus.from_string("SCHEDULED")) self.assertEqual("FINISHED", RunStatus.to_string(RunStatus.FINISHED)) self.assertEqual(RunStatus.FINISHED, RunStatus.from_string("FINISHED")) self.assertEqual("FAILED", RunStatus.to_string(RunStatus.FAILED)) self.assertEqual(RunStatus.FAILED, RunStatus.from_string("FAILED")) self.assertEqual("KILLED", RunStatus.to_string(RunStatus.KILLED)) self.assertEqual(RunStatus.KILLED, RunStatus.from_string("KILLED")) with self.assertRaises(Exception) as e: RunStatus.to_string(-120) self.assertIn("Could not get string corresponding to run status -120", str(e.exception)) with self.assertRaises(Exception) as e: RunStatus.from_string("the IMPOSSIBLE status string") self.assertIn("Could not get run status corresponding to string the IMPO", str(e.exception))
def __exit__(self, exc_type, exc_val, exc_tb): status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED end_run(RunStatus.to_string(status)) return exc_type is None
class SqlRun(Base): """ DB model for :py:class:`mlflow.entities.Run`. These are recorded in ``runs`` table. """ __tablename__ = "runs" run_uuid = Column(String(32), nullable=False) """ Run UUID: `String` (limit 32 characters). *Primary Key* for ``runs`` table. """ name = Column(String(250)) """ Run name: `String` (limit 250 characters). """ source_type = Column(String(20), default=SourceType.to_string(SourceType.LOCAL)) """ Source Type: `String` (limit 20 characters). Can be one of ``NOTEBOOK``, ``JOB``, ``PROJECT``, ``LOCAL`` (default), or ``UNKNOWN``. """ source_name = Column(String(500)) """ Name of source recording the run: `String` (limit 500 characters). """ entry_point_name = Column(String(50)) """ Entry-point name that launched the run run: `String` (limit 50 characters). """ user_id = Column(String(256), nullable=True, default=None) """ User ID: `String` (limit 256 characters). Defaults to ``null``. """ status = Column(String(20), default=RunStatus.to_string(RunStatus.SCHEDULED)) """ Run Status: `String` (limit 20 characters). Can be one of ``RUNNING``, ``SCHEDULED`` (default), ``FINISHED``, ``FAILED``. """ start_time = Column(BigInteger, default=int(time.time())) """ Run start time: `BigInteger`. Defaults to current system time. """ end_time = Column(BigInteger, nullable=True, default=None) """ Run end time: `BigInteger`. """ source_version = Column(String(50)) """ Source version: `String` (limit 50 characters). """ lifecycle_stage = Column(String(20), default=LifecycleStage.ACTIVE) """ Lifecycle Stage of run: `String` (limit 32 characters). Can be either ``active`` (default) or ``deleted``. """ artifact_uri = Column(String(200), default=None) """ Default artifact location for this run: `String` (limit 200 characters). """ experiment_id = Column(Integer, ForeignKey("experiments.experiment_id")) """ Experiment ID to which this run belongs to: *Foreign Key* into ``experiment`` table. """ experiment = relationship("SqlExperiment", backref=backref("runs", cascade="all")) """ SQLAlchemy relationship (many:one) with :py:class:`mlflow.store.dbmodels.models.SqlExperiment`. """ __table_args__ = ( CheckConstraint(source_type.in_(SourceTypes), name="source_type"), CheckConstraint(status.in_(RunStatusTypes), name="status"), CheckConstraint( lifecycle_stage.in_( LifecycleStage.view_type_to_stages(ViewType.ALL)), name="runs_lifecycle_stage", ), PrimaryKeyConstraint("run_uuid", name="run_pk"), ) @staticmethod def get_attribute_name(mlflow_attribute_name): """ Resolves an MLflow attribute name to a `SqlRun` attribute name. """ # Currently, MLflow Search attributes defined in `SearchUtils.VALID_SEARCH_ATTRIBUTE_KEYS` # share the same names as their corresponding `SqlRun` attributes. Therefore, this function # returns the same attribute name return mlflow_attribute_name def to_mlflow_entity(self): """ Convert DB model to corresponding MLflow entity. :return: :py:class:`mlflow.entities.Run`. """ run_info = RunInfo( run_uuid=self.run_uuid, run_id=self.run_uuid, experiment_id=str(self.experiment_id), user_id=self.user_id, status=self.status, start_time=self.start_time, end_time=self.end_time, lifecycle_stage=self.lifecycle_stage, artifact_uri=self.artifact_uri, ) run_data = RunData( metrics=[m.to_mlflow_entity() for m in self.latest_metrics], params=[p.to_mlflow_entity() for p in self.params], tags=[t.to_mlflow_entity() for t in self.tags], ) return Run(run_info=run_info, run_data=run_data)
ViewType, ExperimentTag, ) from mlflow.entities.lifecycle_stage import LifecycleStage from mlflow.store.db.base_sql_model import Base SourceTypes = [ SourceType.to_string(SourceType.NOTEBOOK), SourceType.to_string(SourceType.JOB), SourceType.to_string(SourceType.LOCAL), SourceType.to_string(SourceType.UNKNOWN), SourceType.to_string(SourceType.PROJECT), ] RunStatusTypes = [ RunStatus.to_string(RunStatus.SCHEDULED), RunStatus.to_string(RunStatus.FAILED), RunStatus.to_string(RunStatus.FINISHED), RunStatus.to_string(RunStatus.RUNNING), RunStatus.to_string(RunStatus.KILLED), ] class SqlExperiment(Base): """ DB model for :py:class:`mlflow.entities.Experiment`. These are recorded in ``experiment`` table. """ __tablename__ = "experiments" experiment_id = Column(Integer, autoincrement=True)