def load_checkpoint(filename: Pathlike, model: AcousticModel) -> Dict[str, Any]: logging.info('load checkpoint from {}'.format(filename)) checkpoint = torch.load(filename, map_location='cpu') keys = [ 'state_dict', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict) else: model.load_state_dict(checkpoint['state_dict']) model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] return checkpoint
def load_checkpoint( filename: Pathlike, model: AcousticModel, optimizer: Optional[object] = None, scheduler: Optional[object] = None, scaler: Optional[GradScaler] = None, ) -> Dict[str, Any]: logging.info('load checkpoint from {}'.format(filename)) checkpoint = torch.load(filename, map_location='cpu') keys = [ 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if isinstance(model, DistributedDataParallel): model = model.module if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=False) else: model.load_state_dict(checkpoint['state_dict'], strict=False) # Note we used strict=False above so that the current code # can load models trained with P_scores. model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer']) if scheduler is not None: scheduler.load_state_dict(checkpoint['scheduler']) if scaler is not None: scaler.load_state_dict(checkpoint['grad_scaler']) return checkpoint
def average_checkpoint(filenames: List[Pathlike], model: AcousticModel) -> Dict[str, Any]: logging.info('average over checkpoints {}'.format(filenames)) avg_model = None # sum for filename in filenames: checkpoint = torch.load(filename, map_location='cpu') checkpoint_model = checkpoint['state_dict'] if avg_model is None: avg_model = checkpoint_model else: for k in avg_model.keys(): avg_model[k] += checkpoint_model[k] # average for k in avg_model.keys(): if avg_model[k] is not None: if avg_model[k].is_floating_point(): avg_model[k] /= len(filenames) else: avg_model[k] //= len(filenames) checkpoint['state_dict'] = avg_model keys = [ 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=False) else: model.load_state_dict(checkpoint['state_dict'], strict=False) model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] return checkpoint