def test_db_cancelled_states_interrupt_flow_run(client, monkeypatch): calls = dict(count=0) def heartbeat_counter(*args, **kwargs): if calls["count"] == 3: return Box(dict(data=dict(flow_run_by_pk=dict(state="Cancelled")))) calls["count"] += 1 return Box(dict(data=dict(flow_run_by_pk=dict(state="Running")))) client.graphql = heartbeat_counter @prefect.task def sleeper(): time.sleep(3) f = prefect.Flow("test", tasks=[sleeper]) with set_temporary_config({"cloud.heartbeat_interval": 0.025}): state = CloudFlowRunner(flow=f).run(return_tasks=[sleeper]) assert isinstance(state, Cancelled) assert "interrupt" in state.message.lower()
def test_docker_serialize_with_flows(): docker = storage.Docker( registry_url="url", image_name="name", image_tag="tag", secrets=["FOO"], ) f = prefect.Flow("test") docker.add_flow(f) serialized = DockerSchema().dump(docker) assert serialized assert serialized["__version__"] == prefect.__version__ assert serialized["image_name"] == "name" assert serialized["image_tag"] == "tag" assert serialized["registry_url"] == "url" assert serialized["flows"] == {"test": "/opt/prefect/flows/test.prefect"} assert serialized["secrets"] == ["FOO"] deserialized = DockerSchema().load(serialized) assert f.name in deserialized assert deserialized.secrets == ["FOO"]
def test_environment_run(): class MyExecutor(LocalDaskExecutor): submit_called = False def submit(self, *args, **kwargs): self.submit_called = True return super().submit(*args, **kwargs) global_dict = {} @prefect.task def add_to_dict(): global_dict["run"] = True executor = MyExecutor() environment = FargateTaskEnvironment(executor=executor) flow = prefect.Flow("test", tasks=[add_to_dict], environment=environment) environment.run(flow=flow) assert global_dict.get("run") is True assert executor.submit_called
def test_flow_runner_heartbeat_sets_command(monkeypatch, setting_available): client = MagicMock() monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=client)) client.graphql.return_value.data.flow_run_by_pk.flow.settings = (dict( heartbeat_enabled=True) if setting_available else {}) runner = CloudFlowRunner(flow=prefect.Flow(name="test")) with prefect.context(flow_run_id="foo"): res = runner._heartbeat() assert res is True assert runner.heartbeat_cmd == [ sys.executable, "-m", "prefect", "heartbeat", "flow-run", "-i", "foo", ]
def test_check_interrupt_loop_robust_to_api_errors(self, client, monkeypatch): trigger = threading.Event() error_was_raised = False def get_flow_run_info(*args, _call_count=itertools.count(), **kwargs): call_count = next(_call_count) import inspect caller_name = inspect.currentframe().f_back.f_code.co_name if caller_name == "interrupt_if_cancelling" and call_count % 2: nonlocal error_was_raised error_was_raised = True raise ValueError("Woops!") state = Cancelling() if trigger.is_set() else Running() return MagicMock(version=call_count, state=state) client.get_flow_run_info = get_flow_run_info ran_longer_than_expected = False @prefect.task def set_trigger(x): trigger.set() time.sleep(10) nonlocal ran_longer_than_expected ran_longer_than_expected = True return x + 1 with prefect.Flow("test") as flow: set_trigger(1) with set_temporary_config({"cloud.check_cancellation_interval": 0.1}): res = CloudFlowRunner(flow=flow).run() assert isinstance(res, Cancelled) assert error_was_raised assert not ran_longer_than_expected
def test_client_register_doesnt_raise_if_no_keyed_edges( patch_post, compressed, monkeypatch, tmpdir ): if compressed: response = { "data": { "project": [{"id": "proj-id"}], "create_flow_from_compressed_string": {"id": "long-id"}, } } else: response = { "data": {"project": [{"id": "proj-id"}], "create_flow": {"id": "long-id"}} } patch_post(response) monkeypatch.setattr( "prefect.client.Client.get_default_tenant_slug", MagicMock(return_value="tslug") ) with set_temporary_config( { "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", } ): client = Client() flow = prefect.Flow(name="test", storage=prefect.storage.Local(tmpdir)) flow.result = None flow_id = client.register( flow, project_name="my-default-project", compressed=compressed, version_group_id=str(uuid.uuid4()), no_url=True, ) assert flow_id == "long-id"
def test_simple_two_task_flow_with_final_task_already_running( monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t2.set_upstream(t1) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], version=1, flow_run_id=flow_run_id, state=Running(), ), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_running() assert client.flow_runs[flow_run_id].state.is_running() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_running() assert client.task_runs[task_run_id_2].version == 1
def test_run_flow(monkeypatch): environment = CloudEnvironment() flow_runner = MagicMock() monkeypatch.setattr("prefect.engine.FlowRunner", flow_runner) kube_cluster = MagicMock() monkeypatch.setattr("dask_kubernetes.KubeCluster", kube_cluster) with tempfile.TemporaryDirectory() as directory: with open(os.path.join(directory, "flow_env.prefect"), "w+") as env: flow = prefect.Flow("test") flow_path = os.path.join(directory, "flow_env.prefect") with open(flow_path, "w") as f: json.dump(flow.serialize(), f) with set_temporary_config({"cloud.auth_token": "test"}): with prefect.context(flow_file_path=os.path.join( directory, "flow_env.prefect")): environment.run_flow() assert flow_runner.call_args[1]["flow"].name == "test"
def test_simple_three_task_flow_with_one_failing_task(monkeypatch, executor): @prefect.task def error(): 1 / 0 flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) task_run_id_3 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t3 = error() t2.set_upstream(t1) t3.set_upstream(t2) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_id=t1.id, flow_run_id=flow_run_id), TaskRun(id=task_run_id_2, task_id=t2.id, flow_run_id=flow_run_id), TaskRun(id=task_run_id_3, task_id=t3.id, flow_run_id=flow_run_id), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_failed() assert client.flow_runs[flow_run_id].state.is_failed() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_successful() assert client.task_runs[task_run_id_2].version == 2 assert client.task_runs[task_run_id_3].state.is_failed() assert client.task_runs[task_run_id_2].version == 2
def test_client_deploy(monkeypatch, compressed): if compressed: response = { "data": { "project": [{ "id": "proj-id" }], "createFlowFromCompressedString": { "id": "long-id" }, } } else: response = { "data": { "project": [{ "id": "proj-id" }], "createFlow": { "id": "long-id" } } } post = MagicMock(return_value=MagicMock(json=MagicMock( return_value=response))) session = MagicMock() session.return_value.post = post monkeypatch.setattr("requests.Session", session) with set_temporary_config({ "cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token" }): client = Client() flow = prefect.Flow(name="test", storage=prefect.environments.storage.Memory()) flow_id = client.deploy(flow, project_name="my-default-project", compressed=compressed) assert flow_id == "long-id"
def test_task_failure_caches_inputs_automatically(client): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") with prefect.Flow("test") as f: p = prefect.Parameter("p") res = is_p_three(p) state = CloudFlowRunner(flow=f).run(return_tasks=[res], parameters=dict(p=3)) assert state.is_running() assert isinstance(state.result[res], Retrying) exp_res = Result(3, result_handler=JSONResultHandler()) assert not state.result[res].cached_inputs["p"] == exp_res exp_res.store_safe_value() assert state.result[res].cached_inputs["p"] == exp_res last_state = client.set_task_run_state.call_args_list[-1][-1]["state"] assert isinstance(last_state, Retrying) assert last_state.cached_inputs["p"] == exp_res
def test_environment_execute_with_env_runner(): class TestStorage(Memory): def get_flow(self, *args, **kwargs): raise NotImplementedError() def get_env_runner(self, flow_loc): runner = super().get_flow(flow_loc) return lambda env: runner.run() global_dict = {} @prefect.task def add_to_dict(): global_dict["run"] = True environment = LocalEnvironment() storage = TestStorage() flow = prefect.Flow("test", tasks=[add_to_dict]) flow_loc = storage.add_flow(flow) environment.execute(storage, flow_loc) assert global_dict.get("run") is True
async def test_create_flow_run_with_version_group_id_uses_latest_version( self, ): flow_ids = [] for _ in range(15): flow_ids.append(await flows.create_flow( serialized_flow=prefect.Flow(name="test").serialize(), version_group_id="test-group", )) first_id = flow_ids.pop(0) newer_id = flow_ids.pop(9) for fid in flow_ids: await flows.archive_flow(fid) flow_run_id = await runs.create_flow_run(version_group_id="test-group") fr = await models.FlowRun.where(id=flow_run_id ).first({"flow": { "id": True }}) assert fr.flow.id == newer_id
def test_run_flow(monkeypatch): environment = FargateTaskEnvironment() flow_runner = MagicMock() monkeypatch.setattr( "prefect.engine.get_default_flow_runner_class", MagicMock(return_value=flow_runner), ) with tempfile.TemporaryDirectory() as directory: with open(os.path.join(directory, "flow_env.prefect"), "w+"): flow = prefect.Flow("test") flow_path = os.path.join(directory, "flow_env.prefect") with open(flow_path, "wb") as f: cloudpickle.dump(flow, f) with set_temporary_config({"cloud.auth_token": "test"}): with prefect.context(flow_file_path=os.path.join( directory, "flow_env.prefect")): environment.run_flow() assert flow_runner.call_args[1]["flow"].name == "test"
async def test_create_flow_run_also_creates_task_runs_with_cache_keys( self, ): flow_id = await flows.create_flow(serialized_flow=prefect.Flow( name="test", tasks=[ prefect.Task(cache_key="test-key"), prefect.Task(), prefect.Task(cache_key="wat"), ], ).serialize(), ) flow_run_id = await runs.create_flow_run(flow_id=flow_id) task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"cache_key"}) assert set(tr.cache_key for tr in task_runs) == {"test-key", "wat", None}
def test_environment_execute(): with tempfile.TemporaryDirectory() as directory: @prefect.task def add_to_dict(): with open(path.join(directory, "output"), "w") as tmp: tmp.write("success") with open(path.join(directory, "flow_env.prefect"), "w+") as env: flow = prefect.Flow("test", tasks=[add_to_dict]) flow_path = path.join(directory, "flow_env.prefect") with open(flow_path, "wb") as f: cloudpickle.dump(flow, f) environment = RemoteEnvironment() storage = Local(directory) storage.add_flow(flow) environment.execute(flow=flow) with open(path.join(directory, "output"), "r") as file: assert file.read() == "success"
def test_run_flow_calls_callbacks(monkeypatch, tmpdir): start_func = MagicMock() exit_func = MagicMock() environment = FargateTaskEnvironment(on_start=start_func, on_exit=exit_func) flow_runner = MagicMock() monkeypatch.setattr( "prefect.engine.get_default_flow_runner_class", MagicMock(return_value=flow_runner), ) d = Local(str(tmpdir)) d.add_flow(prefect.Flow("name")) gql_return = MagicMock(return_value=MagicMock(data=MagicMock(flow_run=[ GraphQLResult({ "flow": GraphQLResult({ "name": "name", "storage": d.serialize(), }) }) ], ))) client = MagicMock() client.return_value.graphql = gql_return monkeypatch.setattr("prefect.environments.execution.base.Client", client) with set_temporary_config({"cloud.auth_token": "test"}), prefect.context({"flow_run_id": "id"}): environment.run_flow() assert flow_runner.call_args[1]["flow"].name == "name" assert start_func.called assert exit_func.called
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary( client, ): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") return p with prefect.Flow("test", result_handler=JSONResultHandler()) as f: p = prefect.tasks.secrets.Secret("p") res = is_p_three(p) with prefect.context(secrets=dict(p=3)): state = CloudFlowRunner(flow=f).run(return_tasks=[res]) assert state.is_running() assert isinstance(state.result[res], Retrying) exp_res = Result(3, result_handler=SecretResultHandler(p)) assert not state.result[res].cached_inputs["p"] == exp_res exp_res.store_safe_value() assert state.result[res].cached_inputs["p"] == exp_res ## here we set the result of the secret to a saferesult, ensuring ## it will get converted to a "true" result; ## we expect that the upstream value will actually get recomputed from context ## through the SecretResultHandler safe = SafeResult("p", result_handler=SecretResultHandler(p)) state.result[p] = Success(result=safe) state.result[res].start_time = pendulum.now("utc") state.result[res].cached_inputs = dict(p=safe) with prefect.context(secrets=dict(p=4)): new_state = CloudFlowRunner(flow=f).run(return_tasks=[res], task_states=state.result) assert new_state.is_successful() assert new_state.result[res].result == 4
def test_skip_if_already_run(monkeypatch, test_logger, state, is_skipped): """ Test that the skip_if_already_run task skips if the workflow's most recent state is 'running' or 'success', and does not skip if the state is None (i.e. not run before) or 'failed'. """ get_session_mock = Mock() get_most_recent_state_mock = Mock(return_value=state) monkeypatch.setattr("autoflow.utils.get_session", get_session_mock) monkeypatch.setattr("autoflow.sensor.WorkflowRuns.get_most_recent_state", get_most_recent_state_mock) runner = TaskRunner(task=skip_if_already_run) upstream_edge = Edge(prefect.Task(), skip_if_already_run, key="parametrised_workflow") with set_temporary_config({"db_uri": "DUMMY_DB_URI"}): task_state = runner.run( upstream_states={ upstream_edge: Success(result=( prefect.Flow(name="DUMMY_WORFLOW_NAME"), { "DUMMY_PARAM": "DUMMY_VALUE" }, )) }, context=dict(logger=test_logger), ) get_session_mock.assert_called_once_with("DUMMY_DB_URI") get_most_recent_state_mock.assert_called_once_with( workflow_name="DUMMY_WORFLOW_NAME", parameters={"DUMMY_PARAM": "DUMMY_VALUE"}, session=get_session_mock.return_value, ) assert task_state.is_successful() assert is_skipped == task_state.is_skipped()
def test_run_workflow_fails(test_logger): """ Test that the run_workflow task fails if the workflow fails. """ function_mock = create_autospec(lambda dummy_param: None, side_effect=Exception("Workflow failed")) with prefect.Flow("Dummy workflow") as dummy_workflow: dummy_param = prefect.Parameter("dummy_param") FunctionTask(function_mock)(dummy_param=dummy_param) runner = TaskRunner(task=run_workflow) upstream_edge = Edge(prefect.Task(), run_workflow, key="parametrised_workflow") task_state = runner.run( upstream_states={ upstream_edge: Success(result=(dummy_workflow, dict(dummy_param="DUMMY_VALUE"))) }, context=dict(logger=test_logger), ) assert task_state.is_failed()
def test_client_is_always_called_even_during_state_handler_failures(client): def handler(task, old, new): 1 / 0 flow = prefect.Flow(name="test", tasks=[prefect.Task()], state_handlers=[handler]) ## flow run setup res = flow.run(state=Pending()) ## assertions assert client.get_flow_run_info.call_count == 1 # one time to pull latest state assert client.set_flow_run_state.call_count == 1 # Failed flow_states = [ call[1]["state"] for call in client.set_flow_run_state.call_args_list ] state = flow_states.pop() assert state.is_failed() assert "state handlers" in state.message assert isinstance(state.result, ZeroDivisionError) assert client.get_task_run_info.call_count == 0
def test_starting_at_arbitrary_loop_index_from_cloud_context(client): @prefect.task def looper(x): if prefect.context.get("task_loop_count", 1) < 20: raise LOOP(result=prefect.context.get("task_loop_result", 0) + x) return prefect.context.get("task_loop_result", 0) + x @prefect.task def downstream(l): return l**2 with prefect.Flow(name="looping", result_handler=JSONResultHandler()) as f: inter = looper(10) final = downstream(inter) client.get_flow_run_info = MagicMock(return_value=MagicMock( context={"task_loop_count": 20})) flow_state = CloudFlowRunner(flow=f).run(return_tasks=[inter, final]) assert flow_state.is_successful() assert flow_state.result[inter].result == 10 assert flow_state.result[final].result == 100
async def flow(self, project_id, tmpdir): """ A simple diamond flow """ flow = prefect.Flow( "pause", storage=prefect.environments.storage.Local(directory=tmpdir), environment=prefect.environments.LocalEnvironment(), ) flow.a = prefect.Task("a") flow.b = prefect.Task("b", trigger=prefect.triggers.manual_only) flow.c = prefect.Task("c") flow.add_edge(flow.a, flow.b) flow.add_edge(flow.b, flow.c) with set_temporary_config(key="dev", value=True): flow.server_id = await api.flows.create_flow( project_id=project_id, serialized_flow=flow.serialize(build=True)) return flow
def flow_no_schedule(): @prefect.task def task1(): return 5 @prefect.task def task2(x): return x @prefect.task def task3(x): return x @prefect.task def task4(x): return x flow = prefect.Flow("test_flow_name") with flow: x = task1() task2(x) task4(task3(x)) return flow
def test_client_deploy(patch_post, compressed): if compressed: response = { "data": { "project": [{"id": "proj-id"}], "createFlowFromCompressedString": {"id": "long-id"}, } } else: response = { "data": {"project": [{"id": "proj-id"}], "createFlow": {"id": "long-id"}} } patch_post(response) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): client = Client() flow = prefect.Flow(name="test", storage=prefect.environments.storage.Memory()) flow_id = client.deploy( flow, project_name="my-default-project", compressed=compressed ) assert flow_id == "long-id"
def test_run_workflow(test_logger): """ Test that the run_workflow task runs a workflow with the given parameters. """ function_mock = create_autospec(lambda dummy_param: None) with prefect.Flow("Dummy workflow") as dummy_workflow: dummy_param = prefect.Parameter("dummy_param") FunctionTask(function_mock)(dummy_param=dummy_param) runner = TaskRunner(task=run_workflow) upstream_edge = Edge(prefect.Task(), run_workflow, key="parametrised_workflow") task_state = runner.run( upstream_states={ upstream_edge: Success(result=(dummy_workflow, dict(dummy_param="DUMMY_VALUE"))) }, context=dict(logger=test_logger), ) assert task_state.is_successful() function_mock.assert_called_once_with(dummy_param="DUMMY_VALUE")
async def test_create_flow_with_only_flow_group_schedule_keeps_schedule_active( self, project_id, flow_group_id ): success = await api.flow_groups.set_flow_group_schedule( flow_group_id=flow_group_id, clocks=[{"type": "CronClock", "cron": "42 0 0 * * *"}], ) assert success is True flow_group = await models.FlowGroup.where(id=flow_group_id).first( {"schedule", "name"} ) assert flow_group.schedule is not None flow = prefect.Flow("empty Flow") flow_id = await api.flows.create_flow( project_id=project_id, serialized_flow=flow.serialize(), version_group_id=flow_group.name, ) flow = await models.Flow.where(id=flow_id).first({"is_schedule_active"}) assert flow.is_schedule_active is True
def test_flow_runner_retries_forever_on_queued_state(client, monkeypatch, num_attempts): mock_sleep = MagicMock() monkeypatch.setattr("prefect.engine.cloud.flow_runner.time.sleep", mock_sleep) run_states = [ Queued(start_time=pendulum.now("UTC").add(seconds=i)) for i in range(num_attempts - 1) ] run_states.append(Success()) mock_run = MagicMock(side_effect=run_states) client.get_flow_run_info = MagicMock( side_effect=[MagicMock(version=i) for i in range(num_attempts)]) # Mock out the actual flow execution monkeypatch.setattr("prefect.engine.cloud.flow_runner.FlowRunner.run", mock_run) @prefect.task def return_one(): return 1 with prefect.Flow("test-cloud-flow-runner-with-queues") as flow: one = return_one() # Without these (actual, not mocked) sleep calls, when running full test suite this # test can fail for no reason. final_state = CloudFlowRunner(flow=flow).run() assert final_state.is_successful() assert mock_run.call_count == num_attempts # Not called on the initial run attempt assert client.get_flow_run_info.call_count == num_attempts - 1
async def flow(self, project_id, tmpdir): """ A diamond flow whose tasks always fail the first time """ class FailOnceTask(prefect.Task): def __init__(self, name): super().__init__( name=name, retry_delay=datetime.timedelta(seconds=0), max_retries=1 ) def run(self): if prefect.context.task_run_count <= 1: raise ValueError("Run me again!") flow = prefect.Flow( "diamond fail once", storage=prefect.environments.storage.Local(directory=tmpdir), environment=prefect.environments.LocalEnvironment(), ) flow.a = FailOnceTask("a") flow.b = FailOnceTask("b") flow.c = FailOnceTask("c") flow.d = FailOnceTask("d") flow.add_edge(flow.a, flow.b) flow.add_edge(flow.a, flow.c) flow.add_edge(flow.b, flow.d) flow.add_edge(flow.c, flow.d) with set_temporary_config(key="dev", value=True): flow.server_id = await api.flows.create_flow( project_id=project_id, serialized_flow=flow.serialize(build=True) ) return flow
def test_scheduled_start_time_is_in_context(monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) flow = prefect.Flow(name="test", tasks=[whats_the_time]) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_id=whats_the_time.id, flow_run_id=flow_run_id) ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_successful() assert client.flow_runs[flow_run_id].state.is_successful() assert client.task_runs[task_run_id_1].state.is_successful() assert isinstance(state.result[whats_the_time].result, datetime.datetime)