def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location="cpu") unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () return model
def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () return checkpoint["epoch"]
def load_checkpoint(checkpoint_file, model, optimizer=None, replace=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location="cpu") if replace is not None: checkpoint["model_state"] = OrderedDict([ (k.replace('se', replace), v) if '.se.' in k else (k, v) for k, v in checkpoint["model_state"].items() ]) unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict( checkpoint["optimizer_state"]) if optimizer else () return checkpoint["epoch"]
def save_ckpt(model, out=None): # save student weights checkpoint_file = 'model.pyth' if out is None else out checkpoint = { "epoch": 0, "model_state": unwrap_model(model).state_dict(), } torch.save(checkpoint, checkpoint_file) return checkpoint_file
def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None): """ Loads a checkpoint selectively based on the input options. Each checkpoint contains both the model and model_ema weights (except checkpoints created by old versions of the code). If both the model and model_weights are requested, both sets of weights are loaded. If only the model weights are requested (that is if model_ema=None), the *better* set of weights is selected to be loaded (according to the lesser of test_err and ema_err, also stored in the checkpoint). The code is backward compatible with checkpoints that do not store the ema weights. """ err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") # Get test_err and ema_err (with backward compatibility) test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100 ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100 # Load model and optionally model_ema weights (with backward compatibility) ema_state = "ema_state" if "ema_state" in checkpoint else "model_state" if model_ema: unwrap_model(model).load_state_dict(checkpoint["model_state"]) unwrap_model(model_ema).load_state_dict(checkpoint[ema_state]) else: best_state = "model_state" if test_err <= ema_err else ema_state unwrap_model(model).load_state_dict(checkpoint[best_state]) # Load optimizer if requested if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) return checkpoint["epoch"], test_err, ema_err
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err): """Saves a checkpoint and also the best weights so far in a best checkpoint.""" # Save checkpoints only from the main process if not dist.is_main_proc(): return # Ensure that the checkpoint dir exists pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, "test_err": test_err, "ema_err": ema_err, "model_state": unwrap_model(model).state_dict(), "ema_state": unwrap_model(model_ema).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # Store the best model and model_ema weights so far if not pathmgr.exists(get_checkpoint_best()): pathmgr.copy(checkpoint_file, get_checkpoint_best()) else: with pathmgr.open(get_checkpoint_best(), "rb") as f: best = torch.load(f, map_location="cpu") # Select the best model weights and the best model_ema weights if test_err < best["test_err"] or ema_err < best["ema_err"]: if test_err < best["test_err"]: best["model_state"] = checkpoint["model_state"] best["test_err"] = test_err if ema_err < best["ema_err"]: best["ema_state"] = checkpoint["ema_state"] best["ema_err"] = ema_err with pathmgr.open(get_checkpoint_best(), "wb") as f: torch.save(best, f) return checkpoint_file
def save_checkpoint(model, optimizer, epoch): """Saves a checkpoint.""" # Save checkpoints only from the master process if not dist.is_master_proc(): return # Ensure that the checkpoint dir exists os.makedirs(get_checkpoint_dir(), exist_ok=True) # Record the state checkpoint = { "epoch": epoch, "model_state": unwrap_model(model).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) torch.save(checkpoint, checkpoint_file) return checkpoint_file
def save_checkpoint(model, optimizer, epoch, best): """Saves a checkpoint.""" # Save checkpoints only from the master process if not dist.is_master_proc(): return # Ensure that the checkpoint dir exists pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, "model_state": unwrap_model(model).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # If best copy checkpoint to the best checkpoint if best: pathmgr.copy(checkpoint_file, get_checkpoint_best()) return checkpoint_file