示例#1
0
def test_multiple_ckpts_and_criteria(tmpdir):
    from speechbrain.utils.checkpoints import Checkpointer
    import torch

    class Recoverable(torch.nn.Module):
        def __init__(self, param):
            super().__init__()
            self.param = torch.nn.Parameter(torch.tensor([param]))

        def forward(self, x):
            return x * self.param

    recoverable = Recoverable(1.0)
    recoverables = {"recoverable": recoverable}
    recoverer = Checkpointer(tmpdir, recoverables)

    # Here testing multiple checkpoints with equal meta criteria
    recoverer.save_and_keep_only(
        meta={"error": 5}, min_keys=["error"], keep_recent=True
    )
    # By default, get the most recent one:
    first_ckpt = recoverer.find_checkpoint()
    recoverer.save_and_keep_only(
        meta={"error": 5}, min_keys=["error"], keep_recent=True
    )
    second_ckpt = recoverer.find_checkpoint()
    assert first_ckpt.meta["unixtime"] < second_ckpt.meta["unixtime"]
    recoverer.save_and_keep_only(
        meta={"error": 6}, min_keys=["error"], keep_recent=True
    )
    third_ckpt = recoverer.find_checkpoint()
    remaining_ckpts = recoverer.list_checkpoints()
    assert first_ckpt not in remaining_ckpts
    assert second_ckpt in remaining_ckpts
    assert third_ckpt in remaining_ckpts

    # With equal importance criteria, the latest checkpoint should always be
    # returned
    fourth_ckpt = recoverer.save_checkpoint(meta={"error": 5})
    found_ckpt = recoverer.find_checkpoint(min_key="error")
    assert found_ckpt == fourth_ckpt
    fifth_ckpt = recoverer.save_checkpoint(meta={"error": 5})
    # Similarly for getting multiple checkpoints:
    found_ckpts = recoverer.find_checkpoints(
        min_key="error", max_num_checkpoints=2
    )
    assert found_ckpts == [fifth_ckpt, fourth_ckpt]
示例#2
0
def test_checkpoint_deletion(tmpdir, device):
    from speechbrain.utils.checkpoints import Checkpointer
    import torch

    class Recoverable(torch.nn.Module):
        def __init__(self, param):
            super().__init__()
            self.param = torch.nn.Parameter(
                torch.tensor([param], device=device)
            )

        def forward(self, x):
            return x * self.param

    recoverable = Recoverable(1.0)
    recoverables = {"recoverable": recoverable}
    recoverer = Checkpointer(tmpdir, recoverables)
    first_ckpt = recoverer.save_checkpoint()
    recoverer.delete_checkpoints()
    # Will not delete only checkpoint by default:
    assert first_ckpt in recoverer.list_checkpoints()
    second_ckpt = recoverer.save_checkpoint()
    recoverer.delete_checkpoints()
    # Oldest checkpoint is deleted by default:
    assert first_ckpt not in recoverer.list_checkpoints()
    # Other syntax also should work:
    recoverer.save_and_keep_only()
    assert second_ckpt not in recoverer.list_checkpoints()
    # Can delete all checkpoints:
    recoverer.delete_checkpoints(num_to_keep=0)
    assert not recoverer.list_checkpoints()

    # Now each should be kept:
    # Highest foo
    c1 = recoverer.save_checkpoint(meta={"foo": 2})
    # Latest CKPT after filtering
    c2 = recoverer.save_checkpoint(meta={"foo": 1})
    # Filtered out
    c3 = recoverer.save_checkpoint(meta={"epoch_ckpt": True})
    recoverer.delete_checkpoints(
        num_to_keep=1,
        max_keys=["foo"],
        importance_keys=[lambda c: c.meta["unixtime"]],
        ckpt_predicate=lambda c: "epoch_ckpt" not in c.meta,
    )
    assert all(c in recoverer.list_checkpoints() for c in [c1, c2, c3])
    # Reset:
    recoverer.delete_checkpoints(num_to_keep=0)
    assert not recoverer.list_checkpoints()

    # Test the keeping multiple checkpoints without predicate:
    # This should be deleted:
    c_to_delete = recoverer.save_checkpoint(meta={"foo": 2})
    # Highest foo
    c1 = recoverer.save_checkpoint(meta={"foo": 3})
    # Latest CKPT after filtering
    c2 = recoverer.save_checkpoint(meta={"foo": 1})
    recoverer.delete_checkpoints(
        num_to_keep=1,
        importance_keys=[lambda c: c.meta["unixtime"], lambda c: c.meta["foo"]],
    )
    assert all(c in recoverer.list_checkpoints() for c in [c1, c2])
    assert c_to_delete not in recoverer.list_checkpoints()