def test_lightning_flow_counter(runtime_cls, tmpdir):

    app = LightningApp(FlowCounter())
    app.checkpointing = True
    runtime_cls(app, start_server=False).dispatch()
    assert app.root.counter == 3

    checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints")
    checkpoints = os.listdir(checkpoint_dir)
    assert len(checkpoints) == 4
    for checkpoint in checkpoints:
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint)
        with open(checkpoint_path, "rb") as f:
            app = LightningApp(FlowCounter())
            app.set_state(pickle.load(f))
            runtime_cls(app, start_server=False).dispatch()
            assert app.root.counter == 3
Пример #2
0
def test_lightning_app_checkpointing_with_nested_flows():
    work = CheckpointCounter()
    app = LightningApp(CheckpointFlow(work))
    app.checkpointing = True
    SingleProcessRuntime(app, start_server=False).dispatch()

    assert app.root.counter == 6
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5

    work = CheckpointCounter()
    app = LightningApp(CheckpointFlow(work))
    assert app.root.counter == 0
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0

    app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
    # The counter was increment to 6 after the latest checkpoints was created.
    assert app.root.counter == 5
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5