Exemplo n.º 1
0
def test_start_and_end_run():
    # Use the start_run() and end_run() APIs without a `with` block, verify they work.

    with start_run() as active_run:
        mlflow.log_metric("name_1", 25)
    finished_run = tracking.MlflowClient().get_run(active_run.info.run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 1
    assert finished_run.data.metrics["name_1"] == 25
Exemplo n.º 2
0
def test_start_run_context_manager():
    with start_run() as first_run:
        first_uuid = first_run.info.run_id
        # Check that start_run() causes the run information to be persisted in the store
        persisted_run = tracking.MlflowClient().get_run(first_uuid)
        assert persisted_run is not None
        assert persisted_run.info == first_run.info
    finished_run = tracking.MlflowClient().get_run(first_uuid)
    assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED)
    # Launch a separate run that fails, verify the run status is FAILED and the run UUID is
    # different
    with pytest.raises(Exception):
        with start_run() as second_run:
            second_run_id = second_run.info.run_id
            raise Exception("Failing run!")
    assert second_run_id != first_uuid
    finished_run2 = tracking.MlflowClient().get_run(second_run_id)
    assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED)
Exemplo n.º 3
0
def test_create_experiment_with_duplicate_name(tracking_uri_mock):
    name = "popular_name"
    exp_id = mlflow.create_experiment(name)

    with pytest.raises(MlflowException):
        mlflow.create_experiment(name)

    tracking.MlflowClient().delete_experiment(exp_id)
    with pytest.raises(MlflowException):
        mlflow.create_experiment(name)
Exemplo n.º 4
0
def test_log_param(tracking_uri_mock):
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.log_param("name_1", "a")
        mlflow.log_param("name_2", "b")
        mlflow.log_param("name_1", "c")
        mlflow.log_param("nested/nested/name", 5)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate params
    assert finished_run.data.params == {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
Exemplo n.º 5
0
def test_log_batch_duplicate_entries_raises():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        with pytest.raises(
            MlflowException, match=r"Duplicate parameter keys have been submitted."
        ) as e:
            tracking.MlflowClient().log_batch(
                run_id=run_id, params=[Param("a", "1"), Param("a", "2")]
            )
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Exemplo n.º 6
0
def test_set_experiment_with_deleted_experiment_name(tracking_uri_mock):
    name = "dead_exp"
    mlflow.set_experiment(name)
    with start_run() as run:
        exp_id = run.info.experiment_id

    tracking.MlflowClient().delete_experiment(exp_id)

    with pytest.raises(MlflowException):
        mlflow.set_experiment(name)
Exemplo n.º 7
0
def dump(run_id, artifact_max_level):
    print("Options:")
    print("  run_id:",run_id)
    print("  artifact_max_level:",artifact_max_level)

    client = tracking.MlflowClient()

    run = client.get_run(run_id)
    dump_run(run)
    dump_artifacts(client, run_id,"",INDENT_INC, artifact_max_level)
Exemplo n.º 8
0
def test_start_deleted_run():
    run_id = None
    with mlflow.start_run() as active_run:
        run_id = active_run.info.run_uuid
    tracking.MlflowClient().delete_run(run_id)
    with pytest.raises(MlflowException,
                       matches='because it is in the deleted state.'):
        with mlflow.start_run(run_uuid=run_id):
            pass
    assert mlflow.active_run() is None
Exemplo n.º 9
0
 def _print_description_and_log_tags(self):
     eprint("=== Launched MLflow run as Databricks job run with ID %s. Getting run status "
            "page URL... ===" % self._databricks_run_id)
     run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
     jobs_page_url = run_info["run_page_url"]
     eprint("=== Check the run's status at %s ===" % jobs_page_url)
     host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
     job_id = run_info.get('job_id')
     # In some releases of Databricks we do not return the job ID. We start including it in DB
     # releases 2.80 and above.
     if job_id is not None:
         tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                         MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
Exemplo n.º 10
0
def test_log_batch(tracking_uri_mock, tmpdir):
    expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
    expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
    exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
    approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

    t = int(time.time())
    sorted_expected_metrics = sorted(expected_metrics.items(),
                                     key=lambda kv: kv[0])
    metrics = [
        Metric(key=key, value=value, timestamp=t, step=i)
        for i, (key, value) in enumerate(sorted_expected_metrics)
    ]
    params = [
        Param(key=key, value=value) for key, value in expected_params.items()
    ]
    tags = [
        RunTag(key=key, value=value)
        for key, value in exact_expected_tags.items()
    ]

    with start_run() as active_run:
        run_uuid = active_run.info.run_uuid
        mlflow.tracking.MlflowClient().log_batch(run_id=run_uuid,
                                                 metrics=metrics,
                                                 params=params,
                                                 tags=tags)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate metrics
    assert len(finished_run.data.metrics) == 2
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    # TODO: use client get_metric_history API here instead once it exists
    fs = FileStore(os.path.join(tmpdir.strpath, "mlruns"))
    metric_history0 = fs.get_metric_history(run_uuid, "metric-key0")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history0]) == set([
                    (1.0, t, 0),
                ])
    metric_history1 = fs.get_metric_history(run_uuid, "metric-key1")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history1]) == set([
                    (4.0, t, 1),
                ])

    # Validate tags (for automatically-set tags)
    assert len(finished_run.data.tags
               ) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_value in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert exact_expected_tags[tag_key] == tag_value
    # Validate params
    assert finished_run.data.params == expected_params
Exemplo n.º 11
0
def _build_docker_image(work_dir, project, active_run):
    """
    Build a docker image containing the project in `work_dir`, using the base image and tagging the
    built image with the project name specified by `project`.
    """
    if not project.name:
        raise ExecutionException(
            "Project name in MLProject must be specified when using docker "
            "for image tagging.")
    tag_name = "mlflow-{name}-{version}".format(
        name=(project.name if project.name else "docker-project"),
        version=_get_git_commit(work_dir)[:7],
    )
    dockerfile = ("FROM {imagename}\n"
                  "LABEL Name={tag_name}\n"
                  "COPY {build_context_path}/* /mlflow/projects/code/\n"
                  "WORKDIR /mlflow/projects/code/\n").format(
                      imagename=project.docker_env.get('image'),
                      tag_name=tag_name,
                      build_context_path=_PROJECT_TAR_ARCHIVE_NAME)
    build_ctx_path = _create_docker_build_ctx(work_dir, dockerfile)
    with open(build_ctx_path, 'rb') as docker_build_ctx:
        _logger.info("=== Building docker image %s ===", tag_name)
        client = docker.from_env()
        image = client.images.build(tag=tag_name,
                                    forcerm=True,
                                    dockerfile=posixpath.join(
                                        _PROJECT_TAR_ARCHIVE_NAME,
                                        _GENERATED_DOCKERFILE_NAME),
                                    fileobj=docker_build_ctx,
                                    custom_context=True,
                                    encoding="gzip")
    try:
        os.remove(build_ctx_path)
    except Exception:  # pylint: disable=broad-except
        _logger.info("Temporary docker context file %s was not deleted.",
                     build_ctx_path)
    tracking.MlflowClient().set_tag(active_run.info.run_uuid,
                                    MLFLOW_DOCKER_IMAGE_NAME, tag_name)
    tracking.MlflowClient().set_tag(active_run.info.run_uuid,
                                    MLFLOW_DOCKER_IMAGE_ID, image[0].id)
    return tag_name
Exemplo n.º 12
0
def test_log_batch_validates_entity_names_and_values():
    with start_run() as active_run:
        run_id = active_run.info.run_id

        metrics = [
            Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)
        ]
        with pytest.raises(MlflowException, match="Invalid metric name") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value="non-numerical-value",
                   timestamp=3,
                   step=0)
        ]
        with pytest.raises(MlflowException, match="Got invalid value") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value=0.3,
                   timestamp="non-numerical-timestamp",
                   step=0)
        ]
        with pytest.raises(MlflowException,
                           match="Got invalid timestamp") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        params = [Param(key="../bad/param/name", value="my-val")]
        with pytest.raises(MlflowException,
                           match="Invalid parameter name") as e:
            tracking.MlflowClient().log_batch(run_id, params=params)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        tags = [Param(key="../bad/tag/name", value="my-val")]
        with pytest.raises(MlflowException, match="Invalid tag name") as e:
            tracking.MlflowClient().log_batch(run_id, tags=tags)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Exemplo n.º 13
0
def test_log_metric_validation():
    try:
        tracking.set_tracking_uri(tempfile.mkdtemp())
        active_run = start_run()
        run_uuid = active_run.info.run_uuid
        with active_run:
            mlflow.log_metric("name_1", "apple")
        finished_run = tracking.MlflowClient().get_run(run_uuid)
        assert len(finished_run.data.metrics) == 0
    finally:
        tracking.set_tracking_uri(None)
Exemplo n.º 14
0
def test_log_params(tracking_uri_mock):
    expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
    active_run = start_run()
    run_uuid = active_run.info.run_uuid
    with active_run:
        mlflow.log_params(expected_params)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate params
    assert len(finished_run.data.params) == 3
    for param in finished_run.data.params:
        assert expected_params[param.key] == param.value
Exemplo n.º 15
0
def test_start_and_end_run(tracking_uri_mock):
    # Use the start_run() and end_run() APIs without a `with` block, verify they work.
    active_run = start_run()
    mlflow.log_metric("name_1", 25)
    end_run()
    finished_run = tracking.MlflowClient().get_run(active_run.info.run_uuid)
    # Validate metrics
    assert len(finished_run.data.metrics) == 1
    expected_pairs = {"name_1": 25}
    for metric in finished_run.data.metrics:
        assert expected_pairs[metric.key] == metric.value
Exemplo n.º 16
0
def _resolve_runs_uri(uri):
    from mlflow.utils import mlflow_tags
    client = tracking.MlflowClient()
    toks = uri.split("/")
    if len(toks) < 2:
        raise Exception("Bad runs URI")
    run_id = toks[1]
    run = client.get_run(run_id)
    uri = _get_tag(run, mlflow_tags.MLFLOW_SOURCE_NAME)
    version = _get_tag(run, mlflow_tags.MLFLOW_GIT_COMMIT)
    return uri
Exemplo n.º 17
0
def test_log_params(tracking_uri_mock):
    expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
    with start_run() as active_run:
        run_uuid = active_run.info.run_uuid
        mlflow.log_params(expected_params)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate params
    assert finished_run.data.params == {
        "name_1": "c",
        "name_2": "b",
        "nested/nested/name": "5"
    }
Exemplo n.º 18
0
def test_list_experiments_paginated_errors():
    client = tracking.MlflowClient()
    # test that providing a completely invalid page token throws
    with pytest.raises(MlflowException, match="Invalid page token") as exception_context:
        client.list_experiments(page_token="evilhax", max_results=20)
    assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    # test that providing too large of a max_results throws
    with pytest.raises(
        MlflowException, match="Invalid value for request parameter max_results"
    ) as exception_context:
        client.list_experiments(page_token=None, max_results=int(1e15))
    assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Exemplo n.º 19
0
def test_log_metrics(tracking_uri_mock):
    active_run = start_run()
    run_uuid = active_run.info.run_uuid
    expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    with active_run:
        mlflow.log_metrics(expected_metrics)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate metric key/values match what we expect, and that all metrics have the same timestamp
    common_timestamp = finished_run.data.metrics[0].timestamp
    assert len(finished_run.data.metrics) == len(expected_metrics)
    for metric in finished_run.data.metrics:
        assert expected_metrics[metric.key] == metric.value
        assert metric.timestamp == common_timestamp
Exemplo n.º 20
0
def test_log_params_duplicate_keys_raises():
    params = {"a": "1", "b": "2"}
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.log_params(params)
        with pytest.raises(
            expected_exception=MlflowException,
            match=r"Changing param values is not allowed. Param with key=",
        ) as e:
            mlflow.log_param("a", "3")
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
    finished_run = tracking.MlflowClient().get_run(run_id)
    assert finished_run.data.params == params
Exemplo n.º 21
0
def test_log_batch(tracking_uri_mock, tmpdir):
    expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
    expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
    exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
    approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

    t = int(time.time())
    sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0])
    metrics = [Metric(key=key, value=value, timestamp=t, step=i)
               for i, (key, value) in enumerate(sorted_expected_metrics)]
    params = [Param(key=key, value=value) for key, value in expected_params.items()]
    tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()]

    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.tracking.MlflowClient().log_batch(run_id=run_id, metrics=metrics, params=params,
                                                 tags=tags)
    client = tracking.MlflowClient()
    finished_run = client.get_run(run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 2
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    metric_history0 = client.get_metric_history(run_id, "metric-key0")
    assert set([(m.value, m.timestamp, m.step) for m in metric_history0]) == set([
        (1.0, t, 0),
    ])
    metric_history1 = client.get_metric_history(run_id, "metric-key1")
    assert set([(m.value, m.timestamp, m.step) for m in metric_history1]) == set([
        (4.0, t, 1),
    ])

    # Validate tags (for automatically-set tags)
    assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_value in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert exact_expected_tags[tag_key] == tag_value
    # Validate params
    assert finished_run.data.params == expected_params
    # test that log_batch works with fewer params
    new_tags = {"1": "2", "3": "4", "5": "6"}
    tags = [RunTag(key=key, value=value) for key, value in new_tags.items()]
    client.log_batch(run_id=run_id, tags=tags)
    finished_run_2 = client.get_run(run_id)
    # Validate tags (for automatically-set tags)
    assert len(finished_run_2.data.tags) == len(finished_run.data.tags) + 3
    for tag_key, tag_value in finished_run_2.data.tags.items():
        if tag_key in new_tags:
            assert new_tags[tag_key] == tag_value
Exemplo n.º 22
0
def test_set_tags():
    exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
    approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.set_tags(exact_expected_tags)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate tags
    assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_val in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert str(exact_expected_tags[tag_key]) == tag_val
Exemplo n.º 23
0
def test_create_experiment_with_duplicate_name():
    name = "popular_name"
    exp_id = mlflow.create_experiment(name)

    with pytest.raises(
            MlflowException,
            match=re.escape(f"Experiment(name={name}) already exists")):
        mlflow.create_experiment(name)

    tracking.MlflowClient().delete_experiment(exp_id)
    with pytest.raises(
            MlflowException,
            match=re.escape(f"Experiment(name={name}) already exists")):
        mlflow.create_experiment(name)
Exemplo n.º 24
0
def test_log_metric(tracking_uri_mock):
    active_run = start_run()
    run_uuid = active_run.info.run_uuid
    with active_run:
        mlflow.log_metric("name_1", 25)
        mlflow.log_metric("name_2", -3)
        mlflow.log_metric("name_1", 30)
        mlflow.log_metric("nested/nested/name", 40)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate metrics
    assert len(finished_run.data.metrics) == 3
    expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    for metric in finished_run.data.metrics:
        assert expected_pairs[metric.key] == metric.value
Exemplo n.º 25
0
def test_start_run_context_manager():
    try:
        tracking.set_tracking_uri(tempfile.mkdtemp())
        first_run = start_run()
        first_uuid = first_run.info.run_uuid
        with first_run:
            # Check that start_run() causes the run information to be persisted in the store
            persisted_run = tracking.MlflowClient().get_run(first_uuid)
            assert persisted_run is not None
            assert persisted_run.info == first_run.info
        finished_run = tracking.MlflowClient().get_run(first_uuid)
        assert finished_run.info.status == RunStatus.FINISHED
        # Launch a separate run that fails, verify the run status is FAILED and the run UUID is
        # different
        second_run = start_run()
        assert second_run.info.run_uuid != first_uuid
        with pytest.raises(Exception):
            with second_run:
                raise Exception("Failing run!")
        finished_run2 = tracking.MlflowClient().get_run(second_run.info.run_uuid)
        assert finished_run2.info.status == RunStatus.FAILED
    finally:
        tracking.set_tracking_uri(None)
Exemplo n.º 26
0
def test_log_param(tracking_uri_mock):
    print(tracking.get_tracking_uri())
    active_run = start_run()
    run_uuid = active_run.info.run_uuid
    with active_run:
        mlflow.log_param("name_1", "a")
        mlflow.log_param("name_2", "b")
        mlflow.log_param("name_1", "c")
        mlflow.log_param("nested/nested/name", 5)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate params
    assert len(finished_run.data.params) == 3
    expected_pairs = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
    for param in finished_run.data.params:
        assert expected_pairs[param.key] == param.value
Exemplo n.º 27
0
def test_set_experiment_with_deleted_experiment():
    name = "dead_exp"
    mlflow.set_experiment(name)
    with start_run() as run:
        exp_id = run.info.experiment_id

    tracking.MlflowClient().delete_experiment(exp_id)

    with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc:
        mlflow.set_experiment(name)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

    with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc:
        mlflow.set_experiment(experiment_id=exp_id)
    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Exemplo n.º 28
0
def test_log_metric():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = [123 for _ in range(100)]
        run_id = active_run.info.run_id
        mlflow.log_metric("name_1", 25)
        mlflow.log_metric("name_2", -3)
        mlflow.log_metric("name_1", 30, 5)
        mlflow.log_metric("name_1", 40, -2)
        mlflow.log_metric("nested/nested/name", 40)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 3
    expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    for key, value in finished_run.data.metrics.items():
        assert expected_pairs[key] == value
    client = tracking.MlflowClient()
    metric_history_name1 = client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp, m.step) for m in metric_history_name1]) == set(
        [(25, 123 * 1000, 0), (30, 123 * 1000, 5), (40, 123 * 1000, -2)]
    )
    metric_history_name2 = client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp, m.step) for m in metric_history_name2]) == set(
        [(-3, 123 * 1000, 0)]
    )
Exemplo n.º 29
0
def test_log_metrics_uses_millisecond_timestamp_resolution_fluent():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = lambda: 123
        mlflow.log_metrics({"name_1": 25, "name_2": -3})
        mlflow.log_metrics({"name_1": 30})
        mlflow.log_metrics({"name_1": 40})
        run_id = active_run.info.run_id

    client = tracking.MlflowClient()
    metric_history_name1 = client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set(
        [(25, 123 * 1000), (30, 123 * 1000), (40, 123 * 1000)]
    )
    metric_history_name2 = client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([(-3, 123 * 1000)])
Exemplo n.º 30
0
def test_log_metrics(tracking_uri_mock, step_kwarg):
    expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    with start_run() as active_run:
        run_uuid = active_run.info.run_uuid
        mlflow.log_metrics(expected_metrics, step=step_kwarg)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate metric key/values match what we expect, and that all metrics have the same timestamp
    assert len(finished_run.data.metrics) == len(expected_metrics)
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    common_timestamp = finished_run.data._metric_objs[0].timestamp
    expected_step = step_kwarg if step_kwarg is not None else 0
    for metric_obj in finished_run.data._metric_objs:
        assert metric_obj.timestamp == common_timestamp
        assert metric_obj.step == expected_step