Пример #1
0
 def _print_description_and_log_tags(self):
     _logger.info(
         "=== Launched MLflow run as Databricks job run with ID %s."
         " Getting run status page URL... ===", self._databricks_run_id)
     run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
     jobs_page_url = run_info["run_page_url"]
     _logger.info("=== Check the run's status at %s ===", jobs_page_url)
     host_creds = databricks_utils.get_databricks_host_creds(
         self._job_runner.databricks_profile_uri)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_RUN_URL,
                                     jobs_page_url)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID,
                                     self._databricks_run_id)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_WEBAPP_URL,
                                     host_creds.host)
     job_id = run_info.get('job_id')
     # In some releases of Databricks we do not return the job ID. We start including it in DB
     # releases 2.80 and above.
     if job_id is not None:
         tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                         MLFLOW_DATABRICKS_SHELL_JOB_ID,
                                         job_id)
Пример #2
0
def _build_docker_image(work_dir, repository_uri, base_image, run_id):
    """
    Build a docker image containing the project in `work_dir`, using the base image.
    """
    image_uri = _get_docker_image_uri(repository_uri=repository_uri,
                                      work_dir=work_dir)
    dockerfile = ("FROM {imagename}\n"
                  "COPY {build_context_path}/ {workdir}\n"
                  "WORKDIR {workdir}\n").format(
                      imagename=base_image,
                      build_context_path=_PROJECT_TAR_ARCHIVE_NAME,
                      workdir=_MLFLOW_DOCKER_WORKDIR_PATH)
    build_ctx_path = _create_docker_build_ctx(work_dir, dockerfile)
    with open(build_ctx_path, 'rb') as docker_build_ctx:
        _logger.info("=== Building docker image %s ===", image_uri)
        client = docker.from_env()
        image, _ = client.images.build(tag=image_uri,
                                       forcerm=True,
                                       dockerfile=posixpath.join(
                                           _PROJECT_TAR_ARCHIVE_NAME,
                                           _GENERATED_DOCKERFILE_NAME),
                                       fileobj=docker_build_ctx,
                                       custom_context=True,
                                       encoding="gzip")
    try:
        os.remove(build_ctx_path)
    except Exception:  # pylint: disable=broad-except
        _logger.info("Temporary docker context file %s was not deleted.",
                     build_ctx_path)
    tracking.MlflowClient().set_tag(run_id, MLFLOW_DOCKER_IMAGE_URI, image_uri)
    tracking.MlflowClient().set_tag(run_id, MLFLOW_DOCKER_IMAGE_ID, image.id)
    return image
Пример #3
0
def test_log_metric():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = [123 for _ in range(100)]
        run_id = active_run.info.run_id
        kiwi.log_metric("name_1", 25)
        kiwi.log_metric("name_2", -3)
        kiwi.log_metric("name_1", 30, 5)
        kiwi.log_metric("name_1", 40, -2)
        kiwi.log_metric("nested/nested/name", 40)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 3
    expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    for key, value in finished_run.data.metrics.items():
        assert expected_pairs[key] == value
    client = tracking.MlflowClient()
    metric_history_name1 = client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history_name1]) == set([
                    (25, 123 * 1000, 0),
                    (30, 123 * 1000, 5),
                    (40, 123 * 1000, -2),
                ])
    metric_history_name2 = client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history_name2]) == set([
                    (-3, 123 * 1000, 0),
                ])
Пример #4
0
def _maybe_set_run_terminated(active_run, status):
    """
    If the passed-in active run is defined and still running (i.e. hasn't already been terminated
    within user code), mark it as terminated with the passed-in status.
    """
    if active_run is None:
        return
    run_id = active_run.info.run_id
    cur_status = tracking.MlflowClient().get_run(run_id).info.status
    if RunStatus.is_terminated(cur_status):
        return
    tracking.MlflowClient().set_terminated(run_id, status)
Пример #5
0
def test_log_metrics_uses_millisecond_timestamp_resolution_fluent():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = lambda: 123
        kiwi.log_metrics({
            "name_1": 25,
            "name_2": -3,
        })
        kiwi.log_metrics({
            "name_1": 30,
        })
        kiwi.log_metrics({
            "name_1": 40,
        })
        run_id = active_run.info.run_id

    client = tracking.MlflowClient()
    metric_history_name1 = client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set([
        (25, 123 * 1000),
        (30, 123 * 1000),
        (40, 123 * 1000),
    ])
    metric_history_name2 = client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([
        (-3, 123 * 1000),
    ])
Пример #6
0
def test_list_experiments():
    def _assert_exps(ids_to_lifecycle_stage, view_type_arg):
        result = set([
            (exp.experiment_id, exp.lifecycle_stage)
            for exp in client.list_experiments(view_type=view_type_arg)
        ])
        assert result == set([
            (exp_id, stage)
            for exp_id, stage in ids_to_lifecycle_stage.items()
        ])

    experiment_id = kiwi.create_experiment("exp_1")
    assert experiment_id == '1'
    client = tracking.MlflowClient()
    _assert_exps({
        '0': LifecycleStage.ACTIVE,
        '1': LifecycleStage.ACTIVE
    }, ViewType.ACTIVE_ONLY)
    _assert_exps({
        '0': LifecycleStage.ACTIVE,
        '1': LifecycleStage.ACTIVE
    }, ViewType.ALL)
    _assert_exps({}, ViewType.DELETED_ONLY)
    client.delete_experiment(experiment_id)
    _assert_exps({'0': LifecycleStage.ACTIVE}, ViewType.ACTIVE_ONLY)
    _assert_exps({
        '0': LifecycleStage.ACTIVE,
        '1': LifecycleStage.DELETED
    }, ViewType.ALL)
    _assert_exps({'1': LifecycleStage.DELETED}, ViewType.DELETED_ONLY)
Пример #7
0
def get_or_create_run(run_id, uri, experiment_id, work_dir, version,
                      entry_point, parameters):
    if run_id:
        return tracking.MlflowClient().get_run(run_id)
    else:
        return _create_run(uri, experiment_id, work_dir, version, entry_point,
                           parameters)
Пример #8
0
def test_log_batch_validates_entity_names_and_values():
    bad_kwargs = {
        "metrics": [
            [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)],
            [
                Metric(key="ok-name",
                       value="non-numerical-value",
                       timestamp=3,
                       step=0)
            ],
            [
                Metric(key="ok-name",
                       value=0.3,
                       timestamp="non-numerical-timestamp",
                       step=0)
            ],
        ],
        "params": [[Param(key="../bad/param/name", value="my-val")]],
        "tags": [[Param(key="../bad/tag/name", value="my-val")]],
    }
    with start_run() as active_run:
        for kwarg, bad_values in bad_kwargs.items():
            for bad_kwarg_value in bad_values:
                final_kwargs = {
                    "run_id": active_run.info.run_id,
                    "metrics": [],
                    "params": [],
                    "tags": [],
                }
                final_kwargs[kwarg] = bad_kwarg_value
                with pytest.raises(MlflowException) as e:
                    tracking.MlflowClient().log_batch(**final_kwargs)
                assert e.value.error_code == ErrorCode.Name(
                    INVALID_PARAMETER_VALUE)
Пример #9
0
def _resolve_experiment_id(experiment_name=None, experiment_id=None):
    """
    Resolve experiment.

    Verifies either one or other is specified - cannot be both selected.

    If ``experiment_name`` is provided and does not exist, an experiment
    of that name is created and its id is returned.

    :param experiment_name: Name of experiment under which to launch the run.
    :param experiment_id: ID of experiment under which to launch the run.
    :return: str
    """

    if experiment_name and experiment_id:
        raise MlflowException(
            "Specify only one of 'experiment_name' or 'experiment_id'.")

    if experiment_id:
        return str(experiment_id)

    if experiment_name:
        client = tracking.MlflowClient()
        exp = client.get_experiment_by_name(experiment_name)
        if exp:
            return exp.experiment_id
        else:
            print(
                "INFO: '{}' does not exist. Creating a new experiment".format(
                    experiment_name))
            return client.create_experiment(experiment_name)

    return _get_experiment_id()
Пример #10
0
def test_log_metric_validation():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        with pytest.raises(MlflowException) as e:
            kiwi.log_metric("name_1", "apple")
    assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
    finished_run = tracking.MlflowClient().get_run(run_id)
    assert len(finished_run.data.metrics) == 0
Пример #11
0
def test_start_and_end_run():
    # Use the start_run() and end_run() APIs without a `with` block, verify they work.

    with start_run() as active_run:
        kiwi.log_metric("name_1", 25)
    finished_run = tracking.MlflowClient().get_run(active_run.info.run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 1
    assert finished_run.data.metrics["name_1"] == 25
Пример #12
0
def test_start_run_context_manager():
    with start_run() as first_run:
        first_uuid = first_run.info.run_id
        # Check that start_run() causes the run information to be persisted in the store
        persisted_run = tracking.MlflowClient().get_run(first_uuid)
        assert persisted_run is not None
        assert persisted_run.info == first_run.info
    finished_run = tracking.MlflowClient().get_run(first_uuid)
    assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED)
    # Launch a separate run that fails, verify the run status is FAILED and the run UUID is
    # different
    with pytest.raises(Exception):
        with start_run() as second_run:
            second_run_id = second_run.info.run_id
            raise Exception("Failing run!")
    assert second_run_id != first_uuid
    finished_run2 = tracking.MlflowClient().get_run(second_run_id)
    assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED)
Пример #13
0
def test_set_experiment_with_deleted_experiment_name():
    name = "dead_exp"
    kiwi.set_experiment(name)
    with start_run() as run:
        exp_id = run.info.experiment_id

    tracking.MlflowClient().delete_experiment(exp_id)

    with pytest.raises(MlflowException):
        kiwi.set_experiment(name)
Пример #14
0
def test_create_experiment_with_duplicate_name():
    name = "popular_name"
    exp_id = kiwi.create_experiment(name)

    with pytest.raises(MlflowException):
        kiwi.create_experiment(name)

    tracking.MlflowClient().delete_experiment(exp_id)
    with pytest.raises(MlflowException):
        kiwi.create_experiment(name)
Пример #15
0
def test_start_deleted_run():
    run_id = None
    with kiwi.start_run() as active_run:
        run_id = active_run.info.run_id
    tracking.MlflowClient().delete_run(run_id)
    with pytest.raises(MlflowException,
                       matches='because it is in the deleted state.'):
        with kiwi.start_run(run_id=run_id):
            pass
    assert kiwi.active_run() is None
Пример #16
0
def test_log_params():
    expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
    with start_run() as active_run:
        run_id = active_run.info.run_id
        kiwi.log_params(expected_params)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate params
    assert finished_run.data.params == {
        "name_1": "c",
        "name_2": "b",
        "nested/nested/name": "5"
    }
Пример #17
0
def test_log_param():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        kiwi.log_param("name_1", "a")
        kiwi.log_param("name_2", "b")
        kiwi.log_param("nested/nested/name", 5)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate params
    assert finished_run.data.params == {
        "name_1": "a",
        "name_2": "b",
        "nested/nested/name": "5"
    }
Пример #18
0
def test_log_metrics_uses_common_timestamp_and_step_per_invocation(step_kwarg):
    expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    with start_run() as active_run:
        run_id = active_run.info.run_id
        kiwi.log_metrics(expected_metrics, step=step_kwarg)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate metric key/values match what we expect, and that all metrics have the same timestamp
    assert len(finished_run.data.metrics) == len(expected_metrics)
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    common_timestamp = finished_run.data._metric_objs[0].timestamp
    expected_step = step_kwarg if step_kwarg is not None else 0
    for metric_obj in finished_run.data._metric_objs:
        assert metric_obj.timestamp == common_timestamp
        assert metric_obj.step == expected_step
Пример #19
0
def test_set_tags():
    exact_expected_tags = {
        "name_1": "c",
        "name_2": "b",
        "nested/nested/name": 5
    }
    approx_expected_tags = set(
        [MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])
    with start_run() as active_run:
        run_id = active_run.info.run_id
        kiwi.set_tags(exact_expected_tags)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate tags
    assert len(finished_run.data.tags
               ) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_val in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert str(exact_expected_tags[tag_key]) == tag_val
Пример #20
0
def _wait_for(submitted_run_obj):
    """Wait on the passed-in submitted run, reporting its status to the tracking server."""
    run_id = submitted_run_obj.run_id
    active_run = None
    # Note: there's a small chance we fail to report the run's status to the tracking server if
    # we're interrupted before we reach the try block below
    try:
        active_run = tracking.MlflowClient().get_run(
            run_id) if run_id is not None else None
        if submitted_run_obj.wait():
            _logger.info("=== Run (ID '%s') succeeded ===", run_id)
            _maybe_set_run_terminated(active_run, "FINISHED")
        else:
            _maybe_set_run_terminated(active_run, "FAILED")
            raise ExecutionException("Run (ID '%s') failed" % run_id)
    except KeyboardInterrupt:
        _logger.error("=== Run (ID '%s') interrupted, cancelling run ===",
                      run_id)
        submitted_run_obj.cancel()
        _maybe_set_run_terminated(active_run, "FAILED")
        raise
Пример #21
0
def test_log_metrics_uses_millisecond_timestamp_resolution_client():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = lambda: 123
        mlflow_client = tracking.MlflowClient()
        run_id = active_run.info.run_id

        mlflow_client.log_metric(run_id=run_id, key="name_1", value=25)
        mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3)
        mlflow_client.log_metric(run_id=run_id, key="name_1", value=30)
        mlflow_client.log_metric(run_id=run_id, key="name_1", value=40)

    metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set([
        (25, 123 * 1000),
        (30, 123 * 1000),
        (40, 123 * 1000),
    ])
    metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([
        (-3, 123 * 1000),
    ])
Пример #22
0
 def verify_has_parent_id_tag(child_id, expected_parent_id):
     tags = tracking.MlflowClient().get_run(child_id).data.tags
     assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id
Пример #23
0
def _run(uri, experiment_id, entry_point, version, parameters, docker_args,
         backend_name, backend_config, use_conda, storage_dir, synchronous):
    """
    Helper that delegates to the project-running method corresponding to the passed-in backend.
    Returns a ``SubmittedRun`` corresponding to the project run.
    """
    tracking_store_uri = tracking.get_tracking_uri()
    # TODO: remove this check once local, databricks, kubernetes execution have been refactored
    # into their own built-in project execution backends.
    if backend_name not in {"local", "databricks", "kubernetes"}:
        backend = loader.load_backend(backend_name)
        if backend:
            submitted_run = backend.run(uri, entry_point, parameters, version,
                                        backend_config, experiment_id,
                                        tracking_store_uri)
            tracking.MlflowClient().set_tag(submitted_run.run_id,
                                            MLFLOW_PROJECT_BACKEND,
                                            backend_name)
            return submitted_run

    work_dir = fetch_and_validate_project(uri, version, entry_point,
                                          parameters)
    project = load_project(work_dir)
    _validate_execution_environment(project, backend_name)

    existing_run_id = None
    if backend_name == "local" and _MLFLOW_LOCAL_BACKEND_RUN_ID_CONFIG in backend_config:
        existing_run_id = backend_config[_MLFLOW_LOCAL_BACKEND_RUN_ID_CONFIG]
    active_run = get_or_create_run(existing_run_id, uri, experiment_id,
                                   work_dir, version, entry_point, parameters)

    if backend_name == "databricks":
        tracking.MlflowClient().set_tag(active_run.info.run_id,
                                        MLFLOW_PROJECT_BACKEND, "databricks")
        from kiwi.projects.databricks import run_databricks
        return run_databricks(remote_run=active_run,
                              uri=uri,
                              entry_point=entry_point,
                              work_dir=work_dir,
                              parameters=parameters,
                              experiment_id=experiment_id,
                              cluster_spec=backend_config)

    elif backend_name == "local":
        tracking.MlflowClient().set_tag(active_run.info.run_id,
                                        MLFLOW_PROJECT_BACKEND, "local")
        command_args = []
        command_separator = " "
        # If a docker_env attribute is defined in MLproject then it takes precedence over conda yaml
        # environments, so the project will be executed inside a docker container.
        if project.docker_env:
            tracking.MlflowClient().set_tag(active_run.info.run_id,
                                            MLFLOW_PROJECT_ENV, "docker")
            _validate_docker_env(project)
            _validate_docker_installation()
            image = _build_docker_image(
                work_dir=work_dir,
                repository_uri=project.name,
                base_image=project.docker_env.get('image'),
                run_id=active_run.info.run_id)
            command_args += _get_docker_command(
                image=image,
                active_run=active_run,
                docker_args=docker_args,
                volumes=project.docker_env.get("volumes"),
                user_env_vars=project.docker_env.get("environment"))
        # Synchronously create a conda environment (even though this may take some time)
        # to avoid failures due to multiple concurrent attempts to create the same conda env.
        elif use_conda:
            tracking.MlflowClient().set_tag(active_run.info.run_id,
                                            MLFLOW_PROJECT_ENV, "conda")
            command_separator = " && "
            conda_env_name = _get_or_create_conda_env(project.conda_env_path)
            command_args += _get_conda_command(conda_env_name)
        # In synchronous mode, run the entry point command in a blocking fashion, sending status
        # updates to the tracking server when finished. Note that the run state may not be
        # persisted to the tracking server if interrupted
        if synchronous:
            command_args += _get_entry_point_command(project, entry_point,
                                                     parameters, storage_dir)
            command_str = command_separator.join(command_args)
            return _run_entry_point(command_str,
                                    work_dir,
                                    experiment_id,
                                    run_id=active_run.info.run_id)
        # Otherwise, invoke `mlflow run` in a subprocess
        return _invoke_mlflow_run_subprocess(work_dir=work_dir,
                                             entry_point=entry_point,
                                             parameters=parameters,
                                             experiment_id=experiment_id,
                                             use_conda=use_conda,
                                             storage_dir=storage_dir,
                                             run_id=active_run.info.run_id)
    elif backend_name == "kubernetes":
        from kiwi.projects import kubernetes as kb
        tracking.MlflowClient().set_tag(active_run.info.run_id,
                                        MLFLOW_PROJECT_ENV, "docker")
        tracking.MlflowClient().set_tag(active_run.info.run_id,
                                        MLFLOW_PROJECT_BACKEND, "kubernetes")
        _validate_docker_env(project)
        _validate_docker_installation()
        kube_config = _parse_kubernetes_config(backend_config)
        image = _build_docker_image(
            work_dir=work_dir,
            repository_uri=kube_config["repository-uri"],
            base_image=project.docker_env.get('image'),
            run_id=active_run.info.run_id)
        image_digest = kb.push_image_to_registry(image.tags[0])
        submitted_run = kb.run_kubernetes_job(
            project.name, active_run, image.tags[0], image_digest,
            _get_entry_point_command(project, entry_point, parameters,
                                     storage_dir),
            _get_run_env_vars(run_id=active_run.info.run_uuid,
                              experiment_id=active_run.info.experiment_id),
            kube_config.get('kube-context', None),
            kube_config['kube-job-template'])
        return submitted_run

    supported_backends = ["local", "databricks", "kubernetes"]
    raise ExecutionException("Got unsupported execution mode %s. Supported "
                             "values: %s" % (backend_name, supported_backends))
Пример #24
0
def _create_run(uri, experiment_id, work_dir, version, entry_point,
                parameters):
    """
    Create a ``Run`` against the current MLflow tracking server, logging metadata (e.g. the URI,
    entry point, and parameters of the project) about the run. Return an ``ActiveRun`` that can be
    used to report additional data about the run (metrics/params) to the tracking server.
    """
    if _is_local_uri(uri):
        source_name = tracking._tracking_service.utils._get_git_url_if_present(
            _expand_uri(uri))
    else:
        source_name = _expand_uri(uri)
    source_version = _get_git_commit(work_dir)
    git_diff = _get_git_diff(work_dir)
    existing_run = fluent.active_run()
    if existing_run:
        parent_run_id = existing_run.info.run_id
    else:
        parent_run_id = None

    tags = {
        MLFLOW_USER: _get_user(),
        MLFLOW_SOURCE_NAME: source_name,
        MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.PROJECT),
        MLFLOW_PROJECT_ENTRY_POINT: entry_point
    }

    # System tags - all include a string header, thus none of them is empty
    tags[KIWI_SYSTEM_HW_CPU] = _get_cpu_info()
    tags[KIWI_SYSTEM_HW_MEMORY] = _get_mem_info()
    tags[KIWI_SYSTEM_HW_DISK] = _get_disk_info()
    tags[KIWI_SYSTEM_HW_GPU] = _get_gpu_info()
    tags[KIWI_SYSTEM_OS] = _get_os_info()

    if source_version is not None:
        tags[MLFLOW_GIT_COMMIT] = source_version
        tags[MLFLOW_GIT_DIFF] = git_diff
    if parent_run_id is not None:
        tags[MLFLOW_PARENT_RUN_ID] = parent_run_id

    repo_url = _get_git_repo_url(work_dir)
    if repo_url is not None:
        tags[MLFLOW_GIT_REPO_URL] = repo_url
        tags[LEGACY_MLFLOW_GIT_REPO_URL] = repo_url

    # Add branch name tag if a branch is specified through -version
    if _is_valid_branch_name(work_dir, version):
        tags[MLFLOW_GIT_BRANCH] = version
        tags[LEGACY_MLFLOW_GIT_BRANCH_NAME] = version
    active_run = tracking.MlflowClient().create_run(
        experiment_id=experiment_id, tags=tags)

    project = _project_spec.load_project(work_dir)
    # Consolidate parameters for logging.
    # `storage_dir` is `None` since we want to log actual path not downloaded local path
    entry_point_obj = project.get_entry_point(entry_point)
    final_params, extra_params = entry_point_obj.compute_parameters(
        parameters, storage_dir=None)
    params_list = [
        Param(key, value) for key, value in list(final_params.items()) +
        list(extra_params.items())
    ]
    tracking.MlflowClient().log_batch(active_run.info.run_id,
                                      params=params_list)
    return active_run
Пример #25
0
def test_log_batch():
    expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
    expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
    exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
    approx_expected_tags = set(
        [MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

    t = int(time.time())
    sorted_expected_metrics = sorted(expected_metrics.items(),
                                     key=lambda kv: kv[0])
    metrics = [
        Metric(key=key, value=value, timestamp=t, step=i)
        for i, (key, value) in enumerate(sorted_expected_metrics)
    ]
    params = [
        Param(key=key, value=value) for key, value in expected_params.items()
    ]
    tags = [
        RunTag(key=key, value=value)
        for key, value in exact_expected_tags.items()
    ]

    with start_run() as active_run:
        run_id = active_run.info.run_id
        kiwi.tracking.MlflowClient().log_batch(run_id=run_id,
                                               metrics=metrics,
                                               params=params,
                                               tags=tags)
    client = tracking.MlflowClient()
    finished_run = client.get_run(run_id)
    # Validate metrics
    assert len(finished_run.data.metrics) == 2
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    metric_history0 = client.get_metric_history(run_id, "metric-key0")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history0]) == set([
                    (1.0, t, 0),
                ])
    metric_history1 = client.get_metric_history(run_id, "metric-key1")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history1]) == set([
                    (4.0, t, 1),
                ])

    # Validate tags (for automatically-set tags)
    assert len(finished_run.data.tags
               ) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_value in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert exact_expected_tags[tag_key] == tag_value
    # Validate params
    assert finished_run.data.params == expected_params
    # test that log_batch works with fewer params
    new_tags = {"1": "2", "3": "4", "5": "6"}
    tags = [RunTag(key=key, value=value) for key, value in new_tags.items()]
    client.log_batch(run_id=run_id, tags=tags)
    finished_run_2 = client.get_run(run_id)
    # Validate tags (for automatically-set tags)
    assert len(finished_run_2.data.tags) == len(finished_run.data.tags) + 3
    for tag_key, tag_value in finished_run_2.data.tags.items():
        if tag_key in new_tags:
            assert new_tags[tag_key] == tag_value