예제 #1
0
def load_checkpoint(filename: Pathlike,
                    model: AcousticModel) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'epoch', 'learning_rate', 'objf', 'valid_objf',
        'num_features', 'num_classes', 'subsampling_factor',
        'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict)
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint
예제 #2
0
def load_checkpoint(
    filename: Pathlike,
    model: AcousticModel,
    optimizer: Optional[object] = None,
    scheduler: Optional[object] = None,
    scaler: Optional[GradScaler] = None,
) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if isinstance(model, DistributedDataParallel):
        model = model.module

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    # Note we used strict=False above so that the current code
    # can load models trained with P_scores.

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler'])

    if scaler is not None:
        scaler.load_state_dict(checkpoint['grad_scaler'])

    return checkpoint
예제 #3
0
def average_checkpoint(filenames: List[Pathlike],
                       model: AcousticModel) -> Dict[str, Any]:
    logging.info('average over checkpoints {}'.format(filenames))

    avg_model = None

    # sum
    for filename in filenames:
        checkpoint = torch.load(filename, map_location='cpu')
        checkpoint_model = checkpoint['state_dict']
        if avg_model is None:
            avg_model = checkpoint_model
        else:
            for k in avg_model.keys():
                avg_model[k] += checkpoint_model[k]
    # average
    for k in avg_model.keys():
        if avg_model[k] is not None:
            if avg_model[k].is_floating_point():
                avg_model[k] /= len(filenames)
            else:
                avg_model[k] //= len(filenames)

    checkpoint['state_dict'] = avg_model

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint
예제 #4
0
def save_checkpoint(filename: Pathlike,
                    model: AcousticModel,
                    epoch: int,
                    learning_rate: float,
                    objf: float,
                    valid_objf: float,
                    global_batch_idx_train: int,
                    local_rank: int = 0) -> None:
    if local_rank is not None and local_rank != 0:
        return
    logging.info(
        f'Save checkpoint to {filename}: epoch={epoch}, '
        f'learning_rate={learning_rate}, objf={objf}, valid_objf={valid_objf}')
    checkpoint = {
        'state_dict': model.state_dict(),
        'num_features': model.num_features,
        'num_classes': model.num_classes,
        'subsampling_factor': model.subsampling_factor,
        'epoch': epoch,
        'learning_rate': learning_rate,
        'objf': objf,
        'valid_objf': valid_objf,
        'global_batch_idx_train': global_batch_idx_train,
    }
    torch.save(checkpoint, filename)
예제 #5
0
def save_checkpoint(filename: Pathlike,
                    model: AcousticModel,
                    epoch: int,
                    learning_rate: float,
                    objf: float,
                    local_rank: int = 0) -> None:
    if local_rank is not None and local_rank != 0:
        return
    logging.info('Save checkpoint to {filename}: epoch={epoch}, '
                 'learning_rate={learning_rate}, objf={objf}'.format(
                     filename=filename,
                     epoch=epoch,
                     learning_rate=learning_rate,
                     objf=objf))
    checkpoint = {
        'state_dict': model.state_dict(),
        'num_features': model.num_features,
        'num_classes': model.num_classes,
        'subsampling_factor': model.subsampling_factor,
        'epoch': epoch,
        'learning_rate': learning_rate,
        'objf': objf
    }
    torch.save(checkpoint, filename)