def test_validate_autologging_run_validates_run_status_correctly():
    valid_autologging_tags = {
        MLFLOW_AUTOLOGGING: "test_integration",
    }

    with mlflow.start_run(tags=valid_autologging_tags) as run_finished:
        run_id_finished = run_finished.info.run_id

    assert (
        RunStatus.from_string(MlflowClient().get_run(run_id_finished).info.status)
        == RunStatus.FINISHED
    )
    _validate_autologging_run("test_integration", run_id_finished)

    with mlflow.start_run(tags=valid_autologging_tags) as run_failed:
        run_id_failed = run_failed.info.run_id

    MlflowClient().set_terminated(run_id_failed, status=RunStatus.to_string(RunStatus.FAILED))
    assert (
        RunStatus.from_string(MlflowClient().get_run(run_id_failed).info.status) == RunStatus.FAILED
    )
    _validate_autologging_run("test_integration", run_id_finished)

    run_non_terminal = MlflowClient().create_run(
        experiment_id=run_finished.info.experiment_id, tags=valid_autologging_tags
    )
    run_id_non_terminal = run_non_terminal.info.run_id
    assert (
        RunStatus.from_string(MlflowClient().get_run(run_id_non_terminal).info.status)
        == RunStatus.RUNNING
    )
    with pytest.raises(AssertionError, match="has a non-terminal status"):
        _validate_autologging_run("test_integration", run_id_non_terminal)
Example #2
0
def test_with_managed_run_with_throwing_class_exhibits_expected_behavior():
    client = MlflowClient()
    patch_function_active_run = None

    @with_managed_run
    class TestPatch(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            nonlocal patch_function_active_run
            patch_function_active_run = mlflow.active_run()
            raise Exception("bad implementation")

        def _on_exception(self, exception):
            pass

    with pytest.raises(Exception):
        TestPatch.call(lambda: "foo")

    assert patch_function_active_run is not None
    status1 = client.get_run(patch_function_active_run.info.run_id).info.status
    assert RunStatus.from_string(status1) == RunStatus.FAILED

    with mlflow.start_run() as active_run, pytest.raises(Exception):
        TestPatch.call(lambda: "foo")
        assert patch_function_active_run == active_run
        # `with_managed_run` should not terminate a preexisting MLflow run,
        # even if the patch function throws
        status2 = client.get_run(active_run.info.run_id).info.status
        assert RunStatus.from_string(status2) == RunStatus.FINISHED
Example #3
0
def _create_entity(base, model):

    # create dict of kwargs properties for entity and return the initialized entity
    config = {}
    for k in base._properties():
        # check if its mlflow entity and build it
        obj = getattr(model, k)

        if isinstance(model, SqlRun):
            if base is RunData:
                # Run data contains list for metrics, params and tags
                # so obj will be a list so we need to convert those items
                if k == 'metrics':
                    # only get latest recorded metrics per key
                    metrics = {}
                    for o in obj:
                        existing_metric = metrics.get(o.key)
                        if (existing_metric is None) or (o.timestamp > existing_metric.timestamp)\
                            or (o.timestamp == existing_metric.timestamp
                                and o.value > existing_metric.value):
                            metrics[o.key] = Metric(o.key, o.value, o.timestamp)
                    obj = metrics.values()
                elif k == 'params':
                    obj = [Param(o.key, o.value) for o in obj]
                elif k == 'tags':
                    obj = [RunTag(o.key, o.value) for o in obj]
            elif base is RunInfo:
                if k == 'source_type':
                    obj = SourceType.from_string(obj)
                elif k == "status":
                    obj = RunStatus.from_string(obj)

        config[k] = obj
    return base(**config)
Example #4
0
    def to_mlflow_entity(self):
        """
        Convert DB model to corresponding MLflow entity.

        :return: :py:class:`mlflow.entities.Run`.
        """
        run_info = RunInfo(
            run_uuid=self.run_uuid,
            run_id=self.run_uuid,
            experiment_id=str(self.experiment_id),
            user_id=self.user_id,
            status=RunStatus.from_string(self.status),
            start_time=self.start_time,
            end_time=self.end_time,
            lifecycle_stage=self.lifecycle_stage,
            artifact_uri=self.artifact_uri)

        # only get latest recorded metrics per key
        all_metrics = [m.to_mlflow_entity() for m in self.metrics]
        metrics = {}
        for m in all_metrics:
            existing_metric = metrics.get(m.key)
            if (existing_metric is None)\
                or ((m.step, m.timestamp, m.value) >=
                    (existing_metric.step, existing_metric.timestamp,
                        existing_metric.value)):
                metrics[m.key] = m

        run_data = RunData(
            metrics=list(metrics.values()),
            params=[p.to_mlflow_entity() for p in self.params],
            tags=[t.to_mlflow_entity() for t in self.tags])

        return Run(run_info=run_info, run_data=run_data)
Example #5
0
def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior():
    client = MlflowClient()

    @with_managed_run
    def patch_function(original, *args, **kwargs):
        return mlflow.active_run()

    run1 = patch_function(lambda: "foo")
    run1_status = client.get_run(run1.info.run_id).info.status
    assert RunStatus.from_string(run1_status) == RunStatus.FINISHED

    with mlflow.start_run() as active_run:
        run2 = patch_function(lambda: "foo")

    assert run2 == active_run
    run2_status = client.get_run(run2.info.run_id).info.status
    assert RunStatus.from_string(run2_status) == RunStatus.FINISHED
Example #6
0
 def set_terminated(self, run_id, status=None, end_time=None):
     """Sets a Run's status to terminated
     :param: status A string value of mlflow.entities.RunStatus. Defaults to FINISHED.
     :param: end_time If not provided, defaults to the current time."""
     end_time = end_time if end_time else int(time.time() * 1000)
     status = status if status else "FINISHED"
     self.store.update_run_info(run_id,
                                run_status=RunStatus.from_string(status),
                                end_time=end_time)
Example #7
0
    def set_terminated(self, run_id, status=None, end_time=None):
        """Set a run's status to terminated.

        :param status: A string value of :py:class:`mlflow.entities.RunStatus`.
                       Defaults to "FINISHED".
        :param end_time: If not provided, defaults to the current time."""
        end_time = end_time if end_time else int(time.time() * 1000)
        status = status if status else RunStatus.to_string(RunStatus.FINISHED)
        self.store.update_run_info(run_id, run_status=RunStatus.from_string(status),
                                   end_time=end_time)
Example #8
0
def _make_persisted_run_info_dict(run_info):
    run_info_dict = _entity_to_dict(run_info)
    if "status" in run_info_dict:
        # 'status' is stored as an integer enum in meta file, but RunInfo.status field is a string.
        # Convert from string to enum/int before storing.
        run_info_dict["status"] = RunStatus.from_string(run_info.status)
    else:
        run_info_dict["status"] = RunStatus.RUNNING
    run_info_dict["source_type"] = SourceType.LOCAL
    return run_info_dict
Example #9
0
def test_with_startrun():
    run_id = None
    import time
    t0 = int(time.time() * 1000)
    with mlflow.start_run() as active_run:
        assert mlflow.active_run() == active_run
        run_id = active_run.info.run_uuid
    t1 = int(time.time() * 1000)
    run_info = mlflow.tracking._get_store().get_run(run_id).info
    assert run_info.status == RunStatus.from_string("FINISHED")
    assert t0 <= run_info.end_time and run_info.end_time <= t1
    assert mlflow.active_run() is None
Example #10
0
def test_with_managed_run_with_non_throwing_class_exhibits_expected_behavior():
    client = MlflowClient()

    @with_managed_run
    class TestPatch(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            return mlflow.active_run()

        def _on_exception(self, exception):
            pass

    run1 = TestPatch.call(lambda: "foo")
    run1_status = client.get_run(run1.info.run_id).info.status
    assert RunStatus.from_string(run1_status) == RunStatus.FINISHED

    with mlflow.start_run() as active_run:
        run2 = TestPatch.call(lambda: "foo")

    assert run2 == active_run
    run2_status = client.get_run(run2.info.run_id).info.status
    assert RunStatus.from_string(run2_status) == RunStatus.FINISHED
def test_with_managed_run_ends_run_on_keyboard_interrupt():
    client = MlflowClient()
    run = None

    def original():
        nonlocal run
        run = mlflow.active_run()
        raise KeyboardInterrupt

    patch_function_1 = with_managed_run(
        "test_integration", lambda original, *args, **kwargs: original(*args, **kwargs)
    )

    with pytest.raises(KeyboardInterrupt):
        patch_function_1(original)

    assert not mlflow.active_run()
    run_status_1 = client.get_run(run.info.run_id).info.status
    assert RunStatus.from_string(run_status_1) == RunStatus.FAILED

    class PatchFunction2(PatchFunction):
        def _patch_implementation(self, original, *args, **kwargs):
            return original(*args, **kwargs)

        def _on_exception(self, exception):
            pass

    patch_function_2 = with_managed_run("test_integration", PatchFunction2)

    with pytest.raises(KeyboardInterrupt):

        patch_function_2.call(original)

    assert not mlflow.active_run()
    run_status_2 = client.get_run(run.info.run_id).info.status
    assert RunStatus.from_string(run_status_2) == RunStatus.FAILED
Example #12
0
def _make_persisted_run_info_dict(run_info):
    # 'tags' was moved from RunInfo to RunData, so we must keep storing it in the meta.yaml for
    # old mlflow versions to read
    run_info_dict = dict(run_info)
    run_info_dict['tags'] = []
    run_info_dict['name'] = ''
    if 'status' in run_info_dict:
        # 'status' is stored as an integer enum in meta file, but RunInfo.status field is a string.
        # Convert from string to enum/int before storing.
        run_info_dict['status'] = RunStatus.from_string(run_info.status)
    else:
        run_info_dict['status'] = RunStatus.RUNNING
    run_info_dict['source_type'] = SourceType.LOCAL
    run_info_dict['source_name'] = ''
    run_info_dict['entry_point_name'] = ''
    run_info_dict['source_version'] = ''
    return run_info_dict
Example #13
0
def _create_entity(base, model):

    # create dict of kwargs properties for entity and return the initialized entity
    config = {}
    for k in base._properties():
        # check if its mlflow entity and build it
        obj = getattr(model, k)

        if isinstance(model, SqlRun):
            if base is RunData:
                # Run data contains list for metrics, params and tags
                # so obj will be a list so we need to convert those items
                if k == 'metrics':
                    # only get latest recorded metrics per key
                    metrics = {}
                    for o in obj:
                        existing_metric = metrics.get(o.key)
                        if (existing_metric is None)\
                            or ((o.step, o.timestamp, o.value) >=
                                (existing_metric.step, existing_metric.timestamp,
                                 existing_metric.value)):
                            metrics[o.key] = Metric(o.key, o.value,
                                                    o.timestamp, o.step)
                    obj = list(metrics.values())
                elif k == 'params':
                    obj = [Param(o.key, o.value) for o in obj]
                elif k == 'tags':
                    obj = [RunTag(o.key, o.value) for o in obj]
            elif base is RunInfo:
                if k == 'source_type':
                    obj = SourceType.from_string(obj)
                elif k == "status":
                    obj = RunStatus.from_string(obj)
                elif k == "experiment_id":
                    obj = str(obj)

        # Our data model defines experiment_ids as ints, but the in-memory representation was
        # changed to be a string in time for 1.0.
        if isinstance(model, SqlExperiment) and k == "experiment_id":
            obj = str(obj)

        config[k] = obj
    return base(**config)
Example #14
0
    def test_status_mappings(self):
        # test enum to string mappings
        self.assertEqual("RUNNING", RunStatus.to_string(RunStatus.RUNNING))
        self.assertEqual(RunStatus.RUNNING, RunStatus.from_string("RUNNING"))

        self.assertEqual("SCHEDULED", RunStatus.to_string(RunStatus.SCHEDULED))
        self.assertEqual(RunStatus.SCHEDULED, RunStatus.from_string("SCHEDULED"))

        self.assertEqual("FINISHED", RunStatus.to_string(RunStatus.FINISHED))
        self.assertEqual(RunStatus.FINISHED, RunStatus.from_string("FINISHED"))

        self.assertEqual("FAILED", RunStatus.to_string(RunStatus.FAILED))
        self.assertEqual(RunStatus.FAILED, RunStatus.from_string("FAILED"))

        self.assertEqual("KILLED", RunStatus.to_string(RunStatus.KILLED))
        self.assertEqual(RunStatus.KILLED, RunStatus.from_string("KILLED"))

        with self.assertRaises(Exception) as e:
            RunStatus.to_string(-120)
        self.assertIn("Could not get string corresponding to run status -120", str(e.exception))

        with self.assertRaises(Exception) as e:
            RunStatus.from_string("the IMPOSSIBLE status string")
        self.assertIn("Could not get run status corresponding to string the IMPO", str(e.exception))
Example #15
0
def validate_exit_status(status_str, expected):
    assert RunStatus.from_string(status_str) == expected