def test_lightning_app_checkpointing_with_nested_flows(): work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) app.checkpointing = True SingleProcessRuntime(app, start_server=False).dispatch() assert app.root.counter == 6 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5 work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) assert app.root.counter == 0 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0 app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) # The counter was increment to 6 after the latest checkpoints was created. assert app.root.counter == 5 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5
def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once): app = LightningApp(CFlow(run_once)) runtime_cls(app, start_server=False).dispatch() assert app.root.looping == 0 assert app.root.tracker == 4 call_hash = list(v for v in app.root._calls if "experimental_iterate" in v)[0] iterate_call = app.root._calls[call_hash] assert iterate_call["counter"] == 4 assert not iterate_call["has_finished"] checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints") app = LightningApp(CFlow(run_once)) app.load_state_dict_from_checkpoint_dir(checkpoint_dir) app.root.restarting = True assert app.root.looping == 0 assert app.root.tracker == 4 runtime_cls(app, start_server=False).dispatch() assert app.root.looping == 2 assert app.root.tracker == 10 if run_once else 20 iterate_call = app.root._calls[call_hash] assert iterate_call["has_finished"]
def test_load_state_dict_from_checkpoint_dir(tmpdir): work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) checkpoints = [] num_checkpoints = 11 # generate 11 checkpoints. for _ in range(num_checkpoints): checkpoints.append(app._dump_checkpoint()) app.root.counter += 1 app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) assert app.root.counter == (num_checkpoints - 1) for version in range(num_checkpoints): app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=version) assert app.root.counter == version with pytest.raises(FileNotFoundError, match="The provided directory"): app.load_state_dict_from_checkpoint_dir("./random_folder/") with pytest.raises(Exception, match="No checkpoints where found"): app.load_state_dict_from_checkpoint_dir(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/"))) # delete 2 checkpoints os.remove(os.path.join(checkpoints[4])) os.remove(os.path.join(checkpoints[7])) app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) assert app.root.counter == (num_checkpoints - 1) app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=5) checkpoint_path = app._dump_checkpoint() assert os.path.basename(checkpoint_path).startswith("v_11")