def test_run_flow_calls_callbacks(monkeypatch):
    start_func = MagicMock()
    exit_func = MagicMock()

    environment = DaskKubernetesEnvironment(on_start=start_func,
                                            on_exit=exit_func)

    flow_runner = MagicMock()
    monkeypatch.setattr(
        "prefect.engine.get_default_flow_runner_class",
        MagicMock(return_value=flow_runner),
    )

    kube_cluster = MagicMock()
    monkeypatch.setattr("dask_kubernetes.KubeCluster", kube_cluster)

    with tempfile.TemporaryDirectory() as directory:
        with open(os.path.join(directory, "flow_env.prefect"), "w+") as env:
            storage = Local(directory)
            flow = prefect.Flow("test", storage=storage)
            flow_path = os.path.join(directory, "flow_env.prefect")
            with open(flow_path, "wb") as f:
                cloudpickle.dump(flow, f)

        with set_temporary_config({"cloud.auth_token": "test"}):
            with prefect.context(flow_file_path=os.path.join(
                    directory, "flow_env.prefect")):
                environment.run_flow()

        assert flow_runner.call_args[1]["flow"].name == "test"

    assert start_func.called
    assert exit_func.called
Exemplo n.º 2
0
def test_run_flow_calls_callbacks(monkeypatch):
    start_func = MagicMock()
    exit_func = MagicMock()

    environment = DaskKubernetesEnvironment(on_start=start_func, on_exit=exit_func)

    flow_runner = MagicMock()
    monkeypatch.setattr(
        "prefect.engine.get_default_flow_runner_class",
        MagicMock(return_value=flow_runner),
    )

    kube_cluster = MagicMock()
    monkeypatch.setattr("dask_kubernetes.KubeCluster", kube_cluster)

    with tempfile.TemporaryDirectory() as directory:
        d = Local(directory)
        d.add_flow(prefect.Flow("name"))

        gql_return = MagicMock(
            return_value=MagicMock(
                data=MagicMock(
                    flow_run=[
                        GraphQLResult(
                            {
                                "flow": GraphQLResult(
                                    {"name": "name", "storage": d.serialize(),}
                                )
                            }
                        )
                    ],
                )
            )
        )
        client = MagicMock()
        client.return_value.graphql = gql_return
        monkeypatch.setattr("prefect.environments.execution.dask.k8s.Client", client)

        with set_temporary_config({"cloud.auth_token": "test"}), prefect.context(
            {"flow_run_id": "id"}
        ):
            environment.run_flow()

        assert flow_runner.call_args[1]["flow"].name == "name"

    assert start_func.called
    assert exit_func.called