Exemple #1
0
def test_lightning_app_checkpointing_with_nested_flows():
    work = CheckpointCounter()
    app = LightningApp(CheckpointFlow(work))
    app.checkpointing = True
    SingleProcessRuntime(app, start_server=False).dispatch()

    assert app.root.counter == 6
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5

    work = CheckpointCounter()
    app = LightningApp(CheckpointFlow(work))
    assert app.root.counter == 0
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0

    app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
    # The counter was increment to 6 after the latest checkpoints was created.
    assert app.root.counter == 5
    assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5
def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once):
    app = LightningApp(CFlow(run_once))
    runtime_cls(app, start_server=False).dispatch()
    assert app.root.looping == 0
    assert app.root.tracker == 4
    call_hash = list(v for v in app.root._calls if "experimental_iterate" in v)[0]
    iterate_call = app.root._calls[call_hash]
    assert iterate_call["counter"] == 4
    assert not iterate_call["has_finished"]

    checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints")
    app = LightningApp(CFlow(run_once))
    app.load_state_dict_from_checkpoint_dir(checkpoint_dir)
    app.root.restarting = True
    assert app.root.looping == 0
    assert app.root.tracker == 4
    runtime_cls(app, start_server=False).dispatch()
    assert app.root.looping == 2
    assert app.root.tracker == 10 if run_once else 20
    iterate_call = app.root._calls[call_hash]
    assert iterate_call["has_finished"]
Exemple #3
0
def test_load_state_dict_from_checkpoint_dir(tmpdir):
    work = CheckpointCounter()
    app = LightningApp(CheckpointFlow(work))

    checkpoints = []
    num_checkpoints = 11
    # generate 11 checkpoints.
    for _ in range(num_checkpoints):
        checkpoints.append(app._dump_checkpoint())
        app.root.counter += 1

    app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
    assert app.root.counter == (num_checkpoints - 1)

    for version in range(num_checkpoints):
        app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=version)
        assert app.root.counter == version

    with pytest.raises(FileNotFoundError, match="The provided directory"):
        app.load_state_dict_from_checkpoint_dir("./random_folder/")

    with pytest.raises(Exception, match="No checkpoints where found"):
        app.load_state_dict_from_checkpoint_dir(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/")))

    # delete 2 checkpoints
    os.remove(os.path.join(checkpoints[4]))
    os.remove(os.path.join(checkpoints[7]))

    app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
    assert app.root.counter == (num_checkpoints - 1)

    app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=5)
    checkpoint_path = app._dump_checkpoint()

    assert os.path.basename(checkpoint_path).startswith("v_11")