def test_torch_meta(tmpdir): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter(torch.tensor([param])) def forward(self, x): return x * self.param recoverable = Recoverable(1.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) saved = recoverer.save_checkpoint( meta={"loss": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])}) loaded = recoverer.recover_if_possible() assert saved.meta["loss"].allclose(loaded.meta["loss"])
def test_recovery_custom_io(tmpdir): from speechbrain.utils.checkpoints import register_checkpoint_hooks from speechbrain.utils.checkpoints import mark_as_saver from speechbrain.utils.checkpoints import mark_as_loader from speechbrain.utils.checkpoints import Checkpointer @register_checkpoint_hooks class CustomRecoverable: def __init__(self, param): self.param = int(param) @mark_as_saver def save(self, path): with open(path, "w") as fo: fo.write(str(self.param)) @mark_as_loader def load(self, path, end_of_epoch, device): del end_of_epoch # Unused del device with open(path) as fi: self.param = int(fi.read()) custom_recoverable = CustomRecoverable(0) recoverer = Checkpointer(tmpdir, {"custom_recoverable": custom_recoverable}) custom_recoverable.param = 1 # First, make sure no checkpoints are found # (e.g. somehow tmpdir contaminated) ckpt = recoverer.recover_if_possible() assert ckpt is None ckpt = recoverer.save_checkpoint() custom_recoverable.param = 2 loaded_ckpt = recoverer.recover_if_possible() # Make sure we got the same thing: assert ckpt == loaded_ckpt # With this custom recoverable, the load is instant: assert custom_recoverable.param == 1
def test_torch_defaults(tmpdir): from speechbrain.utils.checkpoints import Checkpointer import torch module = torch.nn.Linear(10, 10) optimizer = torch.optim.Adam(module.parameters()) lr_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.1, 1.0, cycle_momentum=False) # ReduceLROnPlateau is on an _LRScheduler for some reason, so have a separate test for it another_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) checkpointer = Checkpointer( tmpdir, recoverables={ "module": module, "optimizer": optimizer, "scheduler": lr_scheduler, "scheduler2": another_scheduler, }, ) ckpt = checkpointer.save_checkpoint() # test the module: inp = torch.randn((3, 10)) prev_output = module(inp) # Re-initialize everything module = torch.nn.Linear(10, 10) optimizer = torch.optim.Adam(module.parameters()) lr_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.1, 1.0, cycle_momentum=False) another_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) checkpointer = Checkpointer( tmpdir, recoverables={ "module": module, "optimizer": optimizer, "scheduler": lr_scheduler, "scheduler2": another_scheduler, }, ) checkpointer.load_checkpoint(ckpt) assert torch.allclose(module(inp), prev_output)
def test_checkpointer(tmpdir, device): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter(torch.tensor([param])) def forward(self, x): return x * self.param recoverable = Recoverable(2.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) recoverable.param.data = torch.tensor([1.0], device=device) # Should not be possible since no checkpoint saved yet: assert not recoverer.recover_if_possible() result = recoverable(10.0) # Check that parameter has not been loaded from original value: assert recoverable.param.data == torch.tensor([1.0], device=device) ckpt = recoverer.save_checkpoint() # Check that the name recoverable has a save file: # NOTE: Here assuming .pt filename; if convention changes, change test assert (ckpt.path / "recoverable.ckpt").exists() # Check that saved checkpoint is found, and location correct: assert recoverer.list_checkpoints()[0] == ckpt assert recoverer.list_checkpoints()[0].path.parent == tmpdir recoverable.param.data = torch.tensor([2.0], device=device) recoverer.recover_if_possible() # Check that parameter has been loaded immediately: assert recoverable.param.data == torch.tensor([1.0], device=device) result = recoverable(10.0) # And result correct assert result == 10.0 other = Recoverable(2.0) recoverer.add_recoverable("other", other) # Check that both objects are now found: assert recoverer.recoverables["recoverable"] == recoverable assert recoverer.recoverables["other"] == other new_ckpt = recoverer.save_checkpoint() # Check that now both recoverables have a save file: assert (new_ckpt.path / "recoverable.ckpt").exists() assert (new_ckpt.path / "other.ckpt").exists() assert new_ckpt in recoverer.list_checkpoints() recoverable.param.data = torch.tensor([2.0], device=device) other.param.data = torch.tensor([10.0], device=device) chosen_ckpt = recoverer.recover_if_possible() # Should choose newest by default: assert chosen_ckpt == new_ckpt # Check again that parameters have been loaded immediately: assert recoverable.param.data == torch.tensor([1.0], device=device) assert other.param.data == torch.tensor([2.0], device=device) other_result = other(10.0) # And again we should have the correct computations: assert other_result == 20.0 # Recover from oldest, which does not have "other": # This also tests a custom sort # Raises by default: with pytest.raises(RuntimeError): chosen_ckpt = recoverer.recover_if_possible( importance_key=lambda x: -x.meta["unixtime"] ) # However this operation may have loaded the first object # so let's set the values manually: recoverable.param.data = torch.tensor([2.0], device=device) other.param.data = torch.tensor([10.0], device=device) recoverer.allow_partial_load = True chosen_ckpt = recoverer.recover_if_possible( importance_key=lambda x: -x.meta["unixtime"] ) # Should have chosen the original: assert chosen_ckpt == ckpt # And should recover recoverable: assert recoverable.param.data == torch.tensor([1.0], device=device) # But not other: other_result = other(10.0) assert other.param.data == torch.tensor([10.0], device=device) assert other_result == 100.0 # Test saving names checkpoints with meta info, and custom filter epoch_ckpt = recoverer.save_checkpoint(name="ep1", meta={"loss": 2.0}) assert "ep1" in epoch_ckpt.path.name other.param.data = torch.tensor([2.0], device=device) recoverer.save_checkpoint(meta={"loss": 3.0}) chosen_ckpt = recoverer.recover_if_possible( ckpt_predicate=lambda ckpt: "loss" in ckpt.meta, importance_key=lambda ckpt: -ckpt.meta["loss"], ) assert chosen_ckpt == epoch_ckpt assert other.param.data == torch.tensor([10.0], device=device) # Make sure checkpoints can't be name saved by the same name with pytest.raises(FileExistsError): recoverer.save_checkpoint(name="ep1")
def test_checkpoint_hook_register(tmpdir): from speechbrain.utils.checkpoints import register_checkpoint_hooks from speechbrain.utils.checkpoints import mark_as_saver from speechbrain.utils.checkpoints import mark_as_loader from speechbrain.utils.checkpoints import Checkpointer # First a proper interface: @register_checkpoint_hooks class CustomRecoverable: def __init__(self, param): self.param = int(param) @mark_as_saver def save(self, path): with open(path, "w") as fo: fo.write(str(self.param)) @mark_as_loader def load(self, path, end_of_epoch, device): del end_of_epoch # Unused with open(path) as fi: self.param = int(fi.read()) recoverable = CustomRecoverable(1.0) checkpointer = Checkpointer(tmpdir, {"recoverable": recoverable}) checkpointer.save_checkpoint() recoverable.param = 2.0 checkpointer.recover_if_possible() assert recoverable.param == 1.0 # Improper interfaces: with pytest.raises(TypeError): class BadRecoverable: def __init__(self, param): self.param = int(param) def save(self, path): with open(path, "w") as fo: fo.write(str(self.param)) @mark_as_loader def load(self, path, end_of_epoch): # MISSING device del end_of_epoch # Unused with open(path) as fi: self.param = int(fi.read()) with pytest.raises(TypeError): class BadRecoverable: # noqa: F811 def __init__(self, param): self.param = int(param) @mark_as_saver def save(self, path, extra_arg): # Extra argument with open(path, "w") as fo: fo.write(str(self.param)) def load(self, path, end_of_epoch, device): del end_of_epoch # Unused with open(path) as fi: self.param = int(fi.read())
def test_multiple_ckpts_and_criteria(tmpdir): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter(torch.tensor([param])) def forward(self, x): return x * self.param recoverable = Recoverable(1.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) # Here testing multiple checkpoints with equal meta criteria recoverer.save_and_keep_only( meta={"error": 5}, min_keys=["error"], keep_recent=True ) # By default, get the most recent one: first_ckpt = recoverer.find_checkpoint() recoverer.save_and_keep_only( meta={"error": 5}, min_keys=["error"], keep_recent=True ) second_ckpt = recoverer.find_checkpoint() assert first_ckpt.meta["unixtime"] < second_ckpt.meta["unixtime"] recoverer.save_and_keep_only( meta={"error": 6}, min_keys=["error"], keep_recent=True ) third_ckpt = recoverer.find_checkpoint() remaining_ckpts = recoverer.list_checkpoints() assert first_ckpt not in remaining_ckpts assert second_ckpt in remaining_ckpts assert third_ckpt in remaining_ckpts # With equal importance criteria, the latest checkpoint should always be # returned fourth_ckpt = recoverer.save_checkpoint(meta={"error": 5}) found_ckpt = recoverer.find_checkpoint(min_key="error") assert found_ckpt == fourth_ckpt fifth_ckpt = recoverer.save_checkpoint(meta={"error": 5}) # Similarly for getting multiple checkpoints: found_ckpts = recoverer.find_checkpoints( min_key="error", max_num_checkpoints=2 ) assert found_ckpts == [fifth_ckpt, fourth_ckpt]
def test_checkpoint_deletion(tmpdir, device): from speechbrain.utils.checkpoints import Checkpointer import torch class Recoverable(torch.nn.Module): def __init__(self, param): super().__init__() self.param = torch.nn.Parameter( torch.tensor([param], device=device) ) def forward(self, x): return x * self.param recoverable = Recoverable(1.0) recoverables = {"recoverable": recoverable} recoverer = Checkpointer(tmpdir, recoverables) first_ckpt = recoverer.save_checkpoint() recoverer.delete_checkpoints() # Will not delete only checkpoint by default: assert first_ckpt in recoverer.list_checkpoints() second_ckpt = recoverer.save_checkpoint() recoverer.delete_checkpoints() # Oldest checkpoint is deleted by default: assert first_ckpt not in recoverer.list_checkpoints() # Other syntax also should work: recoverer.save_and_keep_only() assert second_ckpt not in recoverer.list_checkpoints() # Can delete all checkpoints: recoverer.delete_checkpoints(num_to_keep=0) assert not recoverer.list_checkpoints() # Now each should be kept: # Highest foo c1 = recoverer.save_checkpoint(meta={"foo": 2}) # Latest CKPT after filtering c2 = recoverer.save_checkpoint(meta={"foo": 1}) # Filtered out c3 = recoverer.save_checkpoint(meta={"epoch_ckpt": True}) recoverer.delete_checkpoints( num_to_keep=1, max_keys=["foo"], importance_keys=[lambda c: c.meta["unixtime"]], ckpt_predicate=lambda c: "epoch_ckpt" not in c.meta, ) assert all(c in recoverer.list_checkpoints() for c in [c1, c2, c3]) # Reset: recoverer.delete_checkpoints(num_to_keep=0) assert not recoverer.list_checkpoints() # Test the keeping multiple checkpoints without predicate: # This should be deleted: c_to_delete = recoverer.save_checkpoint(meta={"foo": 2}) # Highest foo c1 = recoverer.save_checkpoint(meta={"foo": 3}) # Latest CKPT after filtering c2 = recoverer.save_checkpoint(meta={"foo": 1}) recoverer.delete_checkpoints( num_to_keep=1, importance_keys=[lambda c: c.meta["unixtime"], lambda c: c.meta["foo"]], ) assert all(c in recoverer.list_checkpoints() for c in [c1, c2]) assert c_to_delete not in recoverer.list_checkpoints()
def test_epoch_loop_recovery(tmpdir): from speechbrain.utils.checkpoints import Checkpointer from speechbrain.utils.epoch_loop import EpochCounter epoch_counter = EpochCounter(2) recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter}) for epoch in epoch_counter: assert epoch == 1 # Save a mid-epoch checkpoint: recoverer.save_checkpoint(end_of_epoch=False) # Simulate interruption break # Now after recovery still at epoch 1: recoverer.recover_if_possible() second_epoch = False # Will manually update this for epoch in epoch_counter: if not second_epoch: assert epoch == 1 recoverer.save_checkpoint(end_of_epoch=True) second_epoch = True else: assert epoch == 2 # Again simulate interruption break # Now after recovery we are in epoch 2: recoverer.recover_if_possible() loop_runs = 0 for epoch in epoch_counter: assert epoch == 2 loop_runs += 1 recoverer.save_checkpoint(end_of_epoch=True) # And that is that: assert loop_runs == 1 # And now after recovery, no more epochs: recoverer.recover_if_possible() for epoch in epoch_counter: # Will not get here: assert False