def test_multiple_ckpts_and_criteria(tmpdir): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter(torch.tensor([param])) def forward(self, x): return x * self.param recoverable = Recoverable(1.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) # Here testing multiple checkpoints with equal meta criteria recoverer.save_and_keep_only( meta={"error": 5}, min_keys=["error"], keep_recent=True ) # By default, get the most recent one: first_ckpt = recoverer.find_checkpoint() recoverer.save_and_keep_only( meta={"error": 5}, min_keys=["error"], keep_recent=True ) second_ckpt = recoverer.find_checkpoint() assert first_ckpt.meta["unixtime"] < second_ckpt.meta["unixtime"] recoverer.save_and_keep_only( meta={"error": 6}, min_keys=["error"], keep_recent=True ) third_ckpt = recoverer.find_checkpoint() remaining_ckpts = recoverer.list_checkpoints() assert first_ckpt not in remaining_ckpts assert second_ckpt in remaining_ckpts assert third_ckpt in remaining_ckpts # With equal importance criteria, the latest checkpoint should always be # returned fourth_ckpt = recoverer.save_checkpoint(meta={"error": 5}) found_ckpt = recoverer.find_checkpoint(min_key="error") assert found_ckpt == fourth_ckpt fifth_ckpt = recoverer.save_checkpoint(meta={"error": 5}) # Similarly for getting multiple checkpoints: found_ckpts = recoverer.find_checkpoints( min_key="error", max_num_checkpoints=2 ) assert found_ckpts == [fifth_ckpt, fourth_ckpt]
def test_checkpoint_deletion(tmpdir, device): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter( torch.tensor([param], device=device) ) def forward(self, x): return x * self.param recoverable = Recoverable(1.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) first_ckpt = recoverer.save_checkpoint() recoverer.delete_checkpoints() # Will not delete only checkpoint by default: assert first_ckpt in recoverer.list_checkpoints() second_ckpt = recoverer.save_checkpoint() recoverer.delete_checkpoints() # Oldest checkpoint is deleted by default: assert first_ckpt not in recoverer.list_checkpoints() # Other syntax also should work: recoverer.save_and_keep_only() assert second_ckpt not in recoverer.list_checkpoints() # Can delete all checkpoints: recoverer.delete_checkpoints(num_to_keep=0) assert not recoverer.list_checkpoints() # Now each should be kept: # Highest foo c1 = recoverer.save_checkpoint(meta={"foo": 2}) # Latest CKPT after filtering c2 = recoverer.save_checkpoint(meta={"foo": 1}) # Filtered out c3 = recoverer.save_checkpoint(meta={"epoch_ckpt": True}) recoverer.delete_checkpoints( num_to_keep=1, max_keys=["foo"], importance_keys=[lambda c: c.meta["unixtime"]], ckpt_predicate=lambda c: "epoch_ckpt" not in c.meta, ) assert all(c in recoverer.list_checkpoints() for c in [c1, c2, c3]) # Reset: recoverer.delete_checkpoints(num_to_keep=0) assert not recoverer.list_checkpoints() # Test the keeping multiple checkpoints without predicate: # This should be deleted: c_to_delete = recoverer.save_checkpoint(meta={"foo": 2}) # Highest foo c1 = recoverer.save_checkpoint(meta={"foo": 3}) # Latest CKPT after filtering c2 = recoverer.save_checkpoint(meta={"foo": 1}) recoverer.delete_checkpoints( num_to_keep=1, importance_keys=[lambda c: c.meta["unixtime"], lambda c: c.meta["foo"]], ) assert all(c in recoverer.list_checkpoints() for c in [c1, c2]) assert c_to_delete not in recoverer.list_checkpoints()