def test_task_runner_handles_looping_with_retries(client): # note that looping _requires_ a result handler in Cloud @prefect.task( max_retries=1, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def looper(): if (prefect.context.get("task_loop_count") == 2 and prefect.context.get("task_run_count", 1) == 1): raise ValueError("Stop") if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") client.get_task_run_info.side_effect = [ MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i)) for i in range(5) ] res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1}, state=None, upstream_states={}) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 4 assert ( client.set_task_run_state.call_count == 9 ) # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 2, 3]
def test_reads_result_if_cached_valid(self, client): result = PrefectResult(location="2") with pytest.warns(UserWarning): task = Task(cache_validator=duration_only, result=PrefectResult()) state = Cached( result=result, cached_result_expiration=pendulum.now("utc").add(minutes=1)) client.get_latest_cached_states = MagicMock(return_value=[]) new = CloudTaskRunner(task).check_task_is_cached( state=state, inputs={"a": PrefectResult(value=1)}) assert new is state assert new.result == 2
def test_task_runner_handles_looping(client): @prefect.task(result=PrefectResult()) def looper(): if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") client.get_task_run_info.side_effect = [ MagicMock(version=i, state=Pending()) for i in range(1, 4) ] res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1}, state=None, upstream_states={}) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 3 assert ( client.set_task_run_state.call_count == 6 ) # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 2, 3]
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud( client): calls = [] def queued_mock(*args, **kwargs): calls.append(kwargs) # first retry attempt will get queued if len(calls) == 4: return Queued() # immediate start time else: return kwargs.get("state") client.set_task_run_state = queued_mock @prefect.task( max_retries=2, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def tagged_task(x): if prefect.context.get("task_run_count", 1) == 1: raise ValueError("gimme a sec") return x upstream_result = PrefectResult(value=42, location="42") res = CloudTaskRunner(task=tagged_task).run( context={"task_run_version": 1}, state=None, upstream_states={ Edge(Task(), tagged_task, key="x"): Success(result=upstream_result) }, ) assert res.is_successful() assert res.result == 42 assert (len(calls) == 6 ) # Running -> Failed -> Retrying -> Queued -> Running -> Success assert [type(c["state"]).__name__ for c in calls] == [ "Running", "Failed", "Retrying", "Running", "Running", "Success", ] assert calls[2]["state"].cached_inputs["x"].value == 42
def test_serialize_and_deserialize_on_mixed_cached_state(): safe_dct = PrefectResult(location=json.dumps(dict(hi=5, bye=6))) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=PrefectResult(value=2), p=PrefectResult(value="p")), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state._result.location == json.dumps(dict(hi=5, bye=6)) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == dict.fromkeys(["x", "p"], PrefectResult())
def test_deep_map_with_a_failure(monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) task_run_id_3 = str(uuid.uuid4()) with prefect.Flow(name="test", result=PrefectResult()) as flow: t1 = plus_one.map([-1, 0, 1]) t2 = invert_fail_once.map(t1) t3 = plus_one.map(t2) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id), TaskRun(id=task_run_id_2, task_slug=flow.slugs[t2], flow_run_id=flow_run_id), TaskRun(id=task_run_id_3, task_slug=flow.slugs[t3], flow_run_id=flow_run_id), ] + [ TaskRun(id=str(uuid.uuid4()), task_slug=flow.slugs[t], flow_run_id=flow_run_id) for t in flow.tasks if t not in [t1, t2, t3] ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks) assert state.is_failed() assert client.flow_runs[flow_run_id].state.is_failed() assert client.task_runs[task_run_id_1].state.is_mapped() assert client.task_runs[task_run_id_2].state.is_mapped() assert client.task_runs[task_run_id_3].state.is_mapped() # there should be a total of 4 task runs corresponding to each mapped task for t in [t1, t2, t3]: assert (len([ tr for tr in client.task_runs.values() if tr.task_slug == flow.slugs[t] ]) == 4) # t2's first child task should have failed t2_0 = next(tr for tr in client.task_runs.values() if tr.task_slug == flow.slugs[t2] and tr.map_index == 0) assert t2_0.state.is_failed() # t3's first child task should have failed t3_0 = next(tr for tr in client.task_runs.values() if tr.task_slug == flow.slugs[t3] and tr.map_index == 0) assert t3_0.state.is_failed()
def test_deep_map(monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) task_run_id_3 = str(uuid.uuid4()) with prefect.Flow(name="test", result=PrefectResult()) as flow: t1 = plus_one.map([0, 1, 2]) t2 = plus_one.map(t1) t3 = plus_one.map(t2) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun( id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_3, task_slug=flow.slugs[t3], flow_run_id=flow_run_id ), ] + [ TaskRun( id=str(uuid.uuid4()), task_slug=flow.slugs[t], flow_run_id=flow_run_id ) for t in flow.tasks if t not in [t1, t2, t3] ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run( return_tasks=flow.tasks, executor=executor ) assert state.is_successful() assert client.flow_runs[flow_run_id].state.is_successful() assert client.task_runs[task_run_id_1].state.is_mapped() assert client.task_runs[task_run_id_2].state.is_mapped() assert client.task_runs[task_run_id_3].state.is_mapped() # there should be a total of 4 task runs corresponding to each mapped task for t in [t1, t2, t3]: assert ( len( [ tr for tr in client.task_runs.values() if tr.task_slug == flow.slugs[t] ] ) == 4 )
def test_task_runner_uses_upstream_result_handlers(client): class MyResult(Result): def read(self, *args, **kwargs): self.value = "cool" return self def write(self, *args, **kwargs): return self @prefect.task(result=PrefectResult()) def t(x): return x success = Success(result=PrefectResult(location="1")) upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success} state = CloudTaskRunner(task=t).run(upstream_states=upstream_states) assert state.is_successful() assert state.result == "cool"
def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks to see if a FAILED task should be retried. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ if state.is_failed(): run_count = prefect.context.get("task_run_count", 1) if prefect.context.get("task_loop_count") is not None: loop_result = self.result.from_value( value=prefect.context.get("task_loop_result") ) ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing if ( prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and loop_result.value is not None ): try: value = prefect.context.get("task_loop_result") loop_result = self.result.write( value, filename="output", **prefect.context ) except NotImplementedError: pass loop_context = { "_loop_count": PrefectResult( location=json.dumps(prefect.context["task_loop_count"]), ), "_loop_result": loop_result, } inputs.update(loop_context) if run_count <= self.task.max_retries: start_time = pendulum.now("utc") + self.task.retry_delay msg = "Retrying Task (after attempt {n} of {m})".format( n=run_count, m=self.task.max_retries + 1 ) retry_state = Retrying( start_time=start_time, cached_inputs=inputs, message=msg, run_count=run_count, ) return retry_state return state
def test_cloud_task_runner_sends_heartbeat_on_queued_retries(client): calls = [] tr_ids = [] def queued_mock(*args, **kwargs): calls.append(kwargs) # first retry attempt will get queued if len(calls) == 4: return Queued() # immediate start time else: return kwargs.get("state") def mock_heartbeat(**kwargs): tr_ids.append(kwargs.get("task_run_id")) client.set_task_run_state = queued_mock client.update_task_run_heartbeat = mock_heartbeat @prefect.task( max_retries=2, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def tagged_task(x): if prefect.context.get("task_run_count", 1) == 1: raise ValueError("gimme a sec") return x upstream_result = PrefectResult(value=42, location="42") CloudTaskRunner(task=tagged_task).run( context={ "task_run_version": 1, "task_run_id": "id" }, state=None, upstream_states={ Edge(Task(), tagged_task, key="x"): Success(result=upstream_result) }, ) assert len(calls) == 6 assert tr_ids == ["id", "id"]
def test_load_results_from_upstream_reads_results_using_upstream_handlers( self, cloud_api): class CustomResult(Result): def read(self, *args, **kwargs): return "foo-bar-baz".split("-") state = Success(result=PrefectResult(location="1")) edge = Edge(Task(result=CustomResult()), 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task()).load_results( state=Pending(), upstream_states={edge: state}) assert upstreams[edge].result == ["foo", "bar", "baz"]
def test_state_kwarg_is_prioritized_over_db_caches(self, client): task = Task( cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, result=PrefectResult(), ) state_a = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) state_b = Cached( result=PrefectResult(location="99"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state_a]) new = CloudTaskRunner(task).check_task_is_cached( state=state_b, inputs={"a": Result(1)}) assert new is state_b assert new.result == 99
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch): @prefect.task(name="test", result=PrefectResult()) def add_one(x): return x + 1 db_state = Retrying(cached_inputs=dict(x=PrefectResult(value=41))) get_task_run_info = MagicMock(return_value=MagicMock(state=db_state)) set_task_run_state = MagicMock( side_effect=lambda task_run_id, version, state, cache_for: state) client = MagicMock(get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) res = CloudTaskRunner(task=add_one).run(context={"map_index": 1}) ## assertions assert get_task_run_info.call_count == 1 # one time to pull latest state assert set_task_run_state.call_count == 2 # Pending -> Running -> Success assert res.is_successful() assert res.result == 42
def __init__( self, name: str, default: JSONSerializableParameterValue = no_default, required: bool = True, tags: Iterable[str] = None, ) -> None: super().__init__(name=name, default=default, required=required, tags=tags) self.result = PrefectResult(serializer=DateTimeSerializer())
def __init__( self, name: str, required: bool = True, tags: Iterable[str] = None, ) -> None: default = no_default if required else None super().__init__(name=name, default=default, required=required, tags=tags) self.result = PrefectResult(serializer=DateTimeSerializer())
def test_load_results_from_upstream_reads_secret_results(self, cloud_api): secret_result = SecretResult( prefect.tasks.secrets.PrefectSecret(name="foo")) state = Success(result=PrefectResult(location="foo")) with prefect.context(secrets=dict(foo=42)): edge = Edge(Task(result=secret_result), 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task()).load_results( state=Pending(), upstream_states={edge: state}) assert upstreams[edge].result == 42
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=LocalResult(location=str(tmpdir / "made_up_data.prefect")), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 13
def test_task_runner_validates_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(minutes=2), result=PrefectResult(location="99"), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 42
def test_task_runner_preserves_location_of_inputs_when_retrying( self, client): """ If a user opts out of checkpointing via checkpoint=False, we don't want to surprise them by storing the result in cached_inputs. This test ensures that whatever location is provided to a downstream task is the one that is used. """ @prefect.task(max_retries=1, retry_delay=datetime.timedelta(days=1)) def add(x, y): return x + y x = PrefectResult(value=1) y = PrefectResult(value="0", location="foo") state = Pending(cached_inputs=dict(x=x, y=y)) x_state = Success() y_state = Success() upstream_states = { Edge(Task(), Task(), key="x"): x_state, Edge(Task(), Task(), key="y"): y_state, } res = CloudTaskRunner(task=add).run(state=state, upstream_states=upstream_states) ## assertions assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 3 ) # Pending -> Running -> Failed -> Retrying states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert states[0].is_running() assert states[1].is_failed() assert isinstance(states[2], Retrying) assert states[2].cached_inputs["x"].location is None assert states[2].cached_inputs["y"].location == "foo"
def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler( self, client ): class MyResult(Result): def read(self, *args, **kwargs): new = self.copy() new.value = 1337 return new @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result=MyResult(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="-1"), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="99"), cached_inputs={"x": PrefectResult(location="2")}, ) client.get_latest_cached_states = MagicMock(return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": PrefectResult(value=2)} ) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 1337
async def test_set_task_run_state_with_result(self, run_query, task_run_id): result = PrefectResult(location="10") state = Success(result=result) result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(task_run_id=task_run_id, state=state.serialize()) ])), ) tr = await models.TaskRun.where( id=result.data.set_task_run_states.states[0].id ).first({"state", "version"}) assert tr.version == 2 assert tr.state == "Success"
def __init__( self, name: str, default: Any = no_default, required: bool = None, tags: Iterable[str] = None, ): if required is None: required = default is no_default if default is no_default: default = None self.required = required self.default = default super().__init__( name=name, slug=name, tags=tags, result=PrefectResult(), checkpoint=True, )
def test_serializer_not_configurable(self): # By default creates own JSONSerializer result = PrefectResult() assert isinstance(result.serializer, JSONSerializer) # Can specify one manually as well serializer = JSONSerializer() result = PrefectResult(serializer=serializer) assert result.serializer is serializer # Can set if it's a JSONSerializer serializer2 = JSONSerializer() result.serializer = serializer2 assert result.serializer is serializer2 # Type errors for other serializer types with pytest.raises(TypeError): result.serializer = PickleSerializer() with pytest.raises(TypeError): result = PrefectResult(serializer=PickleSerializer())
def test_task_runner_gracefully_handles_load_results_failures(client): class MyResult(Result): def read(self, *args, **kwargs): raise TypeError("something is wrong!") @prefect.task(result=PrefectResult()) def t(x): return x success = Success(result=MyResult(location="foo.txt")) upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success} state = CloudTaskRunner(task=t).run(upstream_states=upstream_states) assert state.is_failed() assert "task results" in state.message assert client.set_task_run_state.call_count == 1 # Pending -> Failed states = [call[1]["state"] for call in client.set_task_run_state.call_args_list] assert [type(s).__name__ for s in states] == ["Failed"]
def test_reads_result_using_handler_attribute_if_cached_valid( self, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self with pytest.warns(UserWarning): task = Task(cache_validator=duration_only, result=MyResult()) result = PrefectResult(location="2") state = Cached( result=result, cached_result_expiration=pendulum.now("utc").add(minutes=1)) client.get_latest_cached_states = MagicMock(return_value=[]) new = CloudTaskRunner(task).check_task_is_cached( state=state, inputs={"a": Result(1)}) assert new is state assert new.result == 53
def __init__( self, name=None, slug=None, tags=None, max_retries=None, retry_delay=None, timeout=None, trigger=None, skip_on_upstream_skip=True, cache_for=None, cache_validator=None, cache_key=None, checkpoint=None, result_handler=None, state_handlers=None, on_failure=None, log_stdout=False, result=PrefectResult(), target=None, ): super().__init__( name=name, slug=slug, tags=tags, max_retries=max_retries, retry_delay=retry_delay, timeout=timeout, trigger=trigger, skip_on_upstream_skip=skip_on_upstream_skip, cache_for=cache_for, cache_validator=cache_validator, cache_key=cache_key, checkpoint=checkpoint, result_handler=result_handler, state_handlers=state_handlers, on_failure=on_failure, log_stdout=log_stdout, result=result, target=target, )
def test_reads_result_if_cached_valid_using_task_result(task, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self task = Task( result=MyResult(), cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, ) state = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state]) new = CloudTaskRunner(task).check_task_is_cached( state=Pending(), inputs={"a": Result(1)}) assert new is state assert new.result == 53
def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks to see if a FAILED task should be retried. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ if state.is_failed(): run_count = prefect.context.get("task_run_count", 1) if prefect.context.get("task_loop_count") is not None: loop_context = { "_loop_count": PrefectResult(location=json.dumps( prefect.context["task_loop_count"]), ), "_loop_result": self.result.from_value( value=prefect.context.get("task_loop_result")), } inputs.update(loop_context) if run_count <= self.task.max_retries: start_time = pendulum.now("utc") + self.task.retry_delay msg = "Retrying Task (after attempt {n} of {m})".format( n=run_count, m=self.task.max_retries + 1) retry_state = Retrying( start_time=start_time, cached_inputs=inputs, message=msg, run_count=run_count, ) return retry_state return state
def test_starting_at_arbitrary_loop_index_from_cloud_context(client): @prefect.task def looper(x): if prefect.context.get("task_loop_count", 1) < 20: raise LOOP(result=prefect.context.get("task_loop_result", 0) + x) return prefect.context.get("task_loop_result", 0) + x @prefect.task def downstream(l): return l**2 with prefect.Flow(name="looping", result=PrefectResult()) as f: inter = looper(10) final = downstream(inter) client.get_flow_run_info = MagicMock(return_value=MagicMock( context={"task_loop_count": 20})) flow_state = CloudFlowRunner(flow=f).run(return_tasks=[inter, final]) assert flow_state.is_successful() assert flow_state.result[inter].result == 10 assert flow_state.result[final].result == 100
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary( client, ): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") return p with prefect.Flow("test", result=PrefectResult()) as f: p = prefect.tasks.secrets.PrefectSecret("p") res = is_p_three(p) with prefect.context(secrets=dict(p=3)): state = CloudFlowRunner(flow=f).run(return_tasks=[res]) assert state.is_running() assert isinstance(state.result[res], Retrying) assert state.result[res].cached_inputs["p"].location is None ## here we set the result of the secret to an empty result, ensuring ## it will get converted to a "true" result; ## we expect that the upstream value will actually get recomputed from context ## through the SecretResultHandler safe = SecretResult(p) state.result[p] = Success(result=safe) state.result[res].start_time = pendulum.now("utc") state.result[res].cached_inputs = dict(p=safe) with prefect.context(secrets=dict(p=4)): new_state = CloudFlowRunner(flow=f).run( return_tasks=[res], task_states=state.result ) assert new_state.is_successful() assert new_state.result[res].result == 4