def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQueue, sleep_time, expect): """This test validates the `_collect_deltas_from_ui_and_work_queues` can aggregate multiple delta together in a time window.""" class SlowQueue(queue_type_cls): def get(self, timeout): out = super().get(timeout) sleep(sleep_time) return out app = LightningApp(EmptyFlow()) app.api_delta_queue = SlowQueue("api_delta_queue", default_timeout) if queue_type_cls is RedisQueue: app.api_delta_queue.clear() def make_delta(i): return Delta({"values_changed": {"root['vars']['counter']": {"new_value": i}}}) # flowed the queue with mocked delta for i in range(expect + 10): app.api_delta_queue.put(make_delta(i)) # Wait for a bit because multiprocessing.Queue doesn't run in the same thread and takes some time for writes sleep(0.001) delta = app._collect_deltas_from_ui_and_work_queues()[-1] generated = delta.to_dict()["values_changed"]["root['vars']['counter']"]["new_value"] if sleep_time: assert generated == expect else: # validate the flow should have aggregated at least expect. assert generated > expect
def test_maybe_apply_changes_from_flow(): """This test validates the app `_updated` is set to True only if the state was changed in the flow.""" app = LightningApp(SimpleFlow()) assert not app._has_updated app.maybe_apply_changes() app.root.run() app.maybe_apply_changes() assert app._has_updated app._has_updated = False app.maybe_apply_changes() assert not app._has_updated
def test_populate_changes(): class WorkA(LightningWork): def __init__(self): super().__init__() self.counter = 0 def run(self): pass class A(LightningFlow): def __init__(self): super().__init__() self.work = WorkA() def run(self): pass flow_a = A() flow_state = flow_a.state work_state = flow_a.work.state flow_a.work.counter = 1 work_state_2 = flow_a.work.state delta = Delta(DeepDiff(work_state, work_state_2)) delta = _delta_to_appstate_delta(flow_a, flow_a.work, delta) new_flow_state = LightningApp.populate_changes(flow_state, flow_state + delta) flow_a.set_state(new_flow_state) assert flow_a.work.counter == 1 assert new_flow_state["works"]["work"]["changes"] == {"counter": {"from": 0, "to": 1}} assert flow_a.work._changes == {"counter": {"from": 0, "to": 1}}
def test_simple_app(component_cls, runtime_cls, tmpdir): comp = component_cls() app = LightningApp(comp, debug=True) assert app.root == comp expected = { "app_state": ANY, "vars": {"_layout": ANY, "_paths": {}}, "calls": {}, "flows": {}, "works": { "work_b": { "vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}}, "calls": {}, "changes": {}, }, "work_a": { "vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}}, "calls": {}, "changes": {}, }, }, "changes": {}, } assert app.state == expected runtime_cls(app, start_server=False).dispatch() assert comp.work_a.has_finished assert comp.work_b.has_finished # possible the `work_a` takes for ever to # start and `work_b` has already completed multiple iterations. assert comp.work_a.counter == 1 assert comp.work_b.counter >= 3
def test_flow_state_change_with_path(): """Test that type changes to a Path attribute are properly reflected within the state.""" class Flow(LightningFlow): def __init__(self): super().__init__() self.none_to_path = None self.path_to_none = Path() self.path_to_path = Path() def run(self): self.none_to_path = "lit://none/to/path" self.path_to_none = None self.path_to_path = "lit://path/to/path" self._exit() flow = Flow() MultiProcessRuntime(LightningApp(flow)).dispatch() assert flow.none_to_path == Path("lit://none/to/path") assert flow.path_to_none is None assert flow.path_to_path == Path("lit://path/to/path") assert "path_to_none" not in flow._paths assert "path_to_none" in flow._state assert flow._paths["none_to_path"] == Path("lit://none/to/path").to_dict() assert flow._paths["path_to_path"] == Path("lit://path/to/path").to_dict() assert flow.state["vars"]["none_to_path"] == Path("lit://none/to/path") assert flow.state["vars"]["path_to_none"] is None assert flow.state["vars"]["path_to_path"] == Path("lit://path/to/path")
def test_populate_changes_status_removed(): """Regression test for https://github.com/Lightning-AI/lightning/issues/342.""" last_state = { "vars": {}, "calls": {}, "flows": {}, "works": { "work": { "vars": {}, "calls": { "latest_call_hash": "run:fe3f", "run:fe3f": { "statuses": [ {"stage": "requesting", "message": None, "reason": None, "timestamp": 1}, {"stage": "starting", "message": None, "reason": None, "timestamp": 2}, {"stage": "requesting", "message": None, "reason": None, "timestamp": 3}, ], }, }, "changes": {}, }, }, "changes": {}, } new_state = deepcopy(last_state) call = new_state["works"]["work"]["calls"]["run:fe3f"] call["statuses"] = call["statuses"][:-1] # pretend that a status was removed from the list new_state_before = deepcopy(new_state) new_state = LightningApp.populate_changes(last_state, new_state) assert new_state == new_state_before
def test_nested_component(runtime_cls): app = LightningApp(A(), debug=True) runtime_cls(app, start_server=False).dispatch() assert app.root.w_a.c == 1 assert app.root.b.w_b.c == 1 assert app.root.b.c.w_c.c == 1 assert app.root.b.c.d.w_d.c == 1 assert app.root.b.c.d.e.w_e.c == 1
def test_get_component_by_name(): app = LightningApp(A()) assert app.get_component_by_name("root") is app.root assert app.get_component_by_name("root.b") is app.root.b assert app.get_component_by_name("root.w_a") is app.root.w_a assert app.get_component_by_name("root.b.w_b") is app.root.b.w_b assert app.get_component_by_name("root.b.c.d.e") is app.root.b.c.d.e
def run_work_isolated(work, *args, start_server: bool = False, **kwargs): """This function is used to run a work a single time with multiprocessing runtime.""" MultiProcessRuntime( LightningApp(SingleWorkFlow(work, args, kwargs), debug=True), start_server=start_server, ).dispatch() # pop the stopped status. call_hash = work._calls["latest_call_hash"] work._calls[call_hash]["statuses"].pop(-1)
def test_invalid_layout(return_val): class Root(EmptyFlow): def configure_layout(self): return return_val root = Root() with pytest.raises( TypeError, match=escape("The return value of configure_layout() in `Root`")): LightningApp(root)
def test_app_state_api(runtime_cls): """This test validates the AppState can properly broadcast changes from work within its own process.""" app = LightningApp(_A()) runtime_cls(app, start_server=True).dispatch() assert app.root.work_a.var_a == -1 _set_work_context() assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"] _set_frontend_context() assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"] os.remove("test_app_state_api.txt")
def test_multiprocess_starts_frontend_servers(*_): """Test that the MultiProcessRuntime starts the servers for the frontends in each LightningFlow.""" root = StartFrontendServersTestFlow() app = LightningApp(root) MultiProcessRuntime(app).dispatch() app.frontends[root.flow0.name].start_server.assert_called_once() app.frontends[root.flow1.name].start_server.assert_called_once() app.frontends[root.flow0.name].stop_server.assert_called_once() app.frontends[root.flow1.name].stop_server.assert_called_once()
def test_component_affiliation(): app = LightningApp(AA()) a_affiliation = affiliation(app.root) assert a_affiliation == () b_affiliation = affiliation(app.root.b) assert b_affiliation == ("b",) c1_affiliation = affiliation(app.root.b.c1) assert c1_affiliation == ("b", "c1") c2_affiliation = affiliation(app.root.b.c2) assert c2_affiliation == ("b", "c2") work_cc_affiliation = affiliation(app.root.b.c2.work_cc) assert work_cc_affiliation == ("b", "c2", "work_cc")
def test_invalid_layout_unsupported_content_value(): class Root(EmptyFlow): def configure_layout(self): return [dict(name="one", content=[1, 2, 3])] root = Root() with pytest.raises( ValueError, match=escape("A dictionary returned by `Root.configure_layout()"), ): LightningApp(root)
def test_invalid_layout_missing_content_key(): class Root(EmptyFlow): def configure_layout(self): return [dict(name="one")] root = Root() with pytest.raises( ValueError, match=escape( "A dictionary returned by `Root.configure_layout()` is missing a key 'content'." )): LightningApp(root)
def test_single_content_layout(): """Test that returning a single dict also works (does not have to be returned in a list).""" class TestContentComponent(EmptyFlow): def __init__(self): super().__init__() self.component0 = EmptyFlow() self.component1 = EmptyFlow() def configure_layout(self): return dict(name="single", content=self.component1) root = TestContentComponent() LightningApp(root) assert root._layout == [dict(name="single", content="root.component1")]
def test_lightning_flow_counter(runtime_cls, tmpdir): app = LightningApp(FlowCounter()) app.checkpointing = True runtime_cls(app, start_server=False).dispatch() assert app.root.counter == 3 checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints") checkpoints = os.listdir(checkpoint_dir) assert len(checkpoints) == 4 for checkpoint in checkpoints: checkpoint_path = os.path.join(checkpoint_dir, checkpoint) with open(checkpoint_path, "rb") as f: app = LightningApp(FlowCounter()) app.set_state(pickle.load(f)) runtime_cls(app, start_server=False).dispatch() assert app.root.counter == 3
def test_layout_leaf_node(find_ports_mock, flow): find_ports_mock.side_effect = lambda: 100 app = LightningApp(flow) assert flow._layout == {} # we copy the dict here because after we dispatch the dict will get update with new instances # as the layout gets updated during the loop. frontends = app.frontends.copy() MultiProcessRuntime(app).dispatch() assert flow.counter == 3 # The target url is available for the frontend after we started the servers in dispatch assert flow._layout == dict(target="http://localhost:100/root") assert app.frontends[flow.name].flow is flow # we start the servers for the frontends that we collected at the time of app instantiation frontends[flow.name].start_server.assert_called_once() # leaf layout nodes can't be changed, they stay the same from when they first got configured assert app.frontends[flow.name] == frontends[flow.name]
def test_lightning_app_checkpointing_with_nested_flows(): work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) app.checkpointing = True SingleProcessRuntime(app, start_server=False).dispatch() assert app.root.counter == 6 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5 work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) assert app.root.counter == 0 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0 app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) # The counter was increment to 6 after the latest checkpoints was created. assert app.root.counter == 5 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5
def test_get_send_request(monkeypatch): app = LightningApp(Flow()) monkeypatch.setattr(lightning_app.utilities.state, "_configure_session", mock.MagicMock()) state = AppState(plugin=AppStatePlugin()) state._session.get._mock_return_value = MockResponse( app.state_with_changes, 500) state._request_state() state._session.get._mock_return_value = MockResponse( app.state_with_changes, 200) state._request_state() assert state._my_affiliation == () with pytest.raises(Exception, match="The response from"): state._session.post._mock_return_value = MockResponse( app.state_with_changes, 500) state.w.counter = 1 state._session.post._mock_return_value = MockResponse( app.state_with_changes, 200) state.w.counter = 1
def test_default_content_layout(): class SimpleFlow(EmptyFlow): def configure_layout(self): frontend = StaticWebFrontend(serve_dir="a/b/c") frontend.start_server = Mock() return frontend class TestContentComponent(EmptyFlow): def __init__(self): super().__init__() self.component0 = SimpleFlow() self.component1 = SimpleFlow() self.component2 = SimpleFlow() root = TestContentComponent() LightningApp(root) assert root._layout == [ dict(name="component0", content="root.component0"), dict(name="component1", content="root.component1"), dict(name="component2", content="root.component2"), ]
def _run_state_transformation(tmpdir, attribute, update_fn, inplace=False): """This helper function defines a flow, assignes an attribute and performs a transformation on the state.""" class StateTransformationTest(LightningFlow): def __init__(self): super().__init__() self.x = attribute self.finished = False def run(self): if self.finished: self._exit() x = update_fn(self.x) if not inplace: self.x = x self.finished = True flow = StateTransformationTest() assert flow.x == attribute app = LightningApp(flow) SingleProcessRuntime(app, start_server=False).dispatch() return app.state["vars"]["x"]
def test_url_content_layout(): class TestContentComponent(EmptyFlow): def __init__(self): super().__init__() self.component0 = EmptyFlow() self.component1 = EmptyFlow() def configure_layout(self): return [ dict(name="one", content=self.component0), dict(name="url", content="https://lightning.ai"), dict(name="two", content=self.component1), ] root = TestContentComponent() LightningApp(root) assert root._layout == [ dict(name="one", content="root.component0"), dict(name="url", content="https://lightning.ai", target="https://lightning.ai"), dict(name="two", content="root.component1"), ]
def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once): app = LightningApp(CFlow(run_once)) runtime_cls(app, start_server=False).dispatch() assert app.root.looping == 0 assert app.root.tracker == 4 call_hash = list(v for v in app.root._calls if "experimental_iterate" in v)[0] iterate_call = app.root._calls[call_hash] assert iterate_call["counter"] == 4 assert not iterate_call["has_finished"] checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints") app = LightningApp(CFlow(run_once)) app.load_state_dict_from_checkpoint_dir(checkpoint_dir) app.root.restarting = True assert app.root.looping == 0 assert app.root.tracker == 4 runtime_cls(app, start_server=False).dispatch() assert app.root.looping == 2 assert app.root.tracker == 10 if run_once else 20 iterate_call = app.root._calls[call_hash] assert iterate_call["has_finished"]
def run(self): self.is_running_now = True print("work_is_running") for i in range(1, 10): time.sleep(1) if i % 5 == 0: raise Exception(f"invalid_value_of_i_{i}") print(f"good_value_of_i_{i}") class RootFlow(LightningFlow): def __init__(self): super().__init__() self.simple_work = SimpleWork() def run(self): print("useless_garbage_log_that_is_always_there_to_overload_logs") self.simple_work.run() if not self.simple_work.is_running_now: pass # work is not ready yet print("waiting_for_work_to_be_ready") else: print("flow_and_work_are_running") logger.info("logger_flow_work") time.sleep(0.1) if __name__ == "__main__": app = LightningApp(RootFlow())
def test_payload_works(tmpdir): """This tests validates the payload api can be used to transfer return values from a work to another.""" with mock.patch("lightning_app.storage.path.storage_root_dir", lambda: pathlib.Path(tmpdir)): app = LightningApp(Flow(), debug=True) MultiProcessRuntime(app, start_server=False).dispatch()
def test_app_state_api_with_flows(runtime_cls, tmpdir): """This test validates the AppState can properly broadcast changes from flows.""" app = LightningApp(A2(), debug=True) runtime_cls(app, start_server=True).dispatch() assert app.root.var_a == -1
def test_multiprocess_runtime_sets_context(): """Test that the runtime sets the global variable COMPONENT_CONTEXT in Flow and Work.""" MultiProcessRuntime(LightningApp(ContxtFlow())).dispatch()
def test_lightning_stop(): app = LightningApp(FlowStop()) MultiProcessRuntime(app, start_server=False).dispatch()
def test_lightning_app_exit(): app = LightningApp(FlowExit()) MultiProcessRuntime(app).dispatch() assert app.root.work.status.stage == WorkStageStatus.STOPPED