def test_viz_if_flow_state_provided(self, state): import graphviz ipython = MagicMock( get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True)) ) with patch.dict("sys.modules", IPython=ipython): t = Task(name="a_nice_task") f = Flow(name="test") f.add_task(t) graph = f.visualize(flow_state=Success(result={t: state})) assert "label=a_nice_task" in graph.source assert 'color="' + state.color + '80"' in graph.source assert "shape=ellipse" in graph.source
def test_sorted_tasks_with_invalid_start_task(): """ t1 -> t2 -> t3 -> t4 t3 -> t5 """ f = Flow(name="test") t1 = Task("1") t2 = Task("2") t3 = Task("3") f.add_edge(t1, t2) with pytest.raises(ValueError) as exc: f.sorted_tasks(root_tasks=[t3]) assert "not found in Flow" in str(exc.value)
def test_reset_reference_tasks_to_terminal_tasks(): with Flow(name="test") as f: t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.add_edge(t2, t3) f.set_reference_tasks([t2]) assert f.reference_tasks() == set([t2]) f.set_reference_tasks([]) assert f.reference_tasks() == f.terminal_tasks()
def test_calling_a_task_returns_a_copy(): t = AddTask() with Flow(name="test") as f: t.bind(4, 2) with pytest.warns(UserWarning): t2 = t(9, 0) assert isinstance(t2, AddTask) assert t != t2 res = f.run().result assert res[t].result == 6 assert res[t2].result == 9
def test_cache_all_upstream_edges(self): f = Flow(name="test") t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.all_upstream_edges() key = ("all_upstream_edges", ()) f._cache[key] = 1 assert f.all_upstream_edges() == 1 f.add_edge(t2, t3) assert f.all_upstream_edges() != 1
def test_set_dependencies_converts_arguments_to_tasks(): class ArgTask(Task): def run(self, x): return x f = Flow(name="test") t1 = ArgTask() t2 = 2 t3 = 3 t4 = 4 f.set_dependencies( task=t1, upstream_tasks=[t2], downstream_tasks=[t3], keyword_tasks={"x": t4} ) assert len(f.tasks) == 4
def test_cache_survives_pickling(self): f = Flow(name="test") t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.sorted_tasks() key = ("_sorted_tasks", (("root_tasks", ()), )) f._cache[key] = 1 assert f.sorted_tasks() == 1 f2 = cloudpickle.loads(cloudpickle.dumps(f)) assert f2.sorted_tasks() == 1 f2.add_edge(t2, t3) assert f2.sorted_tasks() != 1
def test_replace_replaces_all_the_things(self): with Flow(name="test") as f: t1 = Task(name="t1")() t2 = Task(name="t2")(upstream_tasks=[t1]) t3 = Task(name="t3") f.set_reference_tasks([t1]) f.replace(t1, t3) assert f.tasks == {t2, t3} assert {e.upstream_task for e in f.edges} == {t3} assert {e.downstream_task for e in f.edges} == {t2} assert f.reference_tasks() == {t3} assert f.terminal_tasks() == {t2} with pytest.raises(ValueError): f.edges_to(t1)
def test_flow_dot_run_handles_cached_states(self): class MockSchedule(prefect.schedules.Schedule): call_count = 0 def next(self, n): if self.call_count < 3: self.call_count += 1 return [pendulum.now("utc")] else: raise SyntaxError("Cease scheduling!") class StatefulTask(Task): def __init__(self, maxit=False, **kwargs): self.maxit = maxit super().__init__(**kwargs) call_count = 0 def run(self): self.call_count += 1 if self.maxit: return max(self.call_count, 2) else: return self.call_count @task( cache_for=datetime.timedelta(minutes=1), cache_validator=partial_inputs_only(validate_on=["x"]), ) def return_x(x, y): return y storage = {"y": []} @task def store_y(y): storage["y"].append(y) t1, t2 = StatefulTask(maxit=True), StatefulTask() schedule = MockSchedule() with Flow(name="test", schedule=schedule) as f: res = store_y(return_x(x=t1, y=t2)) with pytest.raises(SyntaxError) as exc: f.run() assert storage == dict(y=[1, 1, 3])
def test_set_dependencies_adds_all_arguments_to_flow(): f = Flow(name="test") class ArgTask(Task): def run(self, x): return x t1 = ArgTask() t2 = Task() t3 = Task() t4 = Task() f.set_dependencies( task=t1, upstream_tasks=[t2], downstream_tasks=[t3], keyword_tasks={"x": t4} ) assert f.tasks == set([t1, t2, t3, t4])
def test_sorted_tasks_with_start_task(): """ t1 -> t2 -> t3 -> t4 t3 -> t5 """ f = Flow(name="test") t1 = Task("1") t2 = Task("2") t3 = Task("3") t4 = Task("4") t5 = Task("5") f.add_edge(t1, t2) f.add_edge(t2, t3) f.add_edge(t3, t4) f.add_edge(t3, t5) assert set(f.sorted_tasks(root_tasks=[])) == set([t1, t2, t3, t4, t5]) assert set(f.sorted_tasks(root_tasks=[t3])) == set([t3, t4, t5])
def test_copy(): with Flow(name="test") as f: t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.add_edge(t2, t3) f.set_reference_tasks([t1]) f2 = f.copy() assert f2 == f f.add_edge(Task(), Task()) assert len(f2.tasks) == len(f.tasks) - 2 assert len(f2.edges) == len(f.edges) - 1 assert f.reference_tasks() == f2.reference_tasks() == set([t1])
def test_cache_sorted_tasks(self): f = Flow(name="test") t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.sorted_tasks() # check that cache holds result key = ("_sorted_tasks", (("root_tasks", ()), )) assert f._cache[key] == (t1, t2) # check that cache is read f._cache[key] = 1 assert f.sorted_tasks() == 1 f.add_edge(t2, t3) assert f.sorted_tasks() == (t1, t2, t3)
def test_flow_raises_for_irrelevant_user_provided_parameters(): class ParameterTask(Task): def run(self): return prefect.context.get("parameters") with Flow(name="test") as f: x = Parameter("x") t = ParameterTask() f.add_task(x) f.add_task(t) # errors because of the invalid parameter with pytest.raises(ValueError): state = f.run(parameters=dict(x=10, y=3, z=9)) # errors because the parameter is passed to FlowRunner.run() as an invalid kwarg with pytest.raises(TypeError): state = f.run(x=10, y=3, z=9)
def test_upstream_and_downstream_error_msgs_when_task_is_not_in_flow(): f = Flow(name="test") t = Task() with pytest.raises(ValueError) as e: f.edges_to(t) assert "was not found in Flow" in e with pytest.raises(ValueError) as e: f.edges_from(t) assert "was not found in Flow" in e with pytest.raises(ValueError) as e: f.upstream_tasks(t) assert "was not found in Flow" in e with pytest.raises(ValueError) as e: f.downstream_tasks(t) assert "was not found in Flow" in e
def test_replace_runs_smoothly(self): add = AddTask() class SubTask(Task): def run(self, x, y): return x - y sub = SubTask() with Flow(name="test") as f: x, y = Parameter("x"), Parameter("y") res = add(x, y) state = f.run(x=10, y=11) assert state.result[res].result == 21 f.replace(res, sub) state = f.run(x=10, y=11) assert state.result[sub].result == -1
def test_cache_terminal_tasks(self): f = Flow(name="test") t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) f.terminal_tasks() # check that cache holds result key = ("terminal_tasks", ()) assert f._cache[key] == set([t2]) # check that cache is read f._cache[key] = 1 assert f.terminal_tasks() == 1 f.add_edge(t2, t3) assert f.terminal_tasks() == set([t3])
def test_cache_task_ids(self): f = Flow(name="test") t1 = Task() t2 = Task() t3 = Task() f.add_edge(t1, t2) ids = f.task_ids # check that cache holds result key = ("task_ids", ()) assert f._cache[key] == ids # check that cache is read f._cache[key] = 1 assert f.task_ids == 1 f.add_edge(t2, t3) assert len(f.task_ids) == 3
def test_binding_a_task_with_var_kwargs_expands_the_kwargs(): class KwargsTask(Task): def run(self, **kwargs): return kwargs t1 = Task() t2 = Task() t3 = Task() kw = KwargsTask() with Flow(name="test") as f: kw.bind(a=t1, b=t2, c=t3) assert t1 in f.tasks assert t2 in f.tasks assert t3 in f.tasks assert Edge(t1, kw, key="a") in f.edges assert Edge(t2, kw, key="b") in f.edges assert Edge(t3, kw, key="c") in f.edges
def test_deserialization(self): p1, t2, t3, = Parameter("1"), Task("2"), Task("3") f = Flow( name="hi", tasks=[p1, t2, t3], schedule=prefect.schedules.CronSchedule("0 0 * * *"), ) f.add_edge(p1, t2) f.add_edge(p1, t3) serialized = f.serialize() f2 = prefect.serialization.flow.FlowSchema().load(serialized) assert len(f2.tasks) == 3 assert len(f2.edges) == 2 assert len(f2.reference_tasks()) == 2 assert {t.name for t in f2.reference_tasks()} == {"2", "3"} assert f2.name == f.name assert isinstance(f2.schedule, prefect.schedules.CronSchedule)
def test_eager_validation_is_off_by_default(monkeypatch): # https://github.com/PrefectHQ/prefect/issues/919 assert not prefect.config.flows.eager_edge_validation validate = MagicMock() monkeypatch.setattr("prefect.core.flow.Flow.validate", validate) @task def length(x): return len(x) data = list(range(10)) with Flow(name="test") as f: length.map(data) assert validate.call_count == 0 f.validate() assert validate.call_count == 1
def test_flow_dot_run_stops_on_schedule(self): class MockSchedule(prefect.schedules.Schedule): call_count = 0 def next(self, n): if self.call_count < 1: self.call_count += 1 return [pendulum.now("utc").add(seconds=0.05)] else: return [] class StatefulTask(Task): call_count = 0 def run(self): self.call_count += 1 t = StatefulTask() schedule = MockSchedule() f = Flow(name="test", tasks=[t], schedule=schedule) f.run() assert t.call_count == 1
def test_scheduled_runs_handle_retries(): class MockSchedule(prefect.schedules.Schedule): call_count = 0 def next(self, n): if self.call_count < 1: self.call_count += 1 return [pendulum.now("utc")] else: raise SyntaxError("Cease scheduling!") class StatefulTask(Task): call_count = 0 def run(self): self.call_count += 1 if self.call_count == 1: raise OSError("I need to run again.") state_history = [] def handler(task, old, new): state_history.append(new) return new t = StatefulTask( max_retries=1, retry_delay=datetime.timedelta(minutes=0), state_handlers=[handler], ) schedule = MockSchedule() f = Flow(name="test", tasks=[t], schedule=schedule) with pytest.raises(SyntaxError) as exc: f.run() assert "Cease" in str(exc.value) assert t.call_count == 2 assert len( state_history) == 5 # Running, Failed, Retrying, Running, Success
def test_flow_dot_run_doesnt_run_on_schedule(self): class MockSchedule(prefect.schedules.Schedule): call_count = 0 def next(self, n): if self.call_count < 2: self.call_count += 1 # add small delta to trigger "naptime" return [pendulum.now("utc").add(seconds=0.05)] else: raise SyntaxError("Cease scheduling!") class StatefulTask(Task): call_count = 0 def run(self): self.call_count += 1 t = StatefulTask() schedule = MockSchedule() f = Flow(name="test", tasks=[t], schedule=schedule) state = f.run(run_on_schedule=False) assert t.call_count == 1
def test_flow_dot_run_runs_on_schedule(): class MockSchedule(prefect.schedules.Schedule): call_count = 0 def next(self, n): if self.call_count < 2: self.call_count += 1 return [pendulum.now("utc")] else: raise SyntaxError("Cease scheduling!") class StatefulTask(Task): call_count = 0 def run(self): self.call_count += 1 t = StatefulTask() schedule = MockSchedule() f = Flow(name="test", tasks=[t], schedule=schedule) with pytest.raises(SyntaxError) as exc: f.run() assert "Cease" in str(exc.value) assert t.call_count == 2
def test_viz_reflects_mapping_if_flow_state_provided(self): ipython = MagicMock( get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True)) ) add = AddTask(name="a_nice_task") list_task = Task(name="a_list_task") map_state = Mapped(map_states=[Success(), Failed()]) with patch.dict("sys.modules", IPython=ipython): with Flow(name="test") as f: res = add.map(x=list_task, y=8) graph = f.visualize( flow_state=Success(result={res: map_state, list_task: Success()}) ) # one colored node for each mapped result assert ( 'label="a_nice_task <map>" color="{success}80"'.format( success=Success.color ) in graph.source ) assert ( 'label="a_nice_task <map>" color="{failed}80"'.format(failed=Failed.color) in graph.source ) assert ( 'label=a_list_task color="{success}80"'.format(success=Success.color) in graph.source ) assert 'label=8 color="#00000080"' in graph.source # two edges for each input to add() for var in ["x", "y"]: for index in [0, 1]: assert "{0} [label={1} style=dashed]".format(index, var) in graph.source
def test_create_flow_with_on_failure(self): f = Flow(name="test", on_failure=lambda *args: None) assert len(f.state_handlers) == 1
def test_skip_validate_edges(): f = Flow(name="test") t1, t2 = Task(), Task() # these tasks don't support keyed edges f.add_edge(t1, t2, key="x", validate=False) f.add_edge(t2, t1, validate=False) # this introduces a cycle
def test_create_flow_without_state_handler(self): assert Flow(name="test").state_handlers == []
def test_validate_edges(): with set_temporary_config({"flows.eager_edge_validation": True}): f = Flow(name="test") t1, t2 = Task(), Task() # these tasks don't support keyed edges with pytest.raises(TypeError): f.add_edge(t1, t2, key="x")