Esempio n. 1
0
def test_run_prefect_ensemble_with_path(unused_tcp_port):
    with tmp(os.path.join(SOURCE_DIR, "test-data/local/prefect_test_case")):
        config = parse_config("config.yml")
        config.update({"config_path": Path.cwd()})
        config.update({"realizations": 2})
        config.update({"executor": "local"})

        config["config_path"] = Path(config["config_path"])
        config["run_path"] = Path(config["run_path"])
        config["storage"]["storage_path"] = Path(
            config["storage"]["storage_path"])

        service_config = EvaluatorServerConfig(unused_tcp_port)
        ensemble = PrefectEnsemble(config)

        evaluator = EnsembleEvaluator(ensemble, service_config, ee_id="1")

        mon = evaluator.run()

        for event in mon.track():
            if event.data is not None and event.data.get("status") in [
                    "Failed",
                    "Stopped",
            ]:
                mon.signal_done()

        assert evaluator._snapshot.get_status() == "Stopped"

        successful_realizations = evaluator._snapshot.get_successful_realizations(
        )

        assert successful_realizations == config["realizations"]
Esempio n. 2
0
def test_run_prefect_ensemble_exception(unused_tcp_port):
    with tmp(os.path.join(SOURCE_DIR, "test-data/local/prefect_test_case")):
        config = parse_config("config.yml")
        config.update({"config_path": Path.absolute(Path("."))})
        config.update({"realizations": 2})
        config.update({"executor": "local"})

        service_config = EvaluatorServerConfig(unused_tcp_port)

        ensemble = PrefectEnsemble(config)
        evaluator = EnsembleEvaluator(ensemble, service_config, ee_id="1")

        with patch.object(ensemble,
                          "_fetch_input_files",
                          side_effect=RuntimeError()):
            mon = evaluator.run()
            for event in mon.track():
                if event["type"] in (
                        ids.EVTYPE_EE_SNAPSHOT_UPDATE,
                        ids.EVTYPE_EE_SNAPSHOT,
                ) and event.data.get("status") in [
                        "Stopped",
                        "Failed",
                ]:
                    mon.signal_done()
            assert evaluator._snapshot.get_status() == "Failed"
Esempio n. 3
0
def test_unix_step(unused_tcp_port):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    def _on_task_failure(task, state):
        raise Exception(state.message)

    with tmp(Path(SOURCE_DIR) / "test-data/local/prefect_test_case"):
        config = parse_config("config.yml")
        storage = storage_driver_factory(config=config.get("storage"),
                                         run_path=".")
        resource = storage.store("unix_test_script.py")
        jobs = [{
            "id": "0",
            "name": "test_script",
            "executable": "unix_test_script.py",
            "args": ["vas"],
        }]

        stage_task = UnixStep(
            resources=[resource],
            outputs=["output.out"],
            job_list=jobs,
            iens=1,
            cmd="python3",
            url=url,
            step_id="step_id_0",
            stage_id="stage_id_0",
            ee_id="ee_id_0",
            on_failure=_on_task_failure,
            run_path=config.get("run_path"),
            storage_config=config.get("storage"),
            max_retries=1,
            retry_delay=timedelta(seconds=2),
        )

        flow = Flow("testing")
        flow.add_task(stage_task)
        flow_run = flow.run()

        # Stop the mock evaluator WS server
        with Client(url) as c:
            c.send("stop")
        mock_ws_thread.join()

        task_result = flow_run.result[stage_task]
        assert task_result.is_successful()
        assert flow_run.is_successful()

        assert len(task_result.result["outputs"]) == 1
        expected_path = storage.get_storage_path(1) / "output.out"
        output_path = flow_run.result[stage_task].result["outputs"][0]
        assert expected_path == output_path
        assert output_path.exists()
Esempio n. 4
0
def test_on_task_failure_fail_step(unused_tcp_port, tmpdir):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ensemble = _MockedPrefectEnsemble()

    mock_ws_thread.start()
    script_location = (
        Path(SOURCE_DIR) /
        "test-data/local/prefect_test_case/unix_test_retry_script.py")
    input_ = script_transmitter("script", script_location, storage_path=tmpdir)
    with tmp() as runpath:
        step = get_step(
            step_name="test_step",
            inputs=[("script", Path("unix_test_retry_script.py"),
                     "application/x-python")],
            outputs=[],
            jobs=[("script", Path("unix_test_retry_script.py"), [runpath])],
            type_="unix",
        )

        with prefect.context(url=url, token=None, cert=None):
            output_trans = step_output_transmitters(step, storage_path=tmpdir)
            with Flow("testing") as flow:
                task = step.get_task(
                    output_transmitters=output_trans,
                    ee_id="test_ee_id",
                    max_retries=1,
                    retry_delay=timedelta(seconds=1),
                    on_failure=mock_ensemble._on_task_failure,
                )
                result = task(inputs=input_)
            flow_run = flow.run()

    # Stop the mock evaluator WS server
    with Client(url) as c:
        c.send("stop")
    mock_ws_thread.join()

    task_result = flow_run.result[result]
    assert not task_result.is_successful()
    assert not flow_run.is_successful()

    fail_job_messages = [
        msg for msg in messages if ids.EVTYPE_FM_JOB_FAILURE in msg
    ]
    fail_step_messages = [
        msg for msg in messages if ids.EVTYPE_FM_STEP_FAILURE in msg
    ]

    expected_job_failed_messages = 2
    expected_step_failed_messages = 1
    assert expected_job_failed_messages == len(fail_job_messages)
    assert expected_step_failed_messages == len(fail_step_messages)
Esempio n. 5
0
def test_get_flow(coefficients, unused_tcp_port):
    with tmp(Path(SOURCE_DIR) / "test-data/local/prefect_test_case"):
        config = parse_config("config.yml")
        config.update({
            "config_path": os.getcwd(),
            ids.REALIZATIONS: 2,
            ids.EXECUTOR: "local",
        })
        inputs = {}
        coeffs_trans = coefficient_transmitters(
            coefficients,
            config.get(ids.STORAGE)["storage_path"])
        script_trans = script_transmitters(config)
        for iens in range(2):
            inputs[iens] = {**coeffs_trans[iens], **script_trans[iens]}
        config.update({
            "inputs": inputs,
            "outputs": output_transmitters(config),
        })
        server_config = EvaluatorServerConfig(unused_tcp_port)
        for permuted_steps in permutations(config["steps"]):
            permuted_config = copy.deepcopy(config)
            permuted_config["steps"] = permuted_steps
            permuted_config["dispatch_uri"] = server_config.dispatch_uri
            ensemble = PrefectEnsemble(permuted_config)

            for iens in range(2):
                with prefect.context(
                        url=server_config.url,
                        token=server_config.token,
                        cert=server_config.cert,
                ):
                    flow = ensemble.get_flow(ensemble._ee_id, [iens])

                # Get the ordered tasks and retrieve their step ids.
                flow_steps = [
                    task.get_step() for task in flow.sorted_tasks()
                    if isinstance(task, UnixTask)
                ]
                assert len(flow_steps) == 4

                realization_steps = list(
                    ensemble.get_reals()
                    [iens].get_steps_sorted_topologically())

                # Testing realization steps
                for step_ordering in [realization_steps, flow_steps]:
                    mapping = {
                        step._name: idx
                        for idx, step in enumerate(step_ordering)
                    }
                    assert mapping["second_degree"] < mapping["zero_degree"]
                    assert mapping["zero_degree"] < mapping["add_coeffs"]
                    assert mapping["first_degree"] < mapping["add_coeffs"]
                    assert mapping["second_degree"] < mapping["add_coeffs"]
Esempio n. 6
0
def test_function_step(unused_tcp_port, tmpdir):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    test_values = {"values": [42, 24, 6]}
    inputs = input_transmitter("values",
                               test_values["values"],
                               storage_path=tmpdir)

    def sum_function(values):
        return [sum(values)]

    step = get_step(
        step_name="test_step",
        inputs=[("values", "NA", "text/whatever")],
        outputs=[("output", Path("output.out"), "application/json")],
        jobs=[("test_function", cloudpickle.dumps(sum_function), None)],
        type_="function",
    )

    with prefect.context(url=url, token=None, cert=None):
        output_trans = step_output_transmitters(step, storage_path=tmpdir)
        with Flow("testing") as flow:
            task = step.get_task(output_transmitters=output_trans,
                                 ee_id="test_ee_id")
            result = task(inputs=inputs)
        with tmp():
            flow_run = flow.run()

    # Stop the mock evaluator WS server
    with Client(url) as c:
        c.send("stop")
    mock_ws_thread.join()

    task_result = flow_run.result[result]
    assert task_result.is_successful()
    assert flow_run.is_successful()

    assert len(task_result.result) == 1
    expected_uri = output_trans["output"]._uri
    output_uri = task_result.result["output"]._uri
    assert expected_uri == output_uri
    transmitted_record = asyncio.get_event_loop().run_until_complete(
        task_result.result["output"].load())
    transmitted_result = transmitted_record.data
    expected_result = sum_function(**test_values)
    assert expected_result == transmitted_result
Esempio n. 7
0
def test_unix_step_error(unused_tcp_port, tmpdir):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    script_location = (Path(SOURCE_DIR) /
                       "test-data/local/prefect_test_case/unix_test_script.py")
    input_ = script_transmitter("test_script",
                                script_location,
                                storage_path=tmpdir)
    step = get_step(
        step_name="test_step",
        inputs=[("test_script", Path("unix_test_script.py"),
                 "application/x-python")],
        outputs=[("output", Path("output.out"), "application/json")],
        jobs=[("test_script", Path("unix_test_script.py"), ["foo", "bar"])],
        type_="unix",
    )

    with prefect.context(url=url, token=None, cert=None):
        output_trans = step_output_transmitters(step, storage_path=tmpdir)
        with Flow("testing") as flow:
            task = step.get_task(output_transmitters=output_trans,
                                 ee_id="test_ee_id")
            result = task(inputs=input_)
        with tmp():
            flow_run = flow.run()

    # Stop the mock evaluator WS server
    with Client(url) as c:
        c.send("stop")
    mock_ws_thread.join()

    task_result = flow_run.result[result]
    assert not task_result.is_successful()
    assert not flow_run.is_successful()

    assert isinstance(task_result.result, Exception)
    assert ("unix_test_script.py: error: unrecognized arguments: bar"
            in task_result.message)
Esempio n. 8
0
def test_unix_task(unused_tcp_port, tmpdir):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    script_location = (Path(SOURCE_DIR) /
                       "test-data/local/prefect_test_case/unix_test_script.py")
    input_ = script_transmitter("script", script_location, storage_path=tmpdir)
    step = get_step(
        step_name="test_step",
        inputs=[("script", Path("unix_test_script.py"), "application/x-python")
                ],
        outputs=[("output", Path("output.out"), "application/json")],
        jobs=[("script", Path("unix_test_script.py"), ["vas"])],
        url=url,
        type_="unix",
    )

    output_trans = step_output_transmitters(step, storage_path=tmpdir)
    with Flow("testing") as flow:
        task = step.get_task(output_transmitters=output_trans,
                             ee_id="test_ee_id")
        result = task(inputs=input_)
    with tmp():
        flow_run = flow.run()

    # Stop the mock evaluator WS server
    with Client(url) as c:
        c.send("stop")
    mock_ws_thread.join()

    task_result = flow_run.result[result]
    assert task_result.is_successful()
    assert flow_run.is_successful()

    assert len(task_result.result) == 1
    expected_uri = output_trans["output"]._uri
    output_uri = task_result.result["output"]._uri
    assert expected_uri == output_uri
Esempio n. 9
0
def test_run_prefect_ensemble_with_path(unused_tcp_port, coefficients):
    with tmp(os.path.join(SOURCE_DIR, "test-data/local/prefect_test_case")):
        config = parse_config("config.yml")
        config.update({
            "config_path": os.getcwd(),
            "realizations": 2,
            "executor": "local",
        })
        inputs = {}
        coeffs_trans = coefficient_transmitters(
            coefficients,
            config.get(ids.STORAGE)["storage_path"])
        script_trans = script_transmitters(config)
        for iens in range(2):
            inputs[iens] = {**coeffs_trans[iens], **script_trans[iens]}
        config.update({
            "inputs": inputs,
            "outputs": output_transmitters(config),
        })

        service_config = EvaluatorServerConfig(unused_tcp_port)
        config["config_path"] = Path(config["config_path"])
        config["run_path"] = Path(config["run_path"])
        config["storage"]["storage_path"] = Path(
            config["storage"]["storage_path"])
        config["dispatch_uri"] = service_config.dispatch_uri

        ensemble = PrefectEnsemble(config)

        evaluator = EnsembleEvaluator(ensemble, service_config, 0, ee_id="1")

        with evaluator.run() as mon:
            for event in mon.track():
                if isinstance(event.data,
                              dict) and event.data.get("status") in [
                                  "Failed",
                                  "Stopped",
                              ]:
                    mon.signal_done()

        assert evaluator._ensemble.get_status() == "Stopped"
        successful_realizations = evaluator._ensemble.get_successful_realizations(
        )
        assert successful_realizations == config["realizations"]
Esempio n. 10
0
async def test_simple_record_transmit_and_dump(
    record_transmitter_factory_context: ContextManager[Callable[
        [str], RecordTransmitter]],
    data_in,
    expected_data,
    application_type,
):
    with record_transmitter_factory_context(
    ) as record_transmitter_factory, tmp():
        transmitter = record_transmitter_factory(name="some_name")
        await transmitter.transmit_data(data_in)

        await transmitter.dump("record.json")
        if application_type == "application/json":
            with open("record.json") as f:
                assert json.dumps(expected_data) == f.read()
        else:
            with open("record.json", "rb") as f:
                assert expected_data[0] == f.read()
Esempio n. 11
0
def test_cancel_run_prefect_ensemble(unused_tcp_port):
    with tmp(Path(SOURCE_DIR) / "test-data/local/prefect_test_case"):
        config = parse_config("config.yml")
        config.update({"config_path": Path.absolute(Path("."))})
        config.update({"realizations": 2})
        config.update({"executor": "local"})

        service_config = EvaluatorServerConfig(unused_tcp_port)
        ensemble = PrefectEnsemble(config)

        evaluator = EnsembleEvaluator(ensemble, service_config, ee_id="2")

        mon = evaluator.run()
        cancel = True
        for _ in mon.track():
            if cancel:
                mon.signal_cancel()
                cancel = False

        assert evaluator._snapshot.get_status() == "Cancelled"
Esempio n. 12
0
async def test_simple_record_transmit_from_file(
    record_transmitter_factory_context: ContextManager[Callable[
        [str], RecordTransmitter]],
    data_in,
    expected_data,
    application_type,
):
    filename = "record.file"
    with record_transmitter_factory_context(
    ) as record_transmitter_factory, tmp():
        transmitter = record_transmitter_factory(name="some_name")
        if application_type == "application/json":
            with open(filename, "w") as f:
                json.dump(expected_data, f)
        else:
            with open(filename, "wb") as f:
                f.write(expected_data[0])
        await transmitter.transmit_file(filename, mime=application_type)
        assert transmitter.is_transmitted()
        with pytest.raises(RuntimeError, match="Record already transmitted"):
            await transmitter.transmit_file(filename, mime=application_type)
Esempio n. 13
0
def test_cancel_run_prefect_ensemble(unused_tcp_port, coefficients):
    with tmp(Path(SOURCE_DIR) / "test-data/local/prefect_test_case"):
        config = parse_config("config.yml")
        config.update({
            "config_path": os.getcwd(),
            "realizations": 2,
            "executor": "local",
        })
        inputs = {}
        coeffs_trans = coefficient_transmitters(
            coefficients,
            config.get(ids.STORAGE)["storage_path"])
        script_trans = script_transmitters(config)
        for iens in range(2):
            inputs[iens] = {**coeffs_trans[iens], **script_trans[iens]}
        config.update({
            "inputs": inputs,
            "outputs": output_transmitters(config),
        })

        service_config = EvaluatorServerConfig(unused_tcp_port)
        config["config_path"] = Path(config["config_path"])
        config["run_path"] = Path(config["run_path"])
        config["storage"]["storage_path"] = Path(
            config["storage"]["storage_path"])
        config["dispatch_uri"] = service_config.dispatch_uri

        ensemble = PrefectEnsemble(config)

        evaluator = EnsembleEvaluator(ensemble, service_config, 0, ee_id="2")

        with evaluator.run() as mon:
            cancel = True
            for _ in mon.track():
                if cancel:
                    mon.signal_cancel()
                    cancel = False

        assert evaluator._snapshot.get_status() == "Cancelled"
Esempio n. 14
0
def test_function_step_for_function_defined_outside_py_environment(
        unused_tcp_port, tmpdir):
    # Create temporary module that defines a function `bar`
    # 'bar' returns a call to different function 'internal_call' defined in the same python file
    with tmpdir.as_cwd():
        module_path = Path(tmpdir) / "foo"
        module_path.mkdir()
        init_file = module_path / "__init__.py"
        init_file.touch()
        file_path = module_path / "bar.py"
        file_path.write_text(
            "def bar(values):\n    return internal_call(values)\n"
            "def internal_call(values):\n    return [sum(values)]\n")
        spec = importlib.util.spec_from_file_location("foo", str(file_path))
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        func = getattr(module, "bar")

    # Check module is not in the python environment
    with pytest.raises(ModuleNotFoundError):
        import foo.bar

    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    test_values = {"values": [42, 24, 6]}
    inputs = input_transmitter("values",
                               test_values["values"],
                               storage_path=tmpdir)

    step = get_step(
        step_name="test_step",
        inputs=[("values", "NA", "text/whatever")],
        outputs=[("output", Path("output.out"), "application/json")],
        jobs=[("test_function", cloudpickle.dumps(func), None)],
        type_="function",
    )
    expected_result = func(**test_values)
    # Make sure the function is no longer available before we start creating the flow and task
    del func

    with prefect.context(url=url, token=None, cert=None):
        output_trans = step_output_transmitters(step, storage_path=tmpdir)
        with Flow("testing") as flow:
            task = step.get_task(output_transmitters=output_trans,
                                 ee_id="test_ee_id")
            result = task(inputs=inputs)
        with tmp():
            flow_run = flow.run()

    # Stop the mock evaluator WS server
    with Client(url) as c:
        c.send("stop")
    mock_ws_thread.join()

    task_result = flow_run.result[result]
    assert task_result.is_successful()
    assert flow_run.is_successful()

    assert len(task_result.result) == 1
    expected_uri = output_trans["output"]._uri
    output_uri = task_result.result["output"]._uri
    assert expected_uri == output_uri
    transmitted_record = asyncio.get_event_loop().run_until_complete(
        task_result.result["output"].load())
    transmitted_result = transmitted_record.data
    assert expected_result == transmitted_result
Esempio n. 15
0
def test_on_task_failure(unused_tcp_port):
    host = "localhost"
    url = f"ws://{host}:{unused_tcp_port}"
    messages = []
    mock_ws_thread = threading.Thread(target=partial(_mock_ws,
                                                     messages=messages),
                                      args=(host, unused_tcp_port))

    mock_ws_thread.start()

    with tmp(Path(SOURCE_DIR) / "test-data/local/prefect_test_case", False):
        config = parse_config("config.yml")
        storage = storage_driver_factory(config=config.get("storage"),
                                         run_path=".")
        resource = storage.store("unix_test_retry_script.py")
        jobs = [{
            "id": "0",
            "name": "test_script",
            "executable": "unix_test_retry_script.py",
            "args": [],
        }]

        stage_task = UnixStep(
            resources=[resource],
            outputs=[],
            job_list=jobs,
            iens=1,
            cmd="python3",
            url=url,
            step_id="step_id_0",
            stage_id="stage_id_0",
            ee_id="ee_id_0",
            on_failure=partial(PrefectEnsemble._on_task_failure, url=url),
            run_path=config.get("run_path"),
            storage_config=config.get("storage"),
            max_retries=3,
            retry_delay=timedelta(seconds=1),
        )

        flow = Flow("testing")
        flow.add_task(stage_task)
        flow_run = flow.run()

        # Stop the mock evaluator WS server
        with Client(url) as c:
            c.send("stop")
        mock_ws_thread.join()

        task_result = flow_run.result[stage_task]
        assert task_result.is_successful()
        assert flow_run.is_successful()

        fail_job_messages = [
            msg for msg in messages if ids.EVTYPE_FM_JOB_FAILURE in msg
        ]
        fail_step_messages = [
            msg for msg in messages if ids.EVTYPE_FM_STEP_FAILURE in msg
        ]

        expected_job_failed_messages = 2
        expected_step_failed_messages = 0
        assert expected_job_failed_messages == len(fail_job_messages)
        assert expected_step_failed_messages == len(fail_step_messages)