async def flow(self, project_id, tmpdir): """ A simple diamond flow """ flow = prefect.Flow( "diamond", storage=prefect.environments.storage.Local(directory=tmpdir), environment=prefect.environments.LocalEnvironment(), ) flow.a = prefect.Task("a") flow.b = prefect.Task("b") flow.c = prefect.Task("c") flow.d = prefect.Task("d") flow.add_edge(flow.a, flow.b) flow.add_edge(flow.a, flow.c) flow.add_edge(flow.b, flow.d) flow.add_edge(flow.c, flow.d) with set_temporary_config(key="dev", value=True): flow.server_id = await api.flows.create_flow( project_id=project_id, serialized_flow=flow.serialize(build=True)) return flow
async def test_create_flow_and_register_tasks_separately( self, run_query, project_id): serialized_flow = prefect.Flow(name="test", tasks=[prefect.Task(), prefect.Task() ]).serialize(build=False) tasks = serialized_flow.pop("tasks") flow_result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow, project_id=project_id)), ) flow = await models.Flow.where(id=flow_result.data.create_flow.id ).first({"tasks": {"name"}}) assert flow.tasks == [] tasks_result = await run_query( query=self.register_tasks_mutation, variables=dict( input=dict(serialized_tasks=tasks, flow_id=flow_result.data.create_flow.id)), ) assert tasks_result.data.register_tasks.success is True flow = await models.Flow.where(id=flow_result.data.create_flow.id ).first({"tasks": {"name"}}) assert [t.name for t in flow.tasks] == ["Task", "Task"]
async def test_register_edges_provides_helpful_error( self, run_query, project_id): with prefect.Flow(name="test") as flow: prefect.Task().set_upstream(prefect.Task()) prefect.Task().set_upstream(prefect.Task()) serialized_flow = flow.serialize() serialized_tasks = serialized_flow.pop("tasks") serialized_edges = serialized_flow.pop("edges") flow_result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow, project_id=project_id)), ) edges_result = await run_query( query=self.register_edges_mutation, variables=dict(input=dict( serialized_edges=serialized_edges, flow_id=flow_result.data.create_flow.id, )), ) assert "reference tasks that do not exist" in edges_result.errors[ 0].message
async def test_create_flow_run_also_creates_task_runs_with_cache_keys( self, project_id): flow_id = await api.flows.create_flow( project_id=project_id, serialized_flow=prefect.Flow( name="test", tasks=[ prefect.Task(cache_key="test-key"), prefect.Task(), prefect.Task(cache_key="wat"), ], ).serialize(), ) flow_run_id = await api.runs.create_flow_run(flow_id=flow_id) task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"cache_key"}) assert set(tr.cache_key for tr in task_runs) == {"test-key", "wat", None}
async def flow_id(): """ 1 -> 2 -> 3 -> 4 -> 5 \ / 7 -> 8 -> 9 -> 10 / 6 """ t1 = prefect.Task("t1", slug="t1") t2 = prefect.Task("t2", slug="t2") t3 = prefect.Task("t3", slug="t3") t4 = prefect.Task("t4", slug="t4") t5 = prefect.Task("t5", slug="t5") t6 = prefect.Task("t6", slug="t6") t7 = prefect.Task("t7", slug="t7") t8 = prefect.Task("t8", slug="t8") t9 = prefect.Task("t9", slug="t9") t10 = prefect.Task("t10", slug="t10") f = prefect.Flow("traversal flow") f.chain(t1, t2, t3, t4, t5) f.chain(t6, t7, t8, t9, t10) f.chain(t2, t7) f.chain(t9, t5) return await prefect_server.api.flows.create_flow( serialized_flow=f.serialize())
async def test_create_flow_copies_settings_between_versions( self, run_query, settings ): # create a flow serialized_flow = prefect.Flow( name="My Flow", tasks=[prefect.Task(), prefect.Task()] ).serialize(build=False) result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow)), ) flow_id = result.data.create_flow.id # set an arbitrary setting await models.Flow.where(id=flow_id).update(set=dict(settings=settings)) # grab the name and register a new version flow = await models.Flow.where(id=flow_id).first({"name", "version_group_id"}) # create version two of this flow result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow)), ) new_flow = await models.Flow.where(id=result.data.create_flow.id).first( {"version", "version_group_id", "settings"} ) # confirm the flow we're inspecting is, in fact, a new version of the old flow assert new_flow.version_group_id == flow.version_group_id assert new_flow.version == 2 # confirm the settings were persisted between versions of flows assert new_flow.settings == settings
def test_simple_two_task_flow_with_final_task_set_to_fail( monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t2.set_upstream(t1) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_id=t1.id, flow_run_id=flow_run_id), TaskRun(id=task_run_id_2, task_id=t2.id, flow_run_id=flow_run_id, state=Failed()), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_failed() assert client.flow_runs[flow_run_id].state.is_failed() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_failed() assert client.task_runs[task_run_id_2].version == 0
def test_slug_mismatch_raises_informative_error(monkeypatch): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t2.set_upstream(t1) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id), TaskRun(id=task_run_id_2, task_slug="bad-slug", flow_run_id=flow_run_id), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks) assert state.is_failed() ## assert informative message; can't use `match` because the real exception is one layer depeer than the ENDRUN assert "KeyError" in repr(state.result) assert "not found" in repr(state.result) assert "changing the Flow" in repr(state.result)
async def test_create_flow_without_edges(self): flow = prefect.Flow(name="test") flow.add_task(prefect.Task()) flow.add_task(prefect.Task()) flow_id = await flows.create_flow(serialized_flow=prefect.Flow( name="test").serialize()) assert await m.Flow.exists(flow_id)
async def test_create_flow_and_register_edges_separately( self, run_query, project_id): with prefect.Flow(name="test") as flow: prefect.Task().set_upstream(prefect.Task()) prefect.Task().set_upstream(prefect.Task()) serialized_flow = flow.serialize() serialized_tasks = serialized_flow.pop("tasks") serialized_edges = serialized_flow.pop("edges") flow_result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow, project_id=project_id)), ) flow = await models.Flow.where(id=flow_result.data.create_flow.id ).first( { "tasks": {"name"}, "edges_aggregate": { "aggregate": {"count"} } }, apply_schema=False, ) assert flow.tasks == [] assert flow.edges_aggregate.aggregate.count == 0 # need to register tasks first await api.flows.register_tasks( flow_id=flow_result.data.create_flow.id, tasks=serialized_tasks, tenant_id=None, ) edges_result = await run_query( query=self.register_edges_mutation, variables=dict(input=dict( serialized_edges=serialized_edges, flow_id=flow_result.data.create_flow.id, )), ) assert edges_result.data.register_edges.success is True flow = await models.Flow.where(id=flow_result.data.create_flow.id ).first( { "tasks": {"name"}, "edges_aggregate": { "aggregate": {"count"} } }, apply_schema=False, ) assert [t.name for t in flow.tasks] == ["Task", "Task"] * 2 assert flow.edges_aggregate.aggregate.count == 2
async def test_create_flow_without_edges(self, project_id): flow = prefect.Flow(name="test") flow.add_task(prefect.Task()) flow.add_task(prefect.Task()) flow_id = await api.flows.create_flow( project_id=project_id, serialized_flow=prefect.Flow(name="test").serialize()) assert await models.Flow.exists(flow_id)
async def test_create_flow(self, run_query): serialized_flow = prefect.Flow( name="test", tasks=[prefect.Task(), prefect.Task()] ).serialize(build=False) result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow)), ) assert await models.Flow.exists(result.data.create_flow.id)
async def db_flow_id(): """ A minimal, controlled flow for use in testing the flow run -> task run creation trigger """ flow = prefect.Flow(name="Test Flow") flow.add_task(prefect.Task("task-1", cache_key="my-key-1")) flow.add_task(prefect.Task("task-2", cache_key="my-key-2")) flow_id = await api.flows.create_flow(serialized_flow=flow.serialize()) return flow_id
def flow(): flow = prefect.Flow(name="my flow") flow.add_edge( prefect.Task("t1", tags={"red", "blue"}), prefect.Task("t2", cache_key="test-key", tags={"red", "green"}), ) flow.add_task(prefect.Parameter("x", default=1)) mapped_task = prefect.Task("t3", tags={"mapped"}) flow.add_edge(prefect.Parameter("y"), mapped_task, key="y", mapped=True) flow.add_edge(prefect.Task("t4"), mapped_task, key="not_mapped") return flow
def test_simple_three_task_flow_with_first_task_retrying(monkeypatch, executor): """ If the first task retries, then the next two tasks shouldn't even make calls to Cloud because they won't pass their upstream checks """ @prefect.task(max_retries=1, retry_delay=datetime.timedelta(minutes=20)) def error(): 1 / 0 flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) task_run_id_3 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = error() t2 = prefect.Task() t3 = prefect.Task() t2.set_upstream(t1) t3.set_upstream(t2) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun( id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_3, task_slug=flow.slugs[t3], flow_run_id=flow_run_id ), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run( return_tasks=flow.tasks, executor=executor ) assert state.is_running() assert client.flow_runs[flow_run_id].state.is_running() assert isinstance(client.task_runs[task_run_id_1].state, Retrying) assert client.task_runs[task_run_id_1].version == 3 assert client.task_runs[task_run_id_2].state.is_pending() assert client.task_runs[task_run_id_2].version == 0 assert client.task_runs[task_run_id_3].state.is_pending() assert client.task_runs[task_run_id_2].version == 0 assert client.call_count["set_task_run_state"] == 3
async def test_create_flow_run_also_creates_task_runs(self,): flow_id = await flows.create_flow( serialized_flow=prefect.Flow( name="test", tasks=[prefect.Task(), prefect.Task(), prefect.Task()] ).serialize(), ) flow_run_id = await runs.create_flow_run(flow_id=flow_id) assert ( await models.TaskRun.where({"flow_run_id": {"_eq": flow_run_id}}).count() == 3 )
async def test_create_flow(self, run_query, project_id): serialized_flow = prefect.Flow(name="test", tasks=[prefect.Task(), prefect.Task() ]).serialize(build=False) result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow, project_id=project_id)), ) flow = await models.Flow.where(id=result.data.create_flow.id ).first({"project_id"}) assert flow.project_id == project_id
async def test_create_flow_with_project(self, run_query): """ Checks Cloud-compatible API """ serialized_flow = prefect.Flow(name="test", tasks=[prefect.Task(), prefect.Task() ]).serialize(build=False) result = await run_query( query=self.create_flow_mutation, variables=dict(input=dict(serialized_flow=serialized_flow, project_id=str(uuid.uuid4()))), ) assert await models.Flow.exists(result.data.create_flow.id)
def test_client_register_with_flow_that_cant_be_deserialized( patch_post, monkeypatch): patch_post({"data": {"project": [{"id": "proj-id"}]}}) monkeypatch.setattr("prefect.client.Client.get_default_tenant_slug", MagicMock(return_value="tslug")) with set_temporary_config({ "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", }): client = Client() task = prefect.Task() # we add a max_retries value to the task without a corresponding retry_delay; this will fail at deserialization task.max_retries = 3 flow = prefect.Flow(name="test", tasks=[task]) flow.result = prefect.engine.result.Result() with pytest.raises( ValueError, match=("(Flow could not be deserialized).*" "(`retry_delay` must be provided if max_retries > 0)"), ) as exc: client.register(flow, project_name="my-default-project", build=False)
async def flow_id(): flow = prefect.Flow( name="Test Flow", schedule=prefect.schedules.IntervalSchedule( start_date=pendulum.datetime(2018, 1, 1), interval=datetime.timedelta(days=1), ), ) flow.add_edge( prefect.Task("t1", tags={"red", "blue"}), prefect.Task("t2", tags={"red", "green"}), ) flow.add_task(prefect.Parameter("x", default=1)) flow_id = await api.flows.create_flow(serialized_flow=flow.serialize()) return flow_id
def test_simple_three_task_flow_with_one_failing_task(monkeypatch, executor): @prefect.task def error(): 1 / 0 flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) task_run_id_3 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t3 = error() t2.set_upstream(t1) t3.set_upstream(t2) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun( id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], flow_run_id=flow_run_id ), TaskRun( id=task_run_id_3, task_slug=flow.slugs[t3], flow_run_id=flow_run_id ), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run( return_tasks=flow.tasks, executor=executor ) assert state.is_failed() assert client.flow_runs[flow_run_id].state.is_failed() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_successful() assert client.task_runs[task_run_id_2].version == 2 assert client.task_runs[task_run_id_3].state.is_failed() assert client.task_runs[task_run_id_2].version == 2
async def labeled_flow_id(project_id): flow = prefect.Flow( name="Labeled Flow", run_config=prefect.run_configs.UniversalRun(labels=["foo", "bar"]), schedule=prefect.schedules.IntervalSchedule( start_date=pendulum.datetime(2018, 1, 1), interval=datetime.timedelta(days=1), ), ) flow.add_edge( prefect.Task("t1", tags={"red", "blue"}), prefect.Task("t2", tags={"red", "green"}), ) flow.add_task(prefect.Parameter("x", default=1)) flow_id = await api.flows.create_flow(project_id=project_id, serialized_flow=flow.serialize()) return flow_id
async def labeled_flow_id(): flow = prefect.Flow( name="Labeled Flow", environment=prefect.environments.execution.remote.RemoteEnvironment( labels=["foo", "bar"]), schedule=prefect.schedules.IntervalSchedule( start_date=pendulum.datetime(2018, 1, 1), interval=datetime.timedelta(days=1), ), ) flow.add_edge( prefect.Task("t1", tags={"red", "blue"}), prefect.Task("t2", tags={"red", "green"}), ) flow.add_task(prefect.Parameter("x", default=1)) flow_id = await api.flows.create_flow(serialized_flow=flow.serialize()) return flow_id
def test_flow_runner_doesnt_set_running_states_twice(client): task = prefect.Task() flow = prefect.Flow(name="test", tasks=[task]) res = CloudFlowRunner(flow=flow).run( task_states={task: Retrying(start_time=pendulum.now("utc").add(days=1))} ) ## assertions assert client.get_flow_run_info.call_count == 1 # one time to pull latest state assert client.set_flow_run_state.call_count == 1 # Pending -> Running
def test_client_register_raises_for_keyed_flows_with_no_result( patch_post, compressed, monkeypatch, tmpdir): if compressed: response = { "data": { "project": [{ "id": "proj-id" }], "create_flow_from_compressed_string": { "id": "long-id" }, } } else: response = { "data": { "project": [{ "id": "proj-id" }], "create_flow": { "id": "long-id" } } } patch_post(response) monkeypatch.setattr("prefect.client.Client.get_default_tenant_slug", MagicMock(return_value="tslug")) @prefect.task def a(x): pass with set_temporary_config({ "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", }): client = Client() with prefect.Flow( name="test", storage=prefect.environments.storage.Local(tmpdir)) as flow: a(prefect.Task()) flow.result = None with pytest.warns(UserWarning, match="result handler"): flow_id = client.register( flow, project_name="my-default-project", compressed=compressed, version_group_id=str(uuid.uuid4()), no_url=True, )
def test_switch_works_with_raise_on_exception(): @prefect.task def return_b(): return "b" tasks = {let: prefect.Task(name=let) for let in "abcde"} with Flow(name="test") as flow: res = switch(return_b, tasks) with raise_on_exception(): flow_state = flow.run()
async def flow(self, project_id, tmpdir): """ A simple diamond flow """ flow = prefect.Flow( "pause", storage=prefect.storage.Local(directory=tmpdir), run_config=prefect.run_configs.LocalRun(), ) flow.a = prefect.Task("a") flow.b = prefect.Task("b", trigger=prefect.triggers.manual_only) flow.c = prefect.Task("c") flow.add_edge(flow.a, flow.b) flow.add_edge(flow.b, flow.c) with set_temporary_config(key="dev", value=True): flow.server_id = await api.flows.create_flow( project_id=project_id, serialized_flow=flow.serialize(build=True) ) return flow
def test_client_register_raises_for_keyed_flows_with_no_result_handler( patch_post, compressed): if compressed: response = { "data": { "project": [{ "id": "proj-id" }], "createFlowFromCompressedString": { "id": "long-id" }, } } else: response = { "data": { "project": [{ "id": "proj-id" }], "createFlow": { "id": "long-id" } } } patch_post(response) @prefect.task def a(x): pass with set_temporary_config({ "cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() with prefect.Flow(name="test", storage=prefect.environments.storage.Memory()) as flow: a(prefect.Task()) flow.result_handler = None with pytest.warns(UserWarning, match="result handler"): flow_id = client.register( flow, project_name="my-default-project", compressed=compressed, version_group_id=str(uuid.uuid4()), )
def test_client_deploy_with_flow_that_cant_be_deserialized(patch_post): patch_post({"data": {"project": [{"id": "proj-id"}]}}) with set_temporary_config({ "cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() task = prefect.Task() # we add a max_retries value to the task without a corresponding retry_delay; this will fail at deserialization task.max_retries = 3 flow = prefect.Flow(name="test", tasks=[task]) with pytest.raises( ValueError, match=("(Flow could not be deserialized).*" "(`retry_delay` must be provided if max_retries > 0)"), ) as exc: client.deploy(flow, project_name="my-default-project", build=False)
def test_client_is_always_called_even_during_state_handler_failures(client): def handler(task, old, new): 1 / 0 flow = prefect.Flow(name="test", tasks=[prefect.Task()], state_handlers=[handler]) ## flow run setup res = flow.run(state=Pending()) ## assertions assert client.get_flow_run_info.call_count == 1 # one time to pull latest state assert client.set_flow_run_state.call_count == 1 # Failed flow_states = [ call[1]["state"] for call in client.set_flow_run_state.call_args_list ] state = flow_states.pop() assert state.is_failed() assert "state handlers" in state.message assert isinstance(state.result, ZeroDivisionError) assert client.get_task_run_info.call_count == 0