def test_set_task_run_state_with_error(patch_post): response = { "data": { "setTaskRunStates": None }, "errors": [{ "message": "something went wrong" }], } post = patch_post(response) with set_temporary_config({ "cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() with pytest.raises(ClientError, match="something went wrong"): client.set_task_run_state(task_run_id="76-salt", version=0, state=Pending())
def test_client_is_always_called_even_during_state_handler_failures(client): def handler(task, old, new): 1 / 0 flow = prefect.Flow(name="test", tasks=[prefect.Task()], state_handlers=[handler]) ## flow run setup res = flow.run(state=Pending()) ## assertions assert client.get_flow_run_info.call_count == 1 # one time to pull latest state assert client.set_flow_run_state.call_count == 1 # Failed flow_states = [ call[1]["state"] for call in client.set_flow_run_state.call_args_list ] state = flow_states.pop() assert state.is_failed() assert "state handlers" in state.message assert isinstance(state.result, ZeroDivisionError) assert client.get_task_run_info.call_count == 0
def test_reads_result_if_cached_valid_using_task_result(task, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self task = Task( result=MyResult(), cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, ) state = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state]) new = CloudTaskRunner(task).check_task_is_cached( state=Pending(), inputs={"a": Result(1)}) assert new is state assert new.result == 53
def test_set_flow_run_state_with_error(monkeypatch): response = { "data": { "setFlowRunState": None }, "errors": [{ "message": "something went wrong" }], } post = MagicMock(return_value=MagicMock(json=MagicMock( return_value=response))) monkeypatch.setattr("requests.post", post) with set_temporary_config({ "cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() with pytest.raises(ClientError) as exc: client.set_flow_run_state(flow_run_id="74-salt", version=0, state=Pending()) assert "something went wrong" in str(exc.value)
def check_task_is_cached(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if task is cached and whether the cache is still valid. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if self.task.cache_for is not None: candidate_states = prefect.context.caches.get(self.task.name, []) sanitized_inputs = {key: res.value for key, res in inputs.items()} for candidate in candidate_states: if self.task.cache_validator( candidate, sanitized_inputs, prefect.context.get("parameters")): candidate._result = candidate._result.to_result() return candidate if state.is_cached(): assert isinstance(state, Cached) # mypy assert sanitized_inputs = {key: res.value for key, res in inputs.items()} if self.task.cache_validator(state, sanitized_inputs, prefect.context.get("parameters")): state._result = state._result.to_result() return state else: self.logger.warning( "Task '{name}': can't use cache because it " "is now invalid".format(name=prefect.context.get( "task_full_name", self.task.name))) return Pending("Cache was invalid; ready to run.") return state
class TestCheckScheduledStep: @pytest.mark.parametrize("state", [Failed(), Pending(), Running(), Success()]) def test_non_scheduled_states(self, state): assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_without_start_time(self): state = Scheduled(start_time=None) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_with_future_start_time(self): state = Scheduled( start_time=pendulum.now("utc") + datetime.timedelta(minutes=10) ) with pytest.raises(ENDRUN) as exc: FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) assert exc.value.state is state def test_scheduled_states_with_past_start_time(self): state = Scheduled( start_time=pendulum.now("utc") - datetime.timedelta(minutes=1) ) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state )
async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled( self, flow_run_id): task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"id"}) task_run_ids = [run.id for run in task_runs] # update the state to Running await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) # Currently this flow_run_id fixture has at least 3 tasks, if this # changes the test will need to be updated assert len(task_run_ids) >= 3, "flow_run_id fixture has changed" # Set one task run to pending, one to running, and the rest to success pending_task_run = task_run_ids[0] running_task_run = task_run_ids[1] rest = task_run_ids[2:] await api.states.set_task_run_state(task_run_id=pending_task_run, state=Pending()) await api.states.set_task_run_state(task_run_id=running_task_run, state=Running()) for task_run_id in rest: await api.states.set_task_run_state(task_run_id=task_run_id, state=Success()) # set the flow run to a cancelled state await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Cancelled()) # Confirm the unfinished task runs have been marked as cancelled task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"id", "state"}) new_states = {run.id: run.state for run in task_runs} assert new_states[pending_task_run] == "Cancelled" assert new_states[running_task_run] == "Cancelled" assert all(new_states[id] == "Success" for id in rest)
def test_set_task_run_state(patch_post): response = { "data": { "set_task_run_states": { "states": [{ "status": "SUCCESS" }] } } } post = patch_post(response) state = Pending() with set_temporary_config({ "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() result = client.set_task_run_state(task_run_id="76-salt", version=0, state=state) assert result is state
def test_set_task_run_state_serializes(patch_post): response = { "data": { "set_task_run_states": { "states": [{ "status": "SUCCESS" }] } } } post = patch_post(response) with set_temporary_config({ "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() res = SafeResult(lambda: None, result_handler=None) with pytest.raises(marshmallow.exceptions.ValidationError): client.set_task_run_state(task_run_id="76-salt", version=0, state=Pending(result=res))
def test_task_runner_prioritizes_kwarg_states_over_db_states(monkeypatch, state): task = Task(name="test") db_state = state("already", result=10) get_task_run_info = MagicMock(return_value=MagicMock(state=db_state)) set_task_run_state = MagicMock( side_effect=lambda task_run_id, version, state, cache_for: state ) client = MagicMock( get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state ) monkeypatch.setattr( "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client) ) res = CloudTaskRunner(task=task).run( state=Pending("let's do this"), context={"map_index": 1} ) ## assertions assert get_task_run_info.call_count == 1 # one time to pull latest state assert set_task_run_state.call_count == 2 # Pending -> Running -> Success states = [call[1]["state"] for call in set_task_run_state.call_args_list] assert [type(s).__name__ for s in states] == ["Running", "Success"]
def initialize_run( self, state: Optional[State], context: Dict[str, Any]) -> Tuple[State, Dict[str, Any]]: """ Initializes the Task run by initializing state and context appropriately. If the provided state is a meta state, the state it wraps is extracted. Args: - state (Optional[State]): the initial state of the run - context (dict): the context to be updated with relevant information Returns: - tuple: a tuple of the updated state and context objects """ # extract possibly nested meta states -> for example a Submitted( Queued( Retry ) ) while isinstance(state, State) and state.is_meta_state(): state = state.state # type: ignore state = state or Pending() return state, context
def test_set_flow_run_state(patch_post): response = { "data": { "set_flow_run_states": { "states": [{"id": 1, "status": "SUCCESS", "message": None}] } } } post = patch_post(response) with set_temporary_config( { "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", } ): client = Client() state = Pending() result = client.set_flow_run_state(flow_run_id="74-salt", version=0, state=state) assert isinstance(result, State) assert isinstance(result, Pending)
def client(monkeypatch): cloud_client = MagicMock( get_flow_run_info=MagicMock( return_value=MagicMock(state=Pending(), parameters={}) ), set_flow_run_state=MagicMock( side_effect=lambda flow_run_id, version, state: state ), get_task_run_info=MagicMock(return_value=MagicMock(state=None)), set_task_run_state=MagicMock( side_effect=lambda task_run_id, version, state, cache_for: state ), get_latest_task_run_states=MagicMock( side_effect=lambda flow_run_id, states, result_handler: states ), ) monkeypatch.setattr( "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client) ) monkeypatch.setattr( "prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client) ) yield cloud_client
def test_task_runner_preserves_location_of_inputs_when_retrying( self, client): """ If a user opts out of checkpointing via checkpoint=False, we don't want to surprise them by storing the result in cached_inputs. This test ensures that whatever location is provided to a downstream task is the one that is used. """ @prefect.task(max_retries=1, retry_delay=datetime.timedelta(days=1)) def add(x, y): return x + y x = PrefectResult(value=1) y = PrefectResult(value="0", location="foo") state = Pending(cached_inputs=dict(x=x, y=y)) x_state = Success() y_state = Success() upstream_states = { Edge(Task(), Task(), key="x"): x_state, Edge(Task(), Task(), key="y"): y_state, } res = CloudTaskRunner(task=add).run(state=state, upstream_states=upstream_states) ## assertions assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 3 ) # Pending -> Running -> Failed -> Retrying states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert states[0].is_running() assert states[1].is_failed() assert isinstance(states[2], Retrying) assert states[2].cached_inputs["x"].location is None assert states[2].cached_inputs["y"].location == "foo"
def test_task_runner_handles_looping_with_retries(client): # note that looping _requires_ a result handler in Cloud @prefect.task( max_retries=1, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def looper(): if ( prefect.context.get("task_loop_count") == 2 and prefect.context.get("task_run_count", 1) == 1 ): raise ValueError("Stop") if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") client.get_task_run_info.side_effect = [ MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i)) for i in range(5) ] res = CloudTaskRunner(task=looper).run( context={"task_run_version": 1}, state=None, upstream_states={} ) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 4 assert ( client.set_task_run_state.call_count == 9 ) # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 2, 3]
def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler( self, client ): class MyResult(Result): def read(self, *args, **kwargs): new = self.copy() new.value = 1337 return new @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result=MyResult(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="-1"), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="99"), cached_inputs={"x": PrefectResult(location="2")}, ) client.get_latest_cached_states = MagicMock(return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": PrefectResult(value=2)} ) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 1337
def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler( client, ): class Handler(ResultHandler): def read(self, val): return 1337 @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result_handler=Handler(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("-1", JSONResultHandler()), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("99", JSONResultHandler()), cached_inputs={ "x": SafeResult("2", result_handler=JSONResultHandler()) }, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": Result(2, result_handler=JSONResultHandler())}) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 1337
def test_set_task_run_state_responds_to_status(patch_post): response = { "data": { "set_task_run_states": { "states": [{ "status": "QUEUED" }] } } } post = patch_post(response) state = Pending() with set_temporary_config({ "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() result = client.set_task_run_state(task_run_id="76-salt", version=0, state=state) assert result.is_queued() assert result.state is None # caller should set this
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state( task_run_id=task_run_id, state=Failed() ) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"} ) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state( task_run_id=str(uuid.uuid4()), state=state ) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()] ) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id ): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id): await api.states.set_task_run_state( task_run_id=task_run_id, state=TriggerFailed() ) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"} ) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state", [s() for s in State.children() if s not in _MetaState.children()], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, task_run_id, state, running_flow_run_id ): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update( set=dict(serialized_state=cached_state.serialize()) ) # try to schedule the task run to scheduled await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first( {"serialized_state"} ) # ensure the state change took place assert task_run.serialized_state["type"] == type(state).__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, state, task_run_id, running_flow_run_id ): await api.states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()] ) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, state_cls, task_run_id, running_flow_run_id ): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box( await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) ) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2} @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()] ) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state ): await api.states.set_flow_run_state( flow_run_id=flow_run_id, state=flow_run_state ) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running() ) if flow_run_state.is_running(): assert await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state != "Running"
def test_task_map_with_no_upstream_results_and_a_mapped_state(executor): """ This test makes sure that mapped tasks properly generate children tasks even when run multiple times and without available upstream results. In this test, we run the pipeline from a variety of starting points, ensuring that some upstream results are unavailable and checking that children pipelines are properly regenerated. """ @prefect.task def numbers(): return [1, 2, 3] @prefect.task def plus_one(x): return x + 1 @prefect.task def get_sum(x): return sum(x) with Flow(name="test") as f: n = numbers() x = plus_one.map(n) y = plus_one.map(x) s = get_sum(y) # first run with a missing result from `n` but map_states for `x` state = FlowRunner(flow=f).run( executor=executor, task_states={ n: Success(), x: Mapped(map_states=[ Pending(cached_inputs={"x": Result(i)}) for i in range(1, 4) ]), }, return_tasks=f.tasks, ) assert state.is_successful() assert state.result[s].result == 12 # next run with missing results for n and x state = FlowRunner(flow=f).run( executor=executor, task_states={ n: Success(), x: Mapped(map_states=[Success(), Success(), Success()]), y: Mapped(map_states=[ Success(result=3), Success(result=4), Retrying(cached_inputs={"x": Result(4)}), ]), }, return_tasks=f.tasks, ) assert state.is_successful() assert state.result[s].result == 12 # next run with missing results for n, x, and y state = FlowRunner(flow=f).run( executor=executor, task_states={ n: Success(), x: Mapped(map_states=[Success(), Success(), Success()]), y: Mapped(map_states=[ Success(result=3), Success(result=4), Success(result=5) ]), }, return_tasks=f.tasks, ) assert state.is_successful() assert state.result[s].result == 12
def test_states_are_hashable(): assert {State(), Pending(), Success()}
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ # this dictionary is used for tracking the states of "children" mapped tasks; # when running on Dask, we want to avoid serializing futures, so instead # of storing child task states in the `map_states` attribute we instead store # in this dictionary and only after they are resolved do we attach them to the Mapped state mapped_children = dict() # type: Dict[Task, list] if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError("Some tasks in return_tasks were not found in the flow.") def extra_context(task: Task, task_index: int = None) -> dict: return { "task_name": task.name, "task_tags": task.tags, "task_index": task_index, } # -- process each task in order with self.check_for_cancellation(), executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) # if a task is a constant task, we already know its return value # no need to use up resources by running it through a task runner if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant ): task_states[task] = task_state = Success(result=task.value) # Always restart completed resource setup/cleanup tasks and # secret tasks unless they were explicitly cached. # TODO: we only need to rerun these tasks if any pending # downstream tasks depend on them. if ( isinstance( task, ( prefect.tasks.core.resource_manager.ResourceSetupTask, prefect.tasks.core.resource_manager.ResourceCleanupTask, prefect.tasks.secrets.SecretBase, ), ) and task_state is not None and task_state.is_finished() and not task_state.is_cached() ): task_states[task] = task_state = Pending() # if the state is finished, don't run the task, just use the provided state if # the state is cached / mapped, we still want to run the task runner pipeline # steps to either ensure the cache is still valid / or to recreate the mapped # pipeline for possible retries if ( isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped() ): continue upstream_states = {} # type: Dict[Edge, State] # this dictionary is used exclusively for "reduce" tasks in particular we store # the states / futures corresponding to the upstream children, and if running # on Dask, let Dask resolve them at the appropriate time. # Note: this is an optimization that allows Dask to resolve the mapped # dependencies by "elevating" them to a function argument. upstream_mapped_states = {} # type: Dict[Edge, list] # -- process each edge to the task for edge in self.flow.edges_to(task): # load the upstream task states (supplying Pending as a default) upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.") ) # if the edge is flattened and not the result of a map, then we # preprocess the upstream states. If it IS the result of a # map, it will be handled in `prepare_upstream_states_for_mapping` if edge.flattened: if not isinstance(upstream_states[edge], Mapped): upstream_states[edge] = executor.submit( executors.flatten_upstream_state, upstream_states[edge] ) # this checks whether the task is a "reduce" task for a mapped pipeline # and if so, collects the appropriate upstream children if not edge.mapped and isinstance(upstream_states[edge], Mapped): children = mapped_children.get(edge.upstream_task, []) # if the edge is flattened, then we need to wait for the mapped children # to complete and then flatten them if edge.flattened: children = executors.flatten_mapped_children( mapped_children=children, executor=executor ) upstream_mapped_states[edge] = children # augment edges with upstream constants for key, val in self.flow.constants[task].items(): edge = Edge( upstream_task=prefect.tasks.core.constants.Constant(val), downstream_task=task, key=key, ) upstream_states[edge] = Success( "Auto-generated constant value", result=ConstantResult(value=val), ) # handle mapped tasks if any(edge.mapped for edge in upstream_states.keys()): # wait on upstream states to determine the width of the pipeline # this is the key to depth-first execution upstream_states = executor.wait( {e: state for e, state in upstream_states.items()} ) # we submit the task to the task runner to determine if # we can proceed with mapping - if the new task state is not a Mapped # state then we don't proceed task_states[task] = executor.wait( executor.submit( run_task, task=task, state=task_state, # original state upstream_states=upstream_states, context=dict( prefect.context, **task_contexts.get(task, {}) ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, is_mapped_parent=True, extra_context=extra_context(task), ) ) # either way, we should now have enough resolved states to restructure # the upstream states into a list of upstream state dictionaries to iterate over list_of_upstream_states = ( executors.prepare_upstream_states_for_mapping( task_states[task], upstream_states, mapped_children, executor=executor, ) ) submitted_states = [] for idx, states in enumerate(list_of_upstream_states): # if we are on a future rerun of a partially complete flow run, # there might be mapped children in a retrying state; this check # looks into the current task state's map_states for such info if ( isinstance(task_state, Mapped) and len(task_state.map_states) >= idx + 1 ): current_state = task_state.map_states[ idx ] # type: Optional[State] elif isinstance(task_state, Mapped): current_state = None else: current_state = task_state # this is where each child is submitted for actual work submitted_states.append( executor.submit( run_task, task=task, state=current_state, upstream_states=states, context=dict( prefect.context, **task_contexts.get(task, {}), map_index=idx, ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task, task_index=idx), ) ) if isinstance(task_states.get(task), Mapped): mapped_children[task] = submitted_states # type: ignore else: task_states[task] = executor.submit( run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task), ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks) final_states = executor.wait( { t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks } ) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): # ensure we wait for any mapped children to complete if t in mapped_children: s.map_states = executor.wait(mapped_children[t]) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks]) ) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
def test_states_with_mutable_attrs_are_hashable(): assert {State(result=[1]), Pending(cached_inputs=dict(a=1))}
def run( self, state: State = None, task_states: Dict[Task, State] = None, return_tasks: Iterable[Task] = None, parameters: Dict[str, Any] = None, task_runner_state_handlers: Iterable[Callable] = None, executor: "prefect.executors.Executor" = None, context: Dict[str, Any] = None, task_contexts: Dict[Task, Dict[str, Any]] = None, ) -> State: """ The main endpoint for FlowRunners. Calling this method will perform all computations contained within the Flow and return the final state of the Flow. Args: - state (State, optional): starting state for the Flow. Defaults to `Pending` - task_states (dict, optional): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - parameters (dict, optional): dictionary of any needed Parameter values, with keys being strings representing Parameter names and values being their corresponding values - task_runner_state_handlers (Iterable[Callable], optional): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor, optional): executor to use when performing computation; defaults to the executor specified in your prefect configuration - context (Dict[str, Any], optional): prefect.Context to use for execution to use for each Task run - task_contexts (Dict[Task, Dict[str, Any]], optional): contexts that will be provided to each task Returns: - State: `State` representing the final post-run state of the `Flow`. """ self.logger.info("Beginning Flow run for '{}'".format(self.flow.name)) # make copies to avoid modifying user inputs parameters = dict(parameters or {}) task_states = dict(task_states or {}) task_contexts = dict(task_contexts or {}) # Default to global context, with provided context as override run_context = dict(prefect.context) run_context.update(context or {}) if executor is None: # Use the executor on the flow, if configured executor = getattr(self.flow, "executor", None) if executor is None: executor = prefect.engine.get_default_executor_class()() self.logger.debug("Using executor type %s", type(executor).__name__) try: state, task_states, run_context, task_contexts = self.initialize_run( state=state, task_states=task_states, context=run_context, task_contexts=task_contexts, parameters=parameters, ) with prefect.context(run_context): state = self.check_flow_is_pending_or_running(state) state = self.check_flow_reached_start_time(state) state = self.set_flow_to_running(state) state = self.get_flow_run_state( state, task_states=task_states, task_contexts=task_contexts, return_tasks=return_tasks, task_runner_state_handlers=task_runner_state_handlers, executor=executor, ) except ENDRUN as exc: state = exc.state # All other exceptions are trapped and turned into Failed states except Exception as exc: self.logger.exception( "Unexpected error while running flow: {}".format(repr(exc)) ) if run_context.get("raise_on_exception"): raise exc new_state = Failed( message="Unexpected error while running flow: {}".format(repr(exc)), result=exc, ) state = self.handle_state_change(state or Pending(), new_state) return state
async def get_or_create_task_run_info(flow_run_id: str, task_id: str, map_index: int = None) -> dict: """ Given a flow_run_id, task_id, and map_index, return details about the corresponding task run. If the task run doesn't exist, it will be created. Returns: - dict: a dict of details about the task run, including its id, version, and state. """ if map_index is None: map_index = -1 task_run = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id }, "task_id": { "_eq": task_id }, "map_index": { "_eq": map_index }, }).first({"id", "version", "state", "serialized_state"}) if task_run: return dict( id=task_run.id, version=task_run.version, state=task_run.state, serialized_state=task_run.serialized_state, ) # if it isn't found, add it to the DB task = await models.Task.where(id=task_id ).first({"cache_key", "tenant_id"}) if not task: raise ValueError("Invalid task ID") db_task_run = models.TaskRun( tenant_id=task.tenant_id, flow_run_id=flow_run_id, task_id=task_id, map_index=map_index, cache_key=task.cache_key, version=0, ) db_task_run_state = models.TaskRunState( tenant_id=task.tenant_id, state="Pending", timestamp=pendulum.now(), message="Task run created", serialized_state=Pending(message="Task run created").serialize(), ) db_task_run.states = [db_task_run_state] run = await db_task_run.insert( on_conflict=dict( constraint="task_run_unique_identifier_key", update_columns=["cache_key"], ), selection_set={"returning": {"id"}}, ) return dict( id=run.returning.id, version=db_task_run.version, state="Pending", serialized_state=db_task_run_state.serialized_state, )
async def _create_flow_run( flow_id: str = None, parameters: dict = None, context: dict = None, scheduled_start_time: datetime.datetime = None, flow_run_name: str = None, version_group_id: str = None, labels: List[str] = None, run_config: dict = None, ) -> Any: """ Creates a new flow run for an existing flow. Args: - flow_id (str): A string representing the current flow id - parameters (dict, optional): A dictionary of parameters that were specified for the flow - context (dict, optional): A dictionary of context values - scheduled_start_time (datetime.datetime): When the flow_run should be scheduled to run. If `None`, defaults to right now. Must be UTC. - flow_run_name (str, optional): An optional string representing this flow run - version_group_id (str, optional): An optional version group ID; if provided, will run the most recent unarchived version of the group - labels (List[str], optional): a list of labels to apply to this individual flow run - run-config (dict, optional): A run-config override for this flow run. """ if flow_id is None and version_group_id is None: raise ValueError( "One of flow_id or version_group_id must be provided.") scheduled_start_time = scheduled_start_time or pendulum.now() if flow_id: where_clause = {"id": {"_eq": flow_id}} elif version_group_id: where_clause = { "version_group_id": { "_eq": version_group_id }, "archived": { "_eq": False }, } flow = await models.Flow.where(where=where_clause).first( { "id": True, "archived": True, "tenant_id": True, "environment": True, "run_config": True, "parameters": True, "flow_group_id": True, "flow_group": { "default_parameters": True, "labels": True, "run_config": True, }, }, order_by={"version": EnumValue("desc")}, ) # type: Any if not flow: msg = (f"Flow {flow_id} not found" if flow_id else f"Version group {version_group_id} has no unarchived flows.") raise exceptions.NotFound(msg) elif flow.archived: raise ValueError(f"Flow {flow.id} is archived.") # determine active labels if labels is not None: run_labels = labels elif run_config is not None: run_labels = run_config.get("labels") or [] elif flow.flow_group.labels is not None: run_labels = flow.flow_group.labels elif flow.flow_group.run_config is not None: run_labels = flow.flow_group.run_config.get("labels") or [] elif flow.run_config is not None: run_labels = flow.run_config.get("labels") or [] elif flow.environment is not None: run_labels = flow.environment.get("labels") or [] else: run_labels = [] run_labels.sort() # determine active run_config if run_config is None: if flow.flow_group.run_config is not None: run_config = flow.flow_group.run_config else: run_config = flow.run_config # check parameters run_parameters = flow.flow_group.default_parameters run_parameters.update((parameters or {})) required_parameters = [p["name"] for p in flow.parameters if p["required"]] missing = set(required_parameters).difference(run_parameters) if missing: raise ValueError(f"Required parameters were not supplied: {missing}") state = Scheduled(message="Flow run scheduled.", start_time=scheduled_start_time) run = models.FlowRun( tenant_id=flow.tenant_id, flow_id=flow_id or flow.id, labels=run_labels, parameters=run_parameters, run_config=run_config, context=context or {}, scheduled_start_time=scheduled_start_time, name=flow_run_name or names.generate_slug(2), states=[ models.FlowRunState( tenant_id=flow.tenant_id, **models.FlowRunState.fields_from_state( Pending(message="Flow run created")), ) ], ) flow_run_id = await run.insert() # apply the flow run's initial state via `set_flow_run_state` await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state) return flow_run_id
def test_preparing_state_for_cloud_replaces_cached_inputs_with_safe(): xres = Result(3, result_handler=JSONResultHandler()) state = prepare_state_for_cloud(Pending(cached_inputs=dict(x=xres))) assert state.is_pending() assert state.result == NoResult assert state.cached_inputs == dict(x=xres)
def __init__(self, id, state=None, version=None): self.id = id self.state = state or Pending() self.version = version or 0
def test_flow_run_handles_error_states_when_initial_state_is_provided(): with Flow(name="test") as f: res = AddTask()("5", 5) state = f.run(state=Pending()) assert state.is_failed()
@pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict(state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}), dict(state=Skipped(), assert_true={"is_finished", "is_successful", "is_skipped"}), dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}), dict(state=Success(), assert_true={"is_finished", "is_successful"}), dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}), dict(state=TriggerFailed(), assert_true={"is_finished", "is_failed"}), ], ) def test_state_is_methods(state_check):