示例#1
0
def test_set_task_run_state_with_error(patch_post):
    response = {
        "data": {
            "setTaskRunStates": None
        },
        "errors": [{
            "message": "something went wrong"
        }],
    }
    post = patch_post(response)

    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()

    with pytest.raises(ClientError, match="something went wrong"):
        client.set_task_run_state(task_run_id="76-salt",
                                  version=0,
                                  state=Pending())
示例#2
0
def test_client_is_always_called_even_during_state_handler_failures(client):
    def handler(task, old, new):
        1 / 0

    flow = prefect.Flow(name="test", tasks=[prefect.Task()], state_handlers=[handler])

    ## flow run setup
    res = flow.run(state=Pending())

    ## assertions
    assert client.get_flow_run_info.call_count == 1  # one time to pull latest state
    assert client.set_flow_run_state.call_count == 1  # Failed

    flow_states = [
        call[1]["state"] for call in client.set_flow_run_state.call_args_list
    ]
    state = flow_states.pop()
    assert state.is_failed()
    assert "state handlers" in state.message
    assert isinstance(state.result, ZeroDivisionError)
    assert client.get_task_run_info.call_count == 0
示例#3
0
    def test_reads_result_if_cached_valid_using_task_result(task, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        task = Task(
            result=MyResult(),
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
        )
        state = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=Pending(), inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
示例#4
0
def test_set_flow_run_state_with_error(monkeypatch):
    response = {
        "data": {
            "setFlowRunState": None
        },
        "errors": [{
            "message": "something went wrong"
        }],
    }
    post = MagicMock(return_value=MagicMock(json=MagicMock(
        return_value=response)))
    monkeypatch.setattr("requests.post", post)
    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    with pytest.raises(ClientError) as exc:
        client.set_flow_run_state(flow_run_id="74-salt",
                                  version=0,
                                  state=Pending())
    assert "something went wrong" in str(exc.value)
示例#5
0
    def check_task_is_cached(self, state: State,
                             inputs: Dict[str, Result]) -> State:
        """
        Checks if task is cached and whether the cache is still valid.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if self.task.cache_for is not None:
            candidate_states = prefect.context.caches.get(self.task.name, [])
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            for candidate in candidate_states:
                if self.task.cache_validator(
                        candidate, sanitized_inputs,
                        prefect.context.get("parameters")):
                    candidate._result = candidate._result.to_result()
                    return candidate
        if state.is_cached():
            assert isinstance(state, Cached)  # mypy assert
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            if self.task.cache_validator(state, sanitized_inputs,
                                         prefect.context.get("parameters")):
                state._result = state._result.to_result()
                return state
            else:
                self.logger.warning(
                    "Task '{name}': can't use cache because it "
                    "is now invalid".format(name=prefect.context.get(
                        "task_full_name", self.task.name)))
                return Pending("Cache was invalid; ready to run.")
        return state
示例#6
0
class TestCheckScheduledStep:
    @pytest.mark.parametrize("state", [Failed(), Pending(), Running(), Success()])
    def test_non_scheduled_states(self, state):
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )

    def test_scheduled_states_without_start_time(self):
        state = Scheduled(start_time=None)
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )

    def test_scheduled_states_with_future_start_time(self):
        state = Scheduled(
            start_time=pendulum.now("utc") + datetime.timedelta(minutes=10)
        )
        with pytest.raises(ENDRUN) as exc:
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
        assert exc.value.state is state

    def test_scheduled_states_with_past_start_time(self):
        state = Scheduled(
            start_time=pendulum.now("utc") - datetime.timedelta(minutes=1)
        )
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )
示例#7
0
 async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled(
         self, flow_run_id):
     task_runs = await models.TaskRun.where({
         "flow_run_id": {
             "_eq": flow_run_id
         }
     }).get({"id"})
     task_run_ids = [run.id for run in task_runs]
     # update the state to Running
     await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                         state=Running())
     # Currently this flow_run_id fixture has at least 3 tasks, if this
     # changes the test will need to be updated
     assert len(task_run_ids) >= 3, "flow_run_id fixture has changed"
     # Set one task run to pending, one to running, and the rest to success
     pending_task_run = task_run_ids[0]
     running_task_run = task_run_ids[1]
     rest = task_run_ids[2:]
     await api.states.set_task_run_state(task_run_id=pending_task_run,
                                         state=Pending())
     await api.states.set_task_run_state(task_run_id=running_task_run,
                                         state=Running())
     for task_run_id in rest:
         await api.states.set_task_run_state(task_run_id=task_run_id,
                                             state=Success())
     # set the flow run to a cancelled state
     await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                         state=Cancelled())
     # Confirm the unfinished task runs have been marked as cancelled
     task_runs = await models.TaskRun.where({
         "flow_run_id": {
             "_eq": flow_run_id
         }
     }).get({"id", "state"})
     new_states = {run.id: run.state for run in task_runs}
     assert new_states[pending_task_run] == "Cancelled"
     assert new_states[running_task_run] == "Cancelled"
     assert all(new_states[id] == "Success" for id in rest)
示例#8
0
def test_set_task_run_state(patch_post):
    response = {
        "data": {
            "set_task_run_states": {
                "states": [{
                    "status": "SUCCESS"
                }]
            }
        }
    }
    post = patch_post(response)
    state = Pending()

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    result = client.set_task_run_state(task_run_id="76-salt",
                                       version=0,
                                       state=state)

    assert result is state
示例#9
0
def test_set_task_run_state_serializes(patch_post):
    response = {
        "data": {
            "set_task_run_states": {
                "states": [{
                    "status": "SUCCESS"
                }]
            }
        }
    }
    post = patch_post(response)

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()

    res = SafeResult(lambda: None, result_handler=None)
    with pytest.raises(marshmallow.exceptions.ValidationError):
        client.set_task_run_state(task_run_id="76-salt",
                                  version=0,
                                  state=Pending(result=res))
示例#10
0
def test_task_runner_prioritizes_kwarg_states_over_db_states(monkeypatch, state):
    task = Task(name="test")
    db_state = state("already", result=10)
    get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
    set_task_run_state = MagicMock(
        side_effect=lambda task_run_id, version, state, cache_for: state
    )
    client = MagicMock(
        get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
    )
    res = CloudTaskRunner(task=task).run(
        state=Pending("let's do this"), context={"map_index": 1}
    )

    ## assertions
    assert get_task_run_info.call_count == 1  # one time to pull latest state
    assert set_task_run_state.call_count == 2  # Pending -> Running -> Success

    states = [call[1]["state"] for call in set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Running", "Success"]
示例#11
0
    def initialize_run(
            self, state: Optional[State],
            context: Dict[str, Any]) -> Tuple[State, Dict[str, Any]]:
        """
        Initializes the Task run by initializing state and context appropriately.

        If the provided state is a meta state, the state it wraps is extracted.

        Args:
            - state (Optional[State]): the initial state of the run
            - context (dict): the context to be updated with relevant information

        Returns:
            - tuple: a tuple of the updated state and context objects
        """

        # extract possibly nested meta states -> for example a Submitted( Queued( Retry ) )
        while isinstance(state, State) and state.is_meta_state():
            state = state.state  # type: ignore

        state = state or Pending()

        return state, context
示例#12
0
def test_set_flow_run_state(patch_post):
    response = {
        "data": {
            "set_flow_run_states": {
                "states": [{"id": 1, "status": "SUCCESS", "message": None}]
            }
        }
    }
    post = patch_post(response)

    with set_temporary_config(
        {
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
        }
    ):
        client = Client()

    state = Pending()
    result = client.set_flow_run_state(flow_run_id="74-salt", version=0, state=state)
    assert isinstance(result, State)
    assert isinstance(result, Pending)
def client(monkeypatch):
    cloud_client = MagicMock(
        get_flow_run_info=MagicMock(
            return_value=MagicMock(state=Pending(), parameters={})
        ),
        set_flow_run_state=MagicMock(
            side_effect=lambda flow_run_id, version, state: state
        ),
        get_task_run_info=MagicMock(return_value=MagicMock(state=None)),
        set_task_run_state=MagicMock(
            side_effect=lambda task_run_id, version, state, cache_for: state
        ),
        get_latest_task_run_states=MagicMock(
            side_effect=lambda flow_run_id, states, result_handler: states
        ),
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client)
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client)
    )
    yield cloud_client
示例#14
0
    def test_task_runner_preserves_location_of_inputs_when_retrying(
            self, client):
        """
        If a user opts out of checkpointing via checkpoint=False, we don't want to
        surprise them by storing the result in cached_inputs.  This test ensures
        that whatever location is provided to a downstream task is the one that is used.
        """
        @prefect.task(max_retries=1, retry_delay=datetime.timedelta(days=1))
        def add(x, y):
            return x + y

        x = PrefectResult(value=1)
        y = PrefectResult(value="0", location="foo")
        state = Pending(cached_inputs=dict(x=x, y=y))
        x_state = Success()
        y_state = Success()
        upstream_states = {
            Edge(Task(), Task(), key="x"): x_state,
            Edge(Task(), Task(), key="y"): y_state,
        }
        res = CloudTaskRunner(task=add).run(state=state,
                                            upstream_states=upstream_states)

        ## assertions
        assert client.get_task_run_info.call_count == 0  # never called
        assert (client.set_task_run_state.call_count == 3
                )  # Pending -> Running -> Failed -> Retrying

        states = [
            call[1]["state"]
            for call in client.set_task_run_state.call_args_list
        ]
        assert states[0].is_running()
        assert states[1].is_failed()
        assert isinstance(states[2], Retrying)
        assert states[2].cached_inputs["x"].location is None
        assert states[2].cached_inputs["y"].location == "foo"
示例#15
0
def test_task_runner_handles_looping_with_retries(client):
    # note that looping _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def looper():
        if (
            prefect.context.get("task_loop_count") == 2
            and prefect.context.get("task_run_count", 1) == 1
        ):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i))
        for i in range(5)
    ]
    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1}, state=None, upstream_states={}
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 4
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"]
        for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
    def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler(
        self, client
    ):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                new = self.copy()
                new.value = 1337
                return new

        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=MyResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": PrefectResult(value=2)}
        )
        assert client.get_latest_cached_states.called
        assert res.is_successful()
        assert res.is_cached()
        assert res.result == 1337
示例#17
0
def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler(
    client, ):
    class Handler(ResultHandler):
        def read(self, val):
            return 1337

    @prefect.task(
        cache_for=datetime.timedelta(minutes=1),
        cache_validator=all_inputs,
        result_handler=Handler(),
    )
    def cached_task(x):
        return 42

    dull_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=SafeResult("-1", JSONResultHandler()),
    )
    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=SafeResult("99", JSONResultHandler()),
        cached_inputs={
            "x": SafeResult("2", result_handler=JSONResultHandler())
        },
    )
    client.get_latest_cached_states = MagicMock(
        return_value=[dull_state, state])

    res = CloudTaskRunner(task=cached_task).check_task_is_cached(
        Pending(), inputs={"x": Result(2, result_handler=JSONResultHandler())})
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 1337
示例#18
0
def test_set_task_run_state_responds_to_status(patch_post):
    response = {
        "data": {
            "set_task_run_states": {
                "states": [{
                    "status": "QUEUED"
                }]
            }
        }
    }
    post = patch_post(response)
    state = Pending()

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    result = client.set_task_run_state(task_run_id="76-salt",
                                       version=0,
                                       state=state)

    assert result.is_queued()
    assert result.state is None  # caller should set this
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(
            task_run_id=task_run_id, state=Failed()
        )

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"}
        )

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(
                task_run_id=str(uuid.uuid4()), state=state
            )

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()]
    )
    async def test_state_does_not_set_heartbeat_unless_running(
        self, state, task_run_id
    ):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id, state=Running())

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id):
        await api.states.set_task_run_state(
            task_run_id=task_run_id, state=TriggerFailed()
        )
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"}
        )
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [s() for s in State.children() if s not in _MetaState.children()],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
        self, task_run_id, state, running_flow_run_id
    ):

        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(
            set=dict(serialized_state=cached_state.serialize())
        )

        # try to schedule the task run to scheduled
        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first(
            {"serialized_state"}
        )

        # ensure the state change took place
        assert task_run.serialized_state["type"] == type(state).__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children()
            if s not in _MetaState.children()
        ],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
        self, state, task_run_id, running_flow_run_id
    ):

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls", [s for s in State.children() if s not in _MetaState.children()]
    )
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
        self, state_cls, task_run_id, running_flow_run_id
    ):

        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = Box(
            await models.TaskRun.where(id=task_run_id).first({"serialized_state"})
        )

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}

    @pytest.mark.parametrize(
        "flow_run_state", [Pending(), Running(), Failed(), Success()]
    )
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
        self, flow_run_id, task_run_id, flow_run_state
    ):

        await api.states.set_flow_run_state(
            flow_run_id=flow_run_id, state=flow_run_state
        )

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running()
        )

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state != "Running"
示例#20
0
def test_task_map_with_no_upstream_results_and_a_mapped_state(executor):
    """
    This test makes sure that mapped tasks properly generate children tasks even when
    run multiple times and without available upstream results. In this test, we run the pipeline
    from a variety of starting points, ensuring that some upstream results are unavailable and
    checking that children pipelines are properly regenerated.
    """
    @prefect.task
    def numbers():
        return [1, 2, 3]

    @prefect.task
    def plus_one(x):
        return x + 1

    @prefect.task
    def get_sum(x):
        return sum(x)

    with Flow(name="test") as f:
        n = numbers()
        x = plus_one.map(n)
        y = plus_one.map(x)
        s = get_sum(y)

    # first run with a missing result from `n` but map_states for `x`
    state = FlowRunner(flow=f).run(
        executor=executor,
        task_states={
            n:
            Success(),
            x:
            Mapped(map_states=[
                Pending(cached_inputs={"x": Result(i)}) for i in range(1, 4)
            ]),
        },
        return_tasks=f.tasks,
    )

    assert state.is_successful()
    assert state.result[s].result == 12

    # next run with missing results for n and x
    state = FlowRunner(flow=f).run(
        executor=executor,
        task_states={
            n:
            Success(),
            x:
            Mapped(map_states=[Success(), Success(),
                               Success()]),
            y:
            Mapped(map_states=[
                Success(result=3),
                Success(result=4),
                Retrying(cached_inputs={"x": Result(4)}),
            ]),
        },
        return_tasks=f.tasks,
    )

    assert state.is_successful()
    assert state.result[s].result == 12

    # next run with missing results for n, x, and y
    state = FlowRunner(flow=f).run(
        executor=executor,
        task_states={
            n:
            Success(),
            x:
            Mapped(map_states=[Success(), Success(),
                               Success()]),
            y:
            Mapped(map_states=[
                Success(result=3),
                Success(result=4),
                Success(result=5)
            ]),
        },
        return_tasks=f.tasks,
    )

    assert state.is_successful()
    assert state.result[s].result == 12
示例#21
0
def test_states_are_hashable():
    assert {State(), Pending(), Success()}
示例#22
0
    def get_flow_run_state(
        self,
        state: State,
        task_states: Dict[Task, State],
        task_contexts: Dict[Task, Dict[str, Any]],
        return_tasks: Set[Task],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.executors.base.Executor",
    ) -> State:
        """
        Runs the flow.

        Args:
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to
                each task
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers
                that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing computation; defaults to the
                executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """
        # this dictionary is used for tracking the states of "children" mapped tasks;
        # when running on Dask, we want to avoid serializing futures, so instead
        # of storing child task states in the `map_states` attribute we instead store
        # in this dictionary and only after they are resolved do we attach them to the Mapped state
        mapped_children = dict()  # type: Dict[Task, list]

        if not state.is_running():
            self.logger.info("Flow is not in a Running state.")
            raise ENDRUN(state)

        if return_tasks is None:
            return_tasks = set()
        if set(return_tasks).difference(self.flow.tasks):
            raise ValueError("Some tasks in return_tasks were not found in the flow.")

        def extra_context(task: Task, task_index: int = None) -> dict:
            return {
                "task_name": task.name,
                "task_tags": task.tags,
                "task_index": task_index,
            }

        # -- process each task in order

        with self.check_for_cancellation(), executor.start():

            for task in self.flow.sorted_tasks():
                task_state = task_states.get(task)

                # if a task is a constant task, we already know its return value
                # no need to use up resources by running it through a task runner
                if task_state is None and isinstance(
                    task, prefect.tasks.core.constants.Constant
                ):
                    task_states[task] = task_state = Success(result=task.value)

                # Always restart completed resource setup/cleanup tasks and
                # secret tasks unless they were explicitly cached.
                # TODO: we only need to rerun these tasks if any pending
                # downstream tasks depend on them.
                if (
                    isinstance(
                        task,
                        (
                            prefect.tasks.core.resource_manager.ResourceSetupTask,
                            prefect.tasks.core.resource_manager.ResourceCleanupTask,
                            prefect.tasks.secrets.SecretBase,
                        ),
                    )
                    and task_state is not None
                    and task_state.is_finished()
                    and not task_state.is_cached()
                ):
                    task_states[task] = task_state = Pending()

                # if the state is finished, don't run the task, just use the provided state if
                # the state is cached / mapped, we still want to run the task runner pipeline
                # steps to either ensure the cache is still valid / or to recreate the mapped
                # pipeline for possible retries
                if (
                    isinstance(task_state, State)
                    and task_state.is_finished()
                    and not task_state.is_cached()
                    and not task_state.is_mapped()
                ):
                    continue

                upstream_states = {}  # type: Dict[Edge, State]

                # this dictionary is used exclusively for "reduce" tasks in particular we store
                # the states / futures corresponding to the upstream children, and if running
                # on Dask, let Dask resolve them at the appropriate time.
                # Note: this is an optimization that allows Dask to resolve the mapped
                # dependencies by "elevating" them to a function argument.
                upstream_mapped_states = {}  # type: Dict[Edge, list]

                # -- process each edge to the task
                for edge in self.flow.edges_to(task):

                    # load the upstream task states (supplying Pending as a default)
                    upstream_states[edge] = task_states.get(
                        edge.upstream_task, Pending(message="Task state not available.")
                    )

                    # if the edge is flattened and not the result of a map, then we
                    # preprocess the upstream states. If it IS the result of a
                    # map, it will be handled in `prepare_upstream_states_for_mapping`
                    if edge.flattened:
                        if not isinstance(upstream_states[edge], Mapped):
                            upstream_states[edge] = executor.submit(
                                executors.flatten_upstream_state, upstream_states[edge]
                            )

                    # this checks whether the task is a "reduce" task for a mapped pipeline
                    # and if so, collects the appropriate upstream children
                    if not edge.mapped and isinstance(upstream_states[edge], Mapped):
                        children = mapped_children.get(edge.upstream_task, [])

                        # if the edge is flattened, then we need to wait for the mapped children
                        # to complete and then flatten them
                        if edge.flattened:
                            children = executors.flatten_mapped_children(
                                mapped_children=children, executor=executor
                            )

                        upstream_mapped_states[edge] = children

                # augment edges with upstream constants
                for key, val in self.flow.constants[task].items():
                    edge = Edge(
                        upstream_task=prefect.tasks.core.constants.Constant(val),
                        downstream_task=task,
                        key=key,
                    )
                    upstream_states[edge] = Success(
                        "Auto-generated constant value",
                        result=ConstantResult(value=val),
                    )

                # handle mapped tasks
                if any(edge.mapped for edge in upstream_states.keys()):

                    # wait on upstream states to determine the width of the pipeline
                    # this is the key to depth-first execution
                    upstream_states = executor.wait(
                        {e: state for e, state in upstream_states.items()}
                    )
                    # we submit the task to the task runner to determine if
                    # we can proceed with mapping - if the new task state is not a Mapped
                    # state then we don't proceed
                    task_states[task] = executor.wait(
                        executor.submit(
                            run_task,
                            task=task,
                            state=task_state,  # original state
                            upstream_states=upstream_states,
                            context=dict(
                                prefect.context, **task_contexts.get(task, {})
                            ),
                            flow_result=self.flow.result,
                            task_runner_cls=self.task_runner_cls,
                            task_runner_state_handlers=task_runner_state_handlers,
                            upstream_mapped_states=upstream_mapped_states,
                            is_mapped_parent=True,
                            extra_context=extra_context(task),
                        )
                    )

                    # either way, we should now have enough resolved states to restructure
                    # the upstream states into a list of upstream state dictionaries to iterate over
                    list_of_upstream_states = (
                        executors.prepare_upstream_states_for_mapping(
                            task_states[task],
                            upstream_states,
                            mapped_children,
                            executor=executor,
                        )
                    )

                    submitted_states = []

                    for idx, states in enumerate(list_of_upstream_states):
                        # if we are on a future rerun of a partially complete flow run,
                        # there might be mapped children in a retrying state; this check
                        # looks into the current task state's map_states for such info
                        if (
                            isinstance(task_state, Mapped)
                            and len(task_state.map_states) >= idx + 1
                        ):
                            current_state = task_state.map_states[
                                idx
                            ]  # type: Optional[State]
                        elif isinstance(task_state, Mapped):
                            current_state = None
                        else:
                            current_state = task_state

                        # this is where each child is submitted for actual work
                        submitted_states.append(
                            executor.submit(
                                run_task,
                                task=task,
                                state=current_state,
                                upstream_states=states,
                                context=dict(
                                    prefect.context,
                                    **task_contexts.get(task, {}),
                                    map_index=idx,
                                ),
                                flow_result=self.flow.result,
                                task_runner_cls=self.task_runner_cls,
                                task_runner_state_handlers=task_runner_state_handlers,
                                upstream_mapped_states=upstream_mapped_states,
                                extra_context=extra_context(task, task_index=idx),
                            )
                        )
                    if isinstance(task_states.get(task), Mapped):
                        mapped_children[task] = submitted_states  # type: ignore

                else:
                    task_states[task] = executor.submit(
                        run_task,
                        task=task,
                        state=task_state,
                        upstream_states=upstream_states,
                        context=dict(prefect.context, **task_contexts.get(task, {})),
                        flow_result=self.flow.result,
                        task_runner_cls=self.task_runner_cls,
                        task_runner_state_handlers=task_runner_state_handlers,
                        upstream_mapped_states=upstream_mapped_states,
                        extra_context=extra_context(task),
                    )

            # ---------------------------------------------
            # Collect results
            # ---------------------------------------------

            # terminal tasks determine if the flow is finished
            terminal_tasks = self.flow.terminal_tasks()

            # reference tasks determine flow state
            reference_tasks = self.flow.reference_tasks()

            # wait until all terminal tasks are finished
            final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks)
            final_states = executor.wait(
                {
                    t: task_states.get(t, Pending("Task not evaluated by FlowRunner."))
                    for t in final_tasks
                }
            )

            # also wait for any children of Mapped tasks to finish, and add them
            # to the dictionary to determine flow state
            all_final_states = final_states.copy()
            for t, s in list(final_states.items()):
                if s.is_mapped():
                    # ensure we wait for any mapped children to complete
                    if t in mapped_children:
                        s.map_states = executor.wait(mapped_children[t])
                    s.result = [ms.result for ms in s.map_states]
                    all_final_states[t] = s.map_states

            assert isinstance(final_states, dict)

        key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks]))
        terminal_states = set(
            flatten_seq([all_final_states[t] for t in terminal_tasks])
        )
        return_states = {t: final_states[t] for t in return_tasks}

        state = self.determine_final_state(
            state=state,
            key_states=key_states,
            return_states=return_states,
            terminal_states=terminal_states,
        )

        return state
示例#23
0
def test_states_with_mutable_attrs_are_hashable():
    assert {State(result=[1]), Pending(cached_inputs=dict(a=1))}
示例#24
0
    def run(
        self,
        state: State = None,
        task_states: Dict[Task, State] = None,
        return_tasks: Iterable[Task] = None,
        parameters: Dict[str, Any] = None,
        task_runner_state_handlers: Iterable[Callable] = None,
        executor: "prefect.executors.Executor" = None,
        context: Dict[str, Any] = None,
        task_contexts: Dict[Task, Dict[str, Any]] = None,
    ) -> State:
        """
        The main endpoint for FlowRunners.  Calling this method will perform all
        computations contained within the Flow and return the final state of the Flow.

        Args:
            - state (State, optional): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict, optional): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - parameters (dict, optional): dictionary of any needed Parameter
                values, with keys being strings representing Parameter names and values being
                their corresponding values
            - task_runner_state_handlers (Iterable[Callable], optional): A list of state change
                handlers that will be provided to the task_runner, and called whenever a task
                changes state.
            - executor (Executor, optional): executor to use when performing
                computation; defaults to the executor specified in your prefect configuration
            - context (Dict[str, Any], optional): prefect.Context to use for execution
                to use for each Task run
            - task_contexts (Dict[Task, Dict[str, Any]], optional): contexts that will be
                provided to each task

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """
        self.logger.info("Beginning Flow run for '{}'".format(self.flow.name))

        # make copies to avoid modifying user inputs
        parameters = dict(parameters or {})
        task_states = dict(task_states or {})
        task_contexts = dict(task_contexts or {})
        # Default to global context, with provided context as override
        run_context = dict(prefect.context)
        run_context.update(context or {})

        if executor is None:
            # Use the executor on the flow, if configured
            executor = getattr(self.flow, "executor", None)
            if executor is None:
                executor = prefect.engine.get_default_executor_class()()

        self.logger.debug("Using executor type %s", type(executor).__name__)

        try:
            state, task_states, run_context, task_contexts = self.initialize_run(
                state=state,
                task_states=task_states,
                context=run_context,
                task_contexts=task_contexts,
                parameters=parameters,
            )

            with prefect.context(run_context):
                state = self.check_flow_is_pending_or_running(state)
                state = self.check_flow_reached_start_time(state)
                state = self.set_flow_to_running(state)
                state = self.get_flow_run_state(
                    state,
                    task_states=task_states,
                    task_contexts=task_contexts,
                    return_tasks=return_tasks,
                    task_runner_state_handlers=task_runner_state_handlers,
                    executor=executor,
                )

        except ENDRUN as exc:
            state = exc.state

        # All other exceptions are trapped and turned into Failed states
        except Exception as exc:
            self.logger.exception(
                "Unexpected error while running flow: {}".format(repr(exc))
            )
            if run_context.get("raise_on_exception"):
                raise exc
            new_state = Failed(
                message="Unexpected error while running flow: {}".format(repr(exc)),
                result=exc,
            )
            state = self.handle_state_change(state or Pending(), new_state)

        return state
示例#25
0
文件: runs.py 项目: kad-schoom/server
async def get_or_create_task_run_info(flow_run_id: str,
                                      task_id: str,
                                      map_index: int = None) -> dict:
    """
    Given a flow_run_id, task_id, and map_index, return details about the corresponding task run.
    If the task run doesn't exist, it will be created.

    Returns:
        - dict: a dict of details about the task run, including its id, version, and state.
    """

    if map_index is None:
        map_index = -1

    task_run = await models.TaskRun.where({
        "flow_run_id": {
            "_eq": flow_run_id
        },
        "task_id": {
            "_eq": task_id
        },
        "map_index": {
            "_eq": map_index
        },
    }).first({"id", "version", "state", "serialized_state"})

    if task_run:
        return dict(
            id=task_run.id,
            version=task_run.version,
            state=task_run.state,
            serialized_state=task_run.serialized_state,
        )

    # if it isn't found, add it to the DB
    task = await models.Task.where(id=task_id
                                   ).first({"cache_key", "tenant_id"})
    if not task:
        raise ValueError("Invalid task ID")

    db_task_run = models.TaskRun(
        tenant_id=task.tenant_id,
        flow_run_id=flow_run_id,
        task_id=task_id,
        map_index=map_index,
        cache_key=task.cache_key,
        version=0,
    )

    db_task_run_state = models.TaskRunState(
        tenant_id=task.tenant_id,
        state="Pending",
        timestamp=pendulum.now(),
        message="Task run created",
        serialized_state=Pending(message="Task run created").serialize(),
    )

    db_task_run.states = [db_task_run_state]
    run = await db_task_run.insert(
        on_conflict=dict(
            constraint="task_run_unique_identifier_key",
            update_columns=["cache_key"],
        ),
        selection_set={"returning": {"id"}},
    )

    return dict(
        id=run.returning.id,
        version=db_task_run.version,
        state="Pending",
        serialized_state=db_task_run_state.serialized_state,
    )
示例#26
0
文件: runs.py 项目: kad-schoom/server
async def _create_flow_run(
    flow_id: str = None,
    parameters: dict = None,
    context: dict = None,
    scheduled_start_time: datetime.datetime = None,
    flow_run_name: str = None,
    version_group_id: str = None,
    labels: List[str] = None,
    run_config: dict = None,
) -> Any:
    """
    Creates a new flow run for an existing flow.

    Args:
        - flow_id (str): A string representing the current flow id
        - parameters (dict, optional): A dictionary of parameters that were specified for the flow
        - context (dict, optional): A dictionary of context values
        - scheduled_start_time (datetime.datetime): When the flow_run should be scheduled to run. If `None`,
            defaults to right now. Must be UTC.
        - flow_run_name (str, optional): An optional string representing this flow run
        - version_group_id (str, optional): An optional version group ID; if provided, will run the most
            recent unarchived version of the group
        - labels (List[str], optional): a list of labels to apply to this individual flow run
        - run-config (dict, optional): A run-config override for this flow run.
    """

    if flow_id is None and version_group_id is None:
        raise ValueError(
            "One of flow_id or version_group_id must be provided.")

    scheduled_start_time = scheduled_start_time or pendulum.now()

    if flow_id:
        where_clause = {"id": {"_eq": flow_id}}
    elif version_group_id:
        where_clause = {
            "version_group_id": {
                "_eq": version_group_id
            },
            "archived": {
                "_eq": False
            },
        }

    flow = await models.Flow.where(where=where_clause).first(
        {
            "id": True,
            "archived": True,
            "tenant_id": True,
            "environment": True,
            "run_config": True,
            "parameters": True,
            "flow_group_id": True,
            "flow_group": {
                "default_parameters": True,
                "labels": True,
                "run_config": True,
            },
        },
        order_by={"version": EnumValue("desc")},
    )  # type: Any

    if not flow:
        msg = (f"Flow {flow_id} not found" if flow_id else
               f"Version group {version_group_id} has no unarchived flows.")
        raise exceptions.NotFound(msg)
    elif flow.archived:
        raise ValueError(f"Flow {flow.id} is archived.")

    # determine active labels
    if labels is not None:
        run_labels = labels
    elif run_config is not None:
        run_labels = run_config.get("labels") or []
    elif flow.flow_group.labels is not None:
        run_labels = flow.flow_group.labels
    elif flow.flow_group.run_config is not None:
        run_labels = flow.flow_group.run_config.get("labels") or []
    elif flow.run_config is not None:
        run_labels = flow.run_config.get("labels") or []
    elif flow.environment is not None:
        run_labels = flow.environment.get("labels") or []
    else:
        run_labels = []
    run_labels.sort()

    # determine active run_config
    if run_config is None:
        if flow.flow_group.run_config is not None:
            run_config = flow.flow_group.run_config
        else:
            run_config = flow.run_config

    # check parameters
    run_parameters = flow.flow_group.default_parameters
    run_parameters.update((parameters or {}))
    required_parameters = [p["name"] for p in flow.parameters if p["required"]]
    missing = set(required_parameters).difference(run_parameters)
    if missing:
        raise ValueError(f"Required parameters were not supplied: {missing}")
    state = Scheduled(message="Flow run scheduled.",
                      start_time=scheduled_start_time)

    run = models.FlowRun(
        tenant_id=flow.tenant_id,
        flow_id=flow_id or flow.id,
        labels=run_labels,
        parameters=run_parameters,
        run_config=run_config,
        context=context or {},
        scheduled_start_time=scheduled_start_time,
        name=flow_run_name or names.generate_slug(2),
        states=[
            models.FlowRunState(
                tenant_id=flow.tenant_id,
                **models.FlowRunState.fields_from_state(
                    Pending(message="Flow run created")),
            )
        ],
    )

    flow_run_id = await run.insert()

    # apply the flow run's initial state via `set_flow_run_state`
    await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state)

    return flow_run_id
示例#27
0
def test_preparing_state_for_cloud_replaces_cached_inputs_with_safe():
    xres = Result(3, result_handler=JSONResultHandler())
    state = prepare_state_for_cloud(Pending(cached_inputs=dict(x=xres)))
    assert state.is_pending()
    assert state.result == NoResult
    assert state.cached_inputs == dict(x=xres)
示例#28
0
 def __init__(self, id, state=None, version=None):
     self.id = id
     self.state = state or Pending()
     self.version = version or 0
示例#29
0
def test_flow_run_handles_error_states_when_initial_state_is_provided():
    with Flow(name="test") as f:
        res = AddTask()("5", 5)
    state = f.run(state=Pending())
    assert state.is_failed()
示例#30
0

@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Skipped(),
             assert_true={"is_finished", "is_successful", "is_skipped"}),
        dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}),
        dict(state=Success(), assert_true={"is_finished", "is_successful"}),
        dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}),
        dict(state=TriggerFailed(), assert_true={"is_finished", "is_failed"}),
    ],
)
def test_state_is_methods(state_check):