def test_multiple_ckpts_and_criteria(tmpdir):
    from speechbrain.utils.checkpoints import Checkpointer
    import torch

    class Recoverable(torch.nn.Module):
        def __init__(self, param):
            self.param = torch.nn.Parameter(torch.tensor([param]))

        def forward(self, x):
            return x * self.param

    recoverable = Recoverable(1.0)
    recoverables = {"recoverable": recoverable}
    recoverer = Checkpointer(tmpdir, recoverables)

    # Here testing multiple checkpoints with equal meta criteria
        meta={"error": 5}, min_keys=["error"], keep_recent=True
    # By default, get the most recent one:
    first_ckpt = recoverer.find_checkpoint()
        meta={"error": 5}, min_keys=["error"], keep_recent=True
    second_ckpt = recoverer.find_checkpoint()
    assert first_ckpt.meta["unixtime"] < second_ckpt.meta["unixtime"]
        meta={"error": 6}, min_keys=["error"], keep_recent=True
    third_ckpt = recoverer.find_checkpoint()
    remaining_ckpts = recoverer.list_checkpoints()
    assert first_ckpt not in remaining_ckpts
    assert second_ckpt in remaining_ckpts
    assert third_ckpt in remaining_ckpts

    # With equal importance criteria, the latest checkpoint should always be
    # returned
    fourth_ckpt = recoverer.save_checkpoint(meta={"error": 5})
    found_ckpt = recoverer.find_checkpoint(min_key="error")
    assert found_ckpt == fourth_ckpt
    fifth_ckpt = recoverer.save_checkpoint(meta={"error": 5})
    # Similarly for getting multiple checkpoints:
    found_ckpts = recoverer.find_checkpoints(
        min_key="error", max_num_checkpoints=2
    assert found_ckpts == [fifth_ckpt, fourth_ckpt]
def test_checkpointer(tmpdir, device):
    from speechbrain.utils.checkpoints import Checkpointer
    import torch

    class Recoverable(torch.nn.Module):
        def __init__(self, param):
            self.param = torch.nn.Parameter(torch.tensor([param]))

        def forward(self, x):
            return x * self.param

    recoverable = Recoverable(2.0)
    recoverables = {"recoverable": recoverable}
    recoverer = Checkpointer(tmpdir, recoverables)
    recoverable.param.data = torch.tensor([1.0], device=device)
    # Should not be possible since no checkpoint saved yet:
    assert not recoverer.recover_if_possible()
    result = recoverable(10.0)
    # Check that parameter has not been loaded from original value:
    assert recoverable.param.data == torch.tensor([1.0], device=device)

    ckpt = recoverer.save_checkpoint()
    # Check that the name recoverable has a save file:
    # NOTE: Here assuming .pt filename; if convention changes, change test
    assert (ckpt.path / "recoverable.ckpt").exists()
    # Check that saved checkpoint is found, and location correct:
    assert recoverer.list_checkpoints()[0] == ckpt
    assert recoverer.list_checkpoints()[0].path.parent == tmpdir
    recoverable.param.data = torch.tensor([2.0], device=device)
    # Check that parameter has been loaded immediately:
    assert recoverable.param.data == torch.tensor([1.0], device=device)
    result = recoverable(10.0)
    # And result correct
    assert result == 10.0

    other = Recoverable(2.0)
    recoverer.add_recoverable("other", other)
    # Check that both objects are now found:
    assert recoverer.recoverables["recoverable"] == recoverable
    assert recoverer.recoverables["other"] == other
    new_ckpt = recoverer.save_checkpoint()
    # Check that now both recoverables have a save file:
    assert (new_ckpt.path / "recoverable.ckpt").exists()
    assert (new_ckpt.path / "other.ckpt").exists()
    assert new_ckpt in recoverer.list_checkpoints()
    recoverable.param.data = torch.tensor([2.0], device=device)
    other.param.data = torch.tensor([10.0], device=device)
    chosen_ckpt = recoverer.recover_if_possible()
    # Should choose newest by default:
    assert chosen_ckpt == new_ckpt
    # Check again that parameters have been loaded immediately:
    assert recoverable.param.data == torch.tensor([1.0], device=device)
    assert other.param.data == torch.tensor([2.0], device=device)
    other_result = other(10.0)
    # And again we should have the correct computations:
    assert other_result == 20.0

    # Recover from oldest, which does not have "other":
    # This also tests a custom sort
    # Raises by default:
    with pytest.raises(RuntimeError):
        chosen_ckpt = recoverer.recover_if_possible(
            importance_key=lambda x: -x.meta["unixtime"]
    # However this operation may have loaded the first object
    # so let's set the values manually:
    recoverable.param.data = torch.tensor([2.0], device=device)
    other.param.data = torch.tensor([10.0], device=device)
    recoverer.allow_partial_load = True
    chosen_ckpt = recoverer.recover_if_possible(
        importance_key=lambda x: -x.meta["unixtime"]
    # Should have chosen the original:
    assert chosen_ckpt == ckpt
    # And should recover recoverable:
    assert recoverable.param.data == torch.tensor([1.0], device=device)
    # But not other:
    other_result = other(10.0)
    assert other.param.data == torch.tensor([10.0], device=device)
    assert other_result == 100.0

    # Test saving names checkpoints with meta info, and custom filter
    epoch_ckpt = recoverer.save_checkpoint(name="ep1", meta={"loss": 2.0})
    assert "ep1" in epoch_ckpt.path.name
    other.param.data = torch.tensor([2.0], device=device)
    recoverer.save_checkpoint(meta={"loss": 3.0})
    chosen_ckpt = recoverer.recover_if_possible(
        ckpt_predicate=lambda ckpt: "loss" in ckpt.meta,
        importance_key=lambda ckpt: -ckpt.meta["loss"],
    assert chosen_ckpt == epoch_ckpt
    assert other.param.data == torch.tensor([10.0], device=device)

    # Make sure checkpoints can't be name saved by the same name
    with pytest.raises(FileExistsError):
def test_checkpoint_deletion(tmpdir, device):
    from speechbrain.utils.checkpoints import Checkpointer
    import torch

    class Recoverable(torch.nn.Module):
        def __init__(self, param):
            self.param = torch.nn.Parameter(
                torch.tensor([param], device=device)

        def forward(self, x):
            return x * self.param

    recoverable = Recoverable(1.0)
    recoverables = {"recoverable": recoverable}
    recoverer = Checkpointer(tmpdir, recoverables)
    first_ckpt = recoverer.save_checkpoint()
    # Will not delete only checkpoint by default:
    assert first_ckpt in recoverer.list_checkpoints()
    second_ckpt = recoverer.save_checkpoint()
    # Oldest checkpoint is deleted by default:
    assert first_ckpt not in recoverer.list_checkpoints()
    # Other syntax also should work:
    assert second_ckpt not in recoverer.list_checkpoints()
    # Can delete all checkpoints:
    assert not recoverer.list_checkpoints()

    # Now each should be kept:
    # Highest foo
    c1 = recoverer.save_checkpoint(meta={"foo": 2})
    # Latest CKPT after filtering
    c2 = recoverer.save_checkpoint(meta={"foo": 1})
    # Filtered out
    c3 = recoverer.save_checkpoint(meta={"epoch_ckpt": True})
        importance_keys=[lambda c: c.meta["unixtime"]],
        ckpt_predicate=lambda c: "epoch_ckpt" not in c.meta,
    assert all(c in recoverer.list_checkpoints() for c in [c1, c2, c3])
    # Reset:
    assert not recoverer.list_checkpoints()

    # Test the keeping multiple checkpoints without predicate:
    # This should be deleted:
    c_to_delete = recoverer.save_checkpoint(meta={"foo": 2})
    # Highest foo
    c1 = recoverer.save_checkpoint(meta={"foo": 3})
    # Latest CKPT after filtering
    c2 = recoverer.save_checkpoint(meta={"foo": 1})
        importance_keys=[lambda c: c.meta["unixtime"], lambda c: c.meta["foo"]],
    assert all(c in recoverer.list_checkpoints() for c in [c1, c2])
    assert c_to_delete not in recoverer.list_checkpoints()