Beispiel #1
0
def checkpoint_filter_fn(state_dict, model):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    if state_dict['patch_pos'].shape != model.patch_pos.shape:
        state_dict['patch_pos'] = resize_pos_embed(
            state_dict['patch_pos'], model.patch_pos,
            getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
    return state_dict
Beispiel #2
0
def checkpoint_filter_fn(state_dict, model, args):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    if 'model' in state_dict:
        # For deit models
        state_dict = state_dict['model']
    for k, v in state_dict.items():

        if args.pretrain_pos_only:
            if not k in ['pos_embed', 'head.weight', 'head.bias']:
                # head.weight and head.bias will be ignored in the timm library
                continue

        if k == 'pos_embed' and v.shape != model.pos_embed.shape:
            # To resize pos embedding when using model at different size from pretrained weights
            v = resize_pos_embed(v, model.pos_embed)
        elif k == 'patch_embed.proj.weight' and v.shape != model.patch_embed.proj.weight.shape:
            # Resize kernel
            _logger.warning("Patch size doesn't match. ")

            if True:
                _logger.warning('Downsample patch embedding')
                v = v.reshape(*v.shape[:2], args.patch,
                              v.shape[2] // args.patch, args.patch,
                              v.shape[3] // args.patch).sum(dim=[3, 5])
            else:
                if args.patch_embed_scratch:
                    _logger.warning('Use initialized patch embedding')
                    continue
                elif args.downsample_factor:
                    _logger.warning('Downsample patch embedding')
                    v = v.reshape(*v.shape[:2], args.patch,
                                  v.shape[2] // args.patch, args.patch,
                                  v.shape[3] // args.patch).sum(dim=[3, 5])
                else:
                    _logger.warning(
                        'Downsample patch embedding with F.interpolate')
                    v = F.interpolate(v,
                                      model.patch_embed.proj.weight.shape[-1])
        out_dict[k] = v
    for key in model.state_dict().keys():
        if not key in out_dict:
            _logger.warning('Initialized {}'.format(key))
    return out_dict