def test_state_handler_failures_are_handled_appropriately(client, caplog):
    def bad(*args, **kwargs):
        raise SyntaxError("Syntax Errors are nice because they're so unique")

    @prefect.task(on_failure=bad)
    def do_nothing():
        raise ValueError("This task failed somehow")

    res = CloudTaskRunner(task=do_nothing).run()
    assert res.is_failed()
    assert "SyntaxError" in res.message
    assert isinstance(res.result, SyntaxError)

    assert client.set_task_run_state.call_count == 2
    states = [
        call[1]["state"] for call in client.set_task_run_state.call_args_list
    ]
    assert states[0].is_running()
    assert states[1].is_failed()
    assert isinstance(states[1].result, SyntaxError)

    error_logs = [r.message for r in caplog.records if r.levelname == "ERROR"]
    assert len(error_logs) >= 2
    assert any("This task failed somehow" in elog for elog in error_logs)
    assert "SyntaxError" in error_logs[-1]
    assert "unique" in error_logs[-1]
    assert "state handler" in error_logs[-1]
    def test_task_runner_errors_if_no_result_provided_as_input(self, client):
        @prefect.task
        def add(x, y):
            return x + y

        base_state = prefect.serialization.state.StateSchema().load(
            {"type": "Success"})
        x_state, y_state = base_state, base_state

        upstream_states = {
            Edge(Task(), Task(), key="x"): x_state,
            Edge(Task(), Task(), key="y"): y_state,
        }

        res = CloudTaskRunner(task=add).run(upstream_states=upstream_states)
        assert res.is_failed()
        assert "unsupported operand" in res.message

        ## assertions
        assert client.get_task_run_info.call_count == 0  # never called
        assert client.set_task_run_state.call_count == 2  # Pending -> Running -> Failed

        states = [
            call[1]["state"]
            for call in client.set_task_run_state.call_args_list
        ]
        assert states[0].is_running(
        )  # this isn't ideal, it's a little confusing
        assert states[1].is_failed()
        assert "unsupported operand" in states[1].message
def test_task_runner_raises_endrun_if_client_cant_receive_state_updates(
        monkeypatch):
    task = Task(name="test")
    get_task_run_info = MagicMock(side_effect=SyntaxError)
    set_task_run_state = MagicMock()
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))

    ## an ENDRUN will cause the TaskRunner to return the most recently computed state
    res = CloudTaskRunner(task=task).run(context={"map_index": 1})
    assert get_task_run_info.called
    assert res.is_failed()
    assert isinstance(res.result, SyntaxError)
def test_state_handler_failures_are_handled_appropriately(client):
    def bad(*args, **kwargs):
        raise SyntaxError("Syntax Errors are nice because they're so unique")

    @prefect.task(on_failure=bad)
    def do_nothing():
        raise ValueError("This task failed somehow")

    res = CloudTaskRunner(task=do_nothing).run()
    assert res.is_failed()
    assert "SyntaxError" in res.message
    assert isinstance(res.result, SyntaxError)

    assert client.set_task_run_state.call_count == 2
    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert states[0].is_running()
    assert states[1].is_failed()
    assert isinstance(states[1].result, SyntaxError)
def test_task_runner_gracefully_handles_load_results_failures(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            raise TypeError("something is wrong!")

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=MyResult(location="foo.txt"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)

    assert state.is_failed()
    assert "task results" in state.message
    assert client.set_task_run_state.call_count == 1  # Pending -> Failed

    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Failed"]