Пример #1
0
def test_set_active_run(context, run_id):

    with patch.object(MlFlow, "_setup"):
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)

    with patch.object(MlFlow, "_start_run") as mock_start_run:
        # When _set_active_run is called
        mlf._set_active_run(run_id=run_id)  # pylint: disable=protected-access

    # And: the run is nested
    if mlf.parent_run_id is not None:
        # Then:
        # - the parent run is started if required
        mock_start_run.assert_any_call(run_id=mlf.parent_run_id,
                                       run_name=mlf.run_name)
        # - mlflow.start_run is called with nested=True
        mock_start_run.assert_any_call(run_id=run_id,
                                       run_name=mlf.run_name,
                                       nested=True)
        # - _start_run is called twice
        assert mock_start_run.call_count == 2
    # And: the run is not nested
    else:
        # Then:
        mock_start_run.assert_called_once_with(run_id=run_id,
                                               run_name=mlf.run_name,
                                               nested=False)
Пример #2
0
def test_log_params(mock_log_params, context, num_of_params, string_maker):
    # Given: init of MlFlow
    mlf = MlFlow(context)
    # And: a set of parameters
    param = {string_maker(5): string_maker(5) for _ in range(num_of_params)}
    # When: log_params is called
    mlf.log_params(param)
    # Then mock_log_params is called the correct number of times
    assert mock_log_params.call_count == num_of_params // 100 + (1 if num_of_params % 100 else 0)
Пример #3
0
def test_get_current_run_id(context, experiment, run_df):
    # Given: an initialization of the mlflow object
    mlf = MlFlow(context)

    with patch("mlflow.search_runs", return_value=run_df):
        # when: _get_current_run_id is called
        run_id = mlf._get_current_run_id(experiment=experiment)  # pylint: disable=protected-access
    # Then: the run_id id provided is the same as what was provided
    if not run_df.empty:
        assert run_id == run_df.run_id.values[0]
    else:
        assert run_id is None
Пример #4
0
def test_chunks(context, num_of_params, string_maker, chunk):
    # Given: init of MLFlow
    mlf = MlFlow(context)
    # And: a dictionary
    D = {string_maker(5): string_maker(5) for _ in range(num_of_params)}
    # When: dictionary is chunked
    param_chunks_list = [param_chunk for param_chunk in mlf.chunks(D, chunk)]

    # Then
    # - the number of chunks is what is expected
    assert len(param_chunks_list) == num_of_params // chunk + (1 if num_of_params % chunk else 0)
    # - the unwrapped dictionary is the same as was set
    assert {k: v for d in param_chunks_list for k, v in d.items()} == D
Пример #5
0
def test_set_all_tags(mock_mlflow_set_tags, context):
    with patch.object(MlFlow, "_setup"):
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)
    # When all the tags are set
    mlf._set_all_tags()  # pylint: disable=protected-access

    # Given: the tags that should be set in mlflow
    tags = {
        tag: context.resource_config["env"][tag] for tag in context.resource_config["env_to_tag"]
    }
    tags["dagster_run_id"] = mlf.dagster_run_id
    if mlf.extra_tags:
        tags.update(mlf.extra_tags)
    # Then the Mlflow.set_tags is called with the set tags
    mock_mlflow_set_tags.assert_called_once_with(tags)
Пример #6
0
def test_set_active_run_parent_zero(child_context):
    # Given: a parent_run_id of zero
    child_context.resource_config["parent_run_id"] = 0
    # : an initialization of the mlflow object
    mlf = MlFlow(child_context)

    with patch.object(MlFlow, "_start_run") as mock_start_run:
        # And _set_active_run is called with run_id
        mlf._set_active_run(run_id="what-is-an-edge-case")  # pylint: disable=protected-access
        # Then: _start_run_by_id is called with the parent_id
        mock_start_run.assert_any_call(run_id=mlf.parent_run_id, run_name=mlf.run_name)
        # And: mlflow.start_run is called with the run_name and nested=True
        mock_start_run.assert_any_call(
            run_id="what-is-an-edge-case", run_name=mlf.run_name, nested=True
        )
        # And _start_run is called twice
        assert mock_start_run.call_count == 2
Пример #7
0
def test_cleanup_on_error(
        mock_mlflow_end_run,
        any_error,
        context,
        cleanup_mlflow_runs  # pylint: disable=unused-argument
):
    with patch.object(MlFlow, "_setup"):
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)
    # When: a run is started
    mlf.start_run()  # pylint: disable=no-member

    with patch("sys.exc_info", return_value=[0, any_error]):
        # When: cleanup_on_error is called
        mlf.cleanup_on_error()
    # Then:
    if any_error:
        if isinstance(any_error, KeyboardInterrupt):
            # mlflow.end_run is called with status=KILLED if KeyboardInterrupt
            mock_mlflow_end_run.assert_called_once_with(status="KILLED")
        else:
            # mlflow.end_run is called with status=FAILED for all other errors
            mock_mlflow_end_run.assert_called_once_with(status="FAILED")
        assert True
    else:
        # mlflow.end_run is not called when no error is flagged
        mock_mlflow_end_run.assert_not_called()
Пример #8
0
def test_setup(mock_atexit, context):

    with patch.object(MlFlow, "_setup"):
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)

    with patch.object(
            MlFlow, "_get_current_run_id", return_value="run_id_mock"
    ) as mock_get_current_run_id, patch.object(
            MlFlow, "_set_active_run") as mock_set_active_run, patch.object(
                MlFlow, "_set_all_tags") as mock_set_all_tags:
        # When _setup is called
        mlf._setup()  # pylint: disable=protected-access
        # Then
        # - _get_current_run_id is called once with the experiment object
        mock_get_current_run_id.assert_called()
        # - _set_active_run is called once with the run_id returned from _get_current_run_id
        mock_set_active_run.assert_called_once_with(run_id="run_id_mock")
        # - _set_all_tags is called once
        mock_set_all_tags.assert_called_once()
    # - atexit.unregister is called with mlf.end_run as an argument
    mock_atexit.assert_called_once_with(mlf.end_run)  # pylint: disable=no-member
Пример #9
0
def test_mlflow_constructor_basic(
    mock_mlflowclient,
    mock_set_experiment,
    mock_get_experiment_by_name,
    mock_set_tracking_uri,
    mock_environ_update,
    context,
):
    with patch.object(MlFlow, "_setup") as mock_setup:
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)
        # Then:
        # the _setup() is called once
        mock_setup.assert_called_once()
    # - the mlflow library methods & attributes have been added to the object
    assert all(
        hasattr(mlf, attr) for attr in dir(MlFlow) if attr not in ("__name__"))

    # - the context associated attributes passed have been set
    assert mlf.log == context.log
    assert mlf.run_name == context.pipeline_run.pipeline_name
    assert mlf.dagster_run_id == context.run_id

    # - the tracking URI is the same as what was passed
    assert mlf.tracking_uri == context.resource_config.get(
        "mlflow_tracking_uri")
    # - the tracking URI was set to mlflow
    if mlf.tracking_uri:
        mock_set_tracking_uri.assert_called_once_with(mlf.tracking_uri)
    else:
        mock_set_tracking_uri.assert_not_called()
    # - the resource config attributes have been set
    assert mlf.parent_run_id == context.resource_config.get("parent_run_id")
    assert mlf.experiment_name == context.resource_config.get(
        "experiment_name")
    assert mlf.env_tags_to_log == context.resource_config.get("env_to_tag", [])
    assert mlf.extra_tags == context.resource_config.get("extra_tags")
    assert mlf.env_vars == context.resource_config.get("env", {})

    # - the env vars that have been updated are set
    if mlf.env_vars:
        mock_environ_update.assert_called_once_with(mlf.env_vars)
    else:
        mock_environ_update.assert_not_called()
    mock_set_experiment.assert_called_once_with(mlf.experiment_name)
    mock_get_experiment_by_name(mlf.experiment_name)
    # - mlflow.tracking.MlflowClient has been called
    mock_mlflowclient.assert_called_once()
Пример #10
0
def test_start_run(mock_start_run, context):

    with patch.object(MlFlow, "_setup"):
        # Given: a context  passed into the __init__ for MlFlow
        mlf = MlFlow(context)

    # When: a run is started
    run_id_1 = str(uuid.uuid4())
    mlf._start_run(run_id=run_id_1)  # pylint: disable=protected-access
    # Then mlflow start_run is called
    mock_start_run.assert_called_once_with(run_id=run_id_1)

    # And when start run is called with the same run_id no excpetion is raised
    mlf._start_run(run_id=run_id_1)  # pylint: disable=protected-access