Exemplo n.º 1
0
def load_checkpoint(filename: Pathlike,
                    model: AcousticModel) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'epoch', 'learning_rate', 'objf', 'valid_objf',
        'num_features', 'num_classes', 'subsampling_factor',
        'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict)
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint
Exemplo n.º 2
0
def load_checkpoint(
    filename: Pathlike,
    model: AcousticModel,
    optimizer: Optional[object] = None,
    scheduler: Optional[object] = None,
    scaler: Optional[GradScaler] = None,
) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if isinstance(model, DistributedDataParallel):
        model = model.module

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    # Note we used strict=False above so that the current code
    # can load models trained with P_scores.

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler'])

    if scaler is not None:
        scaler.load_state_dict(checkpoint['grad_scaler'])

    return checkpoint
Exemplo n.º 3
0
def average_checkpoint(filenames: List[Pathlike],
                       model: AcousticModel) -> Dict[str, Any]:
    logging.info('average over checkpoints {}'.format(filenames))

    avg_model = None

    # sum
    for filename in filenames:
        checkpoint = torch.load(filename, map_location='cpu')
        checkpoint_model = checkpoint['state_dict']
        if avg_model is None:
            avg_model = checkpoint_model
        else:
            for k in avg_model.keys():
                avg_model[k] += checkpoint_model[k]
    # average
    for k in avg_model.keys():
        if avg_model[k] is not None:
            if avg_model[k].is_floating_point():
                avg_model[k] /= len(filenames)
            else:
                avg_model[k] //= len(filenames)

    checkpoint['state_dict'] = avg_model

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint