コード例 #1
0
def test_node_hook_logging_above_limit_fail_strategy(kedro_project,
                                                     dummy_run_params,
                                                     param_length):

    _write_yaml(
        kedro_project / "conf" / "local" / "mlflow.yml",
        dict(tracking=dict(params=dict(long_params_strategy="fail"), ), ),
    )

    mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
    mlflow.set_tracking_uri(mlflow_tracking_uri)

    mlflow_node_hook = MlflowHook()

    param_value = param_length * "a"
    node_inputs = {"params:my_param": param_value}

    bootstrap_project(kedro_project)
    with KedroSession.create(project_path=kedro_project, ) as session:
        context = session.load_context()
        mlflow_node_hook.after_context_created(context)

        with mlflow.start_run():
            mlflow_node_hook.before_pipeline_run(
                run_params=dummy_run_params,
                pipeline=Pipeline([]),
                catalog=DataCatalog(),
            )

            # IMPORTANT: Overpassing the parameters limit
            # should raise an error for all mlflow backend
            # but it does not on FileStore backend :
            # https://github.com/mlflow/mlflow/issues/2814#issuecomment-628284425
            # Since we use FileStore system for simplicty for tests logging works
            # But we have enforced failure (which is slightly different from mlflow
            # behaviour)
            with pytest.raises(
                    ValueError,
                    match=f"Parameter 'my_param' length is {param_length}"):
                mlflow_node_hook.before_node_run(
                    node=node(func=lambda x: x,
                              inputs=dict(x="a"),
                              outputs=None),
                    catalog=DataCatalog(),  # can be empty
                    inputs=node_inputs,
                    is_async=False,
                )
コード例 #2
0
ファイル: test_session_hooks.py プロジェクト: szczeles/kedro
    def test_broken_input_update_parallel(
        self, mock_settings_import, tmp_path, dummy_dataframe
    ):
        mock_settings_import.return_value.HOOKS = (BrokenBeforeNodeRunHook(),)

        session = KedroSession.create(MOCK_PACKAGE_NAME, tmp_path)
        context = session.load_context()
        catalog = context.catalog
        catalog.save("cars", dummy_dataframe)
        catalog.save("boats", dummy_dataframe)

        pattern = (
            "`before_node_run` must return either None or a dictionary "
            "mapping dataset names to updated values, got `MockDatasetReplacement`"
        )
        with pytest.raises(TypeError, match=re.escape(pattern)):
            session.run(runner=ParallelRunner())
コード例 #3
0
    def test_register_pipelines_with_duplicate_entries(self, tmp_path,
                                                       mock_pipelines, mocker):
        mocker.patch("kedro.framework.project._validate_module")
        pattern = ("Found duplicate pipeline entries. The following "
                   "will be overwritten: __default__")
        with pytest.warns(UserWarning, match=re.escape(pattern)):
            configure_project(MOCK_PACKAGE_NAME)

        session = KedroSession.create(MOCK_PACKAGE_NAME, tmp_path)
        context = session.load_context()
        # check that all pipeline dictionaries merged together correctly
        expected_pipelines = {
            key: CONTEXT_PIPELINE
            for key in ("__default__", "de", "pipe")
        }
        assert mock_pipelines == expected_pipelines
        assert context.pipelines == expected_pipelines
コード例 #4
0
def test_kedro_mlflow_config_setup_export_credentials(kedro_project_with_mlflow_conf):

    (kedro_project_with_mlflow_conf / "conf/base/credentials.yml").write_text(
        yaml.dump(dict(my_mlflow_creds=dict(fake_mlflow_cred="my_fake_cred")))
    )

    # the config must restore properly the experiment
    config = KedroMlflowConfig(
        server=dict(credentials="my_mlflow_creds"),
    )

    bootstrap_project(kedro_project_with_mlflow_conf)
    with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
        context = session.load_context()  # setup config
        config.setup(context)

    assert os.environ["fake_mlflow_cred"] == "my_fake_cred"
コード例 #5
0
    def test_git_describe_error(
        self, fake_project, exception, mock_package_name, mocker, caplog
    ):
        """Test that git information is not added to the session store
        if call to git fails
        """
        mocker.patch("subprocess.check_output", side_effect=exception)
        session = KedroSession.create(mock_package_name, fake_project)
        assert "git" not in session.store

        expected_log_messages = [f"Unable to git describe {fake_project}"]
        actual_log_messages = [
            rec.getMessage()
            for rec in caplog.records
            if rec.name == SESSION_LOGGER_NAME and rec.levelno == logging.WARN
        ]
        assert actual_log_messages == expected_log_messages
コード例 #6
0
def test_on_pipeline_error(kedro_project_with_mlflow_conf):

    tracking_uri = (kedro_project_with_mlflow_conf / "mlruns").as_uri()

    project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf)
    _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
        package_name=project_metadata.package_name,
        project_path=kedro_project_with_mlflow_conf,
    ):

        def failing_node():
            mlflow.start_run(nested=True)
            raise ValueError("Let's make this pipeline fail")

        class DummyContextWithHook(KedroContext):
            project_name = "fake project"
            package_name = "fake_project"
            project_version = "0.16.5"

            hooks = (MlflowPipelineHook(),)

            def _get_pipeline(self, name: str = None) -> Pipeline:
                return Pipeline(
                    [
                        node(
                            func=failing_node,
                            inputs=None,
                            outputs="fake_output",
                        )
                    ]
                )

        with pytest.raises(ValueError):
            failing_context = DummyContextWithHook(
                "fake_package", kedro_project_with_mlflow_conf.as_posix()
            )
            failing_context.run()

        # the run we want is the last one in Default experiment
        failing_run_info = MlflowClient(tracking_uri).list_run_infos("0")[0]
        assert mlflow.active_run() is None  # the run must have been closed
        assert failing_run_info.status == RunStatus.to_string(
            RunStatus.FAILED
        )  # it must be marked as failed
コード例 #7
0
ファイル: test_session.py プロジェクト: MerelTheisenQB/kedro
 def test_default_store(self, fake_project, fake_session_id, caplog,
                        mock_package_name):
     session = KedroSession.create(mock_package_name, fake_project)
     assert isinstance(session.store, dict)
     assert session._store.__class__ is BaseSessionStore
     assert session._store._path == (fake_project / "sessions").as_posix()
     assert session._store._session_id == fake_session_id
     session.close()
     expected_log_messages = [
         "`read()` not implemented for `BaseSessionStore`. Assuming empty store.",
         "`save()` not implemented for `BaseSessionStore`. Skipping the step.",
     ]
     actual_log_messages = [
         rec.getMessage() for rec in caplog.records
         if rec.name == STORE_LOGGER_NAME and rec.levelno == logging.INFO
     ]
     assert actual_log_messages == expected_log_messages
コード例 #8
0
ファイル: cli.py プロジェクト: ravindra-siwach/kedro
def run(
    tag,
    env,
    parallel,
    runner,
    is_async,
    node_names,
    to_nodes,
    from_nodes,
    from_inputs,
    load_version,
    pipeline,
    config,
    params,
):
    """Run the pipeline."""
    if parallel and runner:
        raise KedroCliError(
            "Both --parallel and --runner options cannot be used together. "
            "Please use either --parallel or --runner.")
    runner = runner or "SequentialRunner"
    if parallel:
        runner = "ParallelRunner"
    runner_class = load_obj(runner, "kedro.runner")

    tag = _get_values_as_tuple(tag) if tag else tag
    node_names = _get_values_as_tuple(node_names) if node_names else node_names

    package_name = str(Path(__file__).resolve().parent.name)
    with KedroSession.create(package_name, env=env,
                             extra_params=params) as session:
        session.run(
            tags=tag,
            runner=runner_class(is_async=is_async),
            node_names=node_names,
            from_nodes=from_nodes,
            to_nodes=to_nodes,
            from_inputs=from_inputs,
            load_versions=load_version,
            pipeline_name=pipeline,
        )

        # Logging parameters for some e2e tests
        params_to_log = session.load_context().params
        logging.info("Parameters: %s", json.dumps(params_to_log,
                                                  sort_keys=True))
コード例 #9
0
    def test_shelve_store(self, fake_project, fake_session_id, caplog, mocker):
        mocker.patch("pathlib.Path.is_file", return_value=True)
        shelve_location = fake_project / "nested" / "sessions"
        other = KedroSession.create(_FAKE_PACKAGE_NAME, fake_project)
        assert other._store.__class__ is ShelveStore
        assert other._store._path == shelve_location.as_posix()
        assert other._store._location == shelve_location / fake_session_id / "store"
        assert other._store._session_id == fake_session_id
        assert not shelve_location.is_dir()

        other.close()  # session data persisted
        assert shelve_location.is_dir()
        actual_log_messages = [
            rec.getMessage() for rec in caplog.records
            if rec.name == STORE_LOGGER_NAME and rec.levelno == logging.INFO
        ]
        assert not actual_log_messages
コード例 #10
0
def test_node_hook_logging(
    kedro_project,
    dummy_run_params,
    dummy_catalog,
    dummy_pipeline,
    dummy_node,
    flatten,
    expected,
):

    _write_yaml(
        kedro_project / "conf" / "base" / "mlflow.yml",
        dict(tracking=dict(params=dict(
            dict_params=dict(flatten=flatten, recursive=False, sep="-")))),
    )

    node_inputs = {
        v: dummy_catalog._data_sets.get(v)
        for k, v in dummy_node._inputs.items()
    }

    mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()

    bootstrap_project(kedro_project)
    with KedroSession.create(project_path=kedro_project, ) as session:
        context = session.load_context()
        mlflow_node_hook = MlflowHook()
        mlflow_node_hook.after_context_created(context)  # setup mlflow_config
        mlflow.set_tracking_uri(mlflow_tracking_uri)
        with mlflow.start_run():
            mlflow_node_hook.before_pipeline_run(
                run_params=dummy_run_params,
                pipeline=dummy_pipeline,
                catalog=dummy_catalog,
            )
            mlflow_node_hook.before_node_run(
                node=dummy_node,
                catalog=dummy_catalog,
                inputs=node_inputs,
                is_async=False,
            )
            run_id = mlflow.active_run().info.run_id

        mlflow_client = MlflowClient(mlflow_tracking_uri)
        current_run = mlflow_client.get_run(run_id)
        assert current_run.data.params == expected
コード例 #11
0
def test_kedro_mlflow_config_new_experiment_does_not_exists(
    kedro_project_with_mlflow_conf,
):

    config = KedroMlflowConfig(
        server=dict(mlflow_tracking_uri="mlruns"),
        tracking=dict(experiment=dict(name="exp1")),
    )

    bootstrap_project(kedro_project_with_mlflow_conf)
    with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
        context = session.load_context()  # setup config
        config.setup(context)

    assert "exp1" in [
        exp.name for exp in config.server._mlflow_client.list_experiments()
    ]
コード例 #12
0
ファイル: cli.py プロジェクト: Ryusuketa/titanic
def run(
    tag,
    env,
    parallel,
    runner,
    is_async,
    node_names,
    to_nodes,
    from_nodes,
    from_inputs,
    load_version,
    pipeline,
    config,
    params,
    run_mode,
):
    """Run the pipeline."""
    if parallel and runner:
        raise KedroCliError(
            "Both --parallel and --runner options cannot be used together. "
            "Please use either --parallel or --runner."
        )
    runner = runner or "SequentialRunner"
    if parallel:
        runner = "ParallelRunner"
    runner_class = load_obj(runner, "kedro.runner")

    project_hooks.set_mode(run_mode)

    tag = _get_values_as_tuple(tag) if tag else tag
    node_names = _get_values_as_tuple(node_names) if node_names else node_names

    package_name = str(Path(__file__).resolve().parent.name)
    with KedroSession.create(package_name, env=env, extra_params=params) as session:
        session.run(
            tags=tag,
            runner=runner_class(is_async=is_async),
            node_names=node_names,
            from_nodes=from_nodes,
            to_nodes=to_nodes,
            from_inputs=from_inputs,
            load_versions=load_version,
            pipeline_name=pipeline,
        )
コード例 #13
0
def reload_kedro(path, env: str = None, extra_params: Dict[str, Any] = None):
    """Line magic which reloads all Kedro default variables."""

    import kedro.config.default_logger  # noqa: F401 # pylint: disable=unused-import
    from kedro.framework.cli import load_entry_points
    from kedro.framework.project import pipelines
    from kedro.framework.session import KedroSession
    from kedro.framework.session.session import _activate_session
    from kedro.framework.startup import bootstrap_project

    _clear_hook_manager()

    path = path or project_path
    metadata = bootstrap_project(path)

    _remove_cached_modules(metadata.package_name)

    session = KedroSession.create(metadata.package_name,
                                  path,
                                  env=env,
                                  extra_params=extra_params)
    _activate_session(session, force=True)
    logging.debug("Loading the context from %s", str(path))
    context = session.load_context()
    catalog = context.catalog

    get_ipython().push(
        variables={
            "context": context,
            "catalog": catalog,
            "session": session,
            "pipelines": pipelines,
        })

    logging.info("** Kedro project %s", str(metadata.project_name))
    logging.info(
        "Defined global variable `context`, `session`, `catalog` and `pipelines`"
    )

    for line_magic in load_entry_points("line_magic"):
        register_line_magic(needs_local_scope(line_magic))
        logging.info("Registered line magic `%s`",
                     line_magic.__name__)  # type: ignore
コード例 #14
0
def test_node_hook_logging_above_limit_truncate_strategy(
        kedro_project, dummy_run_params, param_length):

    _write_yaml(
        kedro_project / "conf" / "local" / "mlflow.yml",
        dict(hooks=dict(node=dict(long_parameters_strategy="truncate")), ),
    )

    mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
    mlflow.set_tracking_uri(mlflow_tracking_uri)

    mlflow_node_hook = MlflowNodeHook()

    param_value = param_length * "a"
    node_inputs = {"params:my_param": param_value}

    project_metadata = _get_project_metadata(kedro_project)
    _add_src_to_path(project_metadata.source_dir, kedro_project)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
            package_name=project_metadata.package_name,
            project_path=kedro_project,
    ):
        with mlflow.start_run():
            mlflow_node_hook.before_pipeline_run(
                run_params=dummy_run_params,
                pipeline=Pipeline([]),
                catalog=DataCatalog(),
            )
            mlflow_node_hook.before_node_run(
                node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
                catalog=DataCatalog(),  # can be empty
                inputs=node_inputs,
                is_async=False,
                run_id="132",
            )
            run_id = mlflow.active_run().info.run_id

        mlflow_client = MlflowClient(mlflow_tracking_uri)
        current_run = mlflow_client.get_run(run_id)
        assert current_run.data.params == {
            "my_param": param_value[0:MAX_PARAM_VAL_LENGTH]
        }
コード例 #15
0
def test_kedro_mlflow_config_new_experiment_does_not_exists(
    kedro_project_with_mlflow_conf,
):

    config = KedroMlflowConfig(
        project_path=kedro_project_with_mlflow_conf,
        mlflow_tracking_uri="mlruns",
        experiment_opts=dict(name="exp1"),
    )

    project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf)
    _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
        "fake_project", project_path=kedro_project_with_mlflow_conf
    ):
        config.setup()

    assert "exp1" in [exp.name for exp in config.mlflow_client.list_experiments()]
コード例 #16
0
ファイル: cli.py プロジェクト: pierreGouedard/datalab
def run(tag, env, node_names, to_nodes, from_nodes, from_inputs, pipeline,
        debug, params, only_missing, is_async):
    """Run the pipeline."""
    runner = DatalabRunner(only_missing, is_async)
    tag = _get_values_as_tuple(tag) if tag else tag
    node_names = _get_values_as_tuple(node_names) if node_names else node_names
    package_name = str(Path(__file__).resolve().parent.name)

    with KedroSession.create(package_name, env=env,
                             extra_params=params) as session:
        session.run(
            tags=tag,
            runner=runner,
            node_names=node_names,
            from_nodes=from_nodes,
            to_nodes=to_nodes,
            from_inputs=from_inputs,
            pipeline_name=pipeline,
        )
コード例 #17
0
    def test_register_pipelines_with_duplicate_entries(self, tmp_path,
                                                       dummy_dataframe):
        session = KedroSession.create(MOCK_PACKAGE_NAME, tmp_path)
        context = session.load_context()
        catalog = context.catalog
        catalog.save("cars", dummy_dataframe)
        catalog.save("boats", dummy_dataframe)

        pattern = ("Found duplicate pipeline entries. The following "
                   "will be overwritten: __default__")
        with pytest.warns(UserWarning, match=re.escape(pattern)):
            session.run()

        # check that all pipeline dictionaries merged together correctly
        expected_pipelines = {
            key: CONTEXT_PIPELINE
            for key in ("__default__", "de", "pipe")
        }
        assert context.pipelines == expected_pipelines
コード例 #18
0
def test_modelify_with_artifact_path_arg(monkeypatch, kp_for_modelify):
    monkeypatch.chdir(kp_for_modelify)

    cli_runner = CliRunner()

    bootstrap_project(Path().cwd())
    with KedroSession.create() as session:
        context = session.load_context()
        catalog = context.catalog
        catalog.save("trained_model", 2)

    runs_id_set_before_cmd = set([
        run_info.run_id
        for run_info in context.mlflow.server._mlflow_client.list_run_infos(
            context.mlflow.tracking.experiment._experiment.experiment_id)
    ])

    result = cli_runner.invoke(
        cli_modelify,
        [
            "--pipeline",
            "inference",
            "--input-name",
            "my_input_data",
            "--artifact-path",
            "my_new_model",
        ],
        catch_exceptions=True,
    )
    runs_id_set_after_cmd = set([
        run_info.run_id
        for run_info in context.mlflow.server._mlflow_client.list_run_infos(
            context.mlflow.tracking.experiment._experiment.experiment_id)
    ])

    new_run_id = runs_id_set_after_cmd - runs_id_set_before_cmd

    assert result.exit_code == 0
    assert "my_new_model" in [
        file.path
        for file in context.mlflow.server._mlflow_client.list_artifacts(
            list(new_run_id)[0])
    ]
コード例 #19
0
ファイル: test_session.py プロジェクト: MerelTheisenQB/kedro
    def test_create(
        self,
        fake_project,
        mock_context_class,
        fake_session_id,
        mock_package_name,
        mocker,
        env,
        extra_params,
    ):
        mock_click_ctx = mocker.patch("click.get_current_context").return_value
        session = KedroSession.create(mock_package_name,
                                      fake_project,
                                      env=env,
                                      extra_params=extra_params)

        expected_cli_data = {
            "args": mock_click_ctx.args,
            "params": mock_click_ctx.params,
            "command_name": mock_click_ctx.command.name,
            "command_path": mock_click_ctx.command_path,
        }
        expected_store = {
            "project_path": fake_project,
            "session_id": fake_session_id,
            "package_name": mock_package_name,
            "cli": expected_cli_data,
        }
        if env:
            expected_store["env"] = env
        if extra_params:
            expected_store["extra_params"] = extra_params

        assert session.store == expected_store
        # called for logging setup
        mock_context_class.assert_called_once_with(
            project_path=fake_project,
            package_name=mock_package_name,
            env=env,
            extra_params=extra_params,
        )

        assert session.load_context() is mock_context_class.return_value
コード例 #20
0
ファイル: ipython.py プロジェクト: saadiaminhas/kedro
def load_kedro_objects(path, line=None):  # pylint: disable=unused-argument
    """Line magic which reloads all Kedro default variables."""

    import kedro.config.default_logger  # noqa: F401 # pylint: disable=unused-import
    from kedro.framework.cli import load_entry_points
    from kedro.framework.cli.utils import _add_src_to_path
    from kedro.framework.project import configure_project
    from kedro.framework.session import KedroSession
    from kedro.framework.session.session import _activate_session
    from kedro.framework.startup import _get_project_metadata

    global context
    global catalog
    global session

    path = path or project_path
    metadata = _get_project_metadata(path)
    _add_src_to_path(metadata.source_dir, path)
    configure_project(metadata.package_name)

    _clear_hook_manager()

    _remove_cached_modules(metadata.package_name)

    session = KedroSession.create(metadata.package_name, path)
    _activate_session(session)
    logging.debug("Loading the context from %s", str(path))
    context = session.load_context()
    catalog = context.catalog

    get_ipython().push(variables={
        "context": context,
        "catalog": catalog,
        "session": session
    })

    logging.info("** Kedro project %s", str(metadata.project_name))
    logging.info("Defined global variable `context`, `session` and `catalog`")

    for line_magic in load_entry_points("line_magic"):
        register_line_magic(needs_local_scope(line_magic))
        logging.info("Registered line magic `%s`", line_magic.__name__)
コード例 #21
0
def test_kedro_mlflow_config_experiment_exists(kedro_project_with_mlflow_conf):

    # create an experiment with the same name
    mlflow_tracking_uri = (
        kedro_project_with_mlflow_conf / "conf" / "local" / "mlruns"
    ).as_uri()
    MlflowClient(mlflow_tracking_uri).create_experiment("exp1")
    config = KedroMlflowConfig(
        server=dict(mlflow_tracking_uri="mlruns"),
        tracking=dict(experiment=dict(name="exp1")),
    )

    bootstrap_project(kedro_project_with_mlflow_conf)
    with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
        context = session.load_context()  # setup config
        config.setup(context)

    assert "exp1" in [
        exp.name for exp in config.server._mlflow_client.list_experiments()
    ]
コード例 #22
0
def test_kedro_mlflow_config_setup_export_credentials(kedro_project_with_mlflow_conf):

    (kedro_project_with_mlflow_conf / "conf/base/credentials.yml").write_text(
        yaml.dump(dict(my_mlflow_creds=dict(fake_mlflow_cred="my_fake_cred")))
    )

    # the config must restore properly the experiment
    config = KedroMlflowConfig(
        project_path=kedro_project_with_mlflow_conf, credentials="my_mlflow_creds"
    )

    project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf)
    _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
        "fake_project", project_path=kedro_project_with_mlflow_conf
    ):
        config.setup()

    assert os.environ["fake_mlflow_cred"] == "my_fake_cred"
コード例 #23
0
def test_on_pipeline_error(kedro_project_with_mlflow_conf):

    tracking_uri = (kedro_project_with_mlflow_conf / "mlruns").as_uri()

    bootstrap_project(kedro_project_with_mlflow_conf)
    with KedroSession.create(
            project_path=kedro_project_with_mlflow_conf) as session:
        context = session.load_context()
        with pytest.raises(ValueError):
            session.run()

        # the run we want is the last one in the configuration experiment
        mlflow_client = MlflowClient(tracking_uri)
        experiment = mlflow_client.get_experiment_by_name(
            context.mlflow.tracking.experiment.name)
        failing_run_info = MlflowClient(tracking_uri).list_run_infos(
            experiment.experiment_id)[0]
        assert mlflow.active_run() is None  # the run must have been closed
        assert failing_run_info.status == RunStatus.to_string(
            RunStatus.FAILED)  # it must be marked as failed
コード例 #24
0
def test_kedro_mlflow_config_setup_set_tracking_uri(kedro_project_with_mlflow_conf):

    # create an experiment with the same name and then delete it
    mlflow_tracking_uri = (kedro_project_with_mlflow_conf / "awesome_tracking").as_uri()

    # the config must restore properly the experiment
    config = KedroMlflowConfig(
        project_path=kedro_project_with_mlflow_conf,
        mlflow_tracking_uri="awesome_tracking",
        experiment_opts=dict(name="exp1"),
    )

    project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf)
    _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
        "fake_project", project_path=kedro_project_with_mlflow_conf
    ):
        config.setup()

    assert mlflow.get_tracking_uri() == mlflow_tracking_uri
コード例 #25
0
def test_kedro_mlflow_config_experiment_exists(mocker, kedro_project_with_mlflow_conf):

    # create an experiment with the same name
    mlflow_tracking_uri = (
        kedro_project_with_mlflow_conf / "conf" / "local" / "mlruns"
    ).as_uri()
    MlflowClient(mlflow_tracking_uri).create_experiment("exp1")
    config = KedroMlflowConfig(
        project_path=kedro_project_with_mlflow_conf,
        mlflow_tracking_uri="mlruns",
        experiment_opts=dict(name="exp1"),
    )

    project_metadata = _get_project_metadata(kedro_project_with_mlflow_conf)
    _add_src_to_path(project_metadata.source_dir, kedro_project_with_mlflow_conf)
    configure_project(project_metadata.package_name)
    with KedroSession.create(
        "fake_project", project_path=kedro_project_with_mlflow_conf
    ):
        config.setup()
    assert "exp1" in [exp.name for exp in config.mlflow_client.list_experiments()]
コード例 #26
0
    def test_shelve_store(self, fake_project, fake_session_id, caplog):
        caplog.set_level(logging.WARN, logger="kedro.framework.session.store")
        kedro_yml = fake_project / ".kedro.yml"
        shelve_location = fake_project / "nested" / "sessions"
        with kedro_yml.open("r+") as f:
            data = yaml.safe_load(f)
            data["session_store"] = {
                "type": "ShelveStore",
                "path": shelve_location.as_posix(),
            }
            yaml.safe_dump(data, f)

        other = KedroSession.create(fake_project)
        assert other._store.__class__ is ShelveStore
        assert other._store._path == shelve_location.as_posix()
        assert other._store._location == shelve_location / fake_session_id / "store"
        assert other._store._session_id == fake_session_id
        assert not shelve_location.is_dir()
        other.close()  # session data persisted
        assert shelve_location.is_dir()
        assert not caplog.records
コード例 #27
0
ファイル: test_session_hooks.py プロジェクト: szczeles/kedro
    def test_before_and_after_node_run_hooks_parallel_runner(
        self, tmp_path, logging_hooks, dummy_dataframe
    ):
        log_records = []
        session = KedroSession.create(MOCK_PACKAGE_NAME, tmp_path)
        context = session.load_context()
        catalog = context.catalog
        catalog.save("cars", dummy_dataframe)
        catalog.save("boats", dummy_dataframe)

        class LogHandler(logging.Handler):  # pylint: disable=abstract-method
            def handle(self, record):
                log_records.append(record)

        logs_queue_listener = QueueListener(logging_hooks.queue, LogHandler())
        logs_queue_listener.start()

        try:
            session.run(runner=ParallelRunner(), node_names=["node1", "node2"])
        finally:
            logs_queue_listener.stop()

        before_node_run_log_records = [
            r for r in log_records if r.funcName == "before_node_run"
        ]
        assert len(before_node_run_log_records) == 2
        for record in before_node_run_log_records:
            assert record.getMessage() == "About to run node"
            assert record.node.name in ["node1", "node2"]
            assert set(record.inputs.keys()) <= {"cars", "boats"}

        after_node_run_log_records = [
            r for r in log_records if r.funcName == "after_node_run"
        ]
        assert len(after_node_run_log_records) == 2
        for record in after_node_run_log_records:
            assert record.getMessage() == "Ran node"
            assert record.node.name in ["node1", "node2"]
            assert set(record.outputs.keys()) <= {"planes", "ships"}
コード例 #28
0
    def test_catalog_and_params(  # pylint: disable=too-many-locals
            self, fake_repo_path, fake_project_cli, fake_metadata,
            fake_package_path):
        """Test that catalog and parameter configs generated in pipeline
        sections propagate into the context"""
        pipelines_dir = fake_package_path / "pipelines"
        assert pipelines_dir.is_dir()

        cmd = ["pipeline", "create", PIPELINE_NAME]
        result = CliRunner().invoke(fake_project_cli, cmd, obj=fake_metadata)
        assert result.exit_code == 0

        # write pipeline catalog
        conf_dir = fake_repo_path / settings.CONF_ROOT / "base"
        catalog_dict = {
            "ds_from_pipeline": {
                "type": "pandas.CSVDataSet",
                "filepath": "data/01_raw/iris.csv",
            }
        }
        catalog_file = conf_dir / "catalog" / f"{PIPELINE_NAME}.yml"
        catalog_file.parent.mkdir()
        with catalog_file.open("w") as f:
            yaml.dump(catalog_dict, f)

        # write pipeline parameters
        params_file = conf_dir / "parameters" / f"{PIPELINE_NAME}.yml"
        assert params_file.is_file()
        params_dict = {"params_from_pipeline": {"p1": [1, 2, 3], "p2": None}}
        with params_file.open("w") as f:
            yaml.dump(params_dict, f)

        with KedroSession.create(PACKAGE_NAME) as session:
            ctx = session.load_context()
        assert isinstance(ctx.catalog._data_sets["ds_from_pipeline"],
                          CSVDataSet)
        assert isinstance(ctx.catalog.load("ds_from_pipeline"), DataFrame)
        assert ctx.params["params_from_pipeline"] == params_dict[
            "params_from_pipeline"]
コード例 #29
0
def test_cli_init_existing_config_force_option(monkeypatch, kedro_project,
                                               mock_settings_fake_project):
    # "kedro_project" is a pytest.fixture declared in conftest
    monkeypatch.chdir(kedro_project)
    cli_runner = CliRunner()

    bootstrap_project(kedro_project)
    with KedroSession.create(project_path=kedro_project) as session:

        # emulate first call by writing a mlflow.yml file
        yaml_str = yaml.dump(dict(server=dict(mlflow_tracking_uri="toto")))
        (kedro_project / mock_settings_fake_project.CONF_SOURCE / "local" /
         "mlflow.yml").write_text(yaml_str)

        result = cli_runner.invoke(cli_init, args="--force")

        # check an error message is raised
        assert "successfully updated" in result.output

        # check the file remains unmodified
        context = session.load_context()
        assert context.mlflow.server.mlflow_tracking_uri.endswith("mlruns")
コード例 #30
0
def test_node_hook_logging_above_limit_truncate_strategy(
        kedro_project, dummy_run_params, param_length):

    _write_yaml(
        kedro_project / "conf" / "local" / "mlflow.yml",
        dict(tracking=dict(params=dict(long_params_strategy="truncate")), ),
    )

    mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
    mlflow.set_tracking_uri(mlflow_tracking_uri)

    param_value = param_length * "a"
    node_inputs = {"params:my_param": param_value}

    bootstrap_project(kedro_project)
    with KedroSession.create(project_path=kedro_project, ) as session:
        context = session.load_context()
        mlflow_node_hook = MlflowHook()
        mlflow_node_hook.after_context_created(context)
        with mlflow.start_run():
            mlflow_node_hook.before_pipeline_run(
                run_params=dummy_run_params,
                pipeline=Pipeline([]),
                catalog=DataCatalog(),
            )
            mlflow_node_hook.before_node_run(
                node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
                catalog=DataCatalog(),  # can be empty
                inputs=node_inputs,
                is_async=False,
            )
            run_id = mlflow.active_run().info.run_id

        mlflow_client = MlflowClient(mlflow_tracking_uri)
        current_run = mlflow_client.get_run(run_id)
        assert current_run.data.params == {
            "my_param": param_value[0:MAX_PARAM_VAL_LENGTH]
        }