def test_lightning_flow_counter(runtime_cls, tmpdir): app = LightningApp(FlowCounter()) app.checkpointing = True runtime_cls(app, start_server=False).dispatch() assert app.root.counter == 3 checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints") checkpoints = os.listdir(checkpoint_dir) assert len(checkpoints) == 4 for checkpoint in checkpoints: checkpoint_path = os.path.join(checkpoint_dir, checkpoint) with open(checkpoint_path, "rb") as f: app = LightningApp(FlowCounter()) app.set_state(pickle.load(f)) runtime_cls(app, start_server=False).dispatch() assert app.root.counter == 3
def test_lightning_app_checkpointing_with_nested_flows(): work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) app.checkpointing = True SingleProcessRuntime(app, start_server=False).dispatch() assert app.root.counter == 6 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5 work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) assert app.root.counter == 0 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0 app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) # The counter was increment to 6 after the latest checkpoints was created. assert app.root.counter == 5 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5