def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" if state_dict['patch_pos'].shape != model.patch_pos.shape: state_dict['patch_pos'] = resize_pos_embed( state_dict['patch_pos'], model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) return state_dict
def checkpoint_filter_fn(state_dict, model, args): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] for k, v in state_dict.items(): if args.pretrain_pos_only: if not k in ['pos_embed', 'head.weight', 'head.bias']: # head.weight and head.bias will be ignored in the timm library continue if k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed(v, model.pos_embed) elif k == 'patch_embed.proj.weight' and v.shape != model.patch_embed.proj.weight.shape: # Resize kernel _logger.warning("Patch size doesn't match. ") if True: _logger.warning('Downsample patch embedding') v = v.reshape(*v.shape[:2], args.patch, v.shape[2] // args.patch, args.patch, v.shape[3] // args.patch).sum(dim=[3, 5]) else: if args.patch_embed_scratch: _logger.warning('Use initialized patch embedding') continue elif args.downsample_factor: _logger.warning('Downsample patch embedding') v = v.reshape(*v.shape[:2], args.patch, v.shape[2] // args.patch, args.patch, v.shape[3] // args.patch).sum(dim=[3, 5]) else: _logger.warning( 'Downsample patch embedding with F.interpolate') v = F.interpolate(v, model.patch_embed.proj.weight.shape[-1]) out_dict[k] = v for key in model.state_dict().keys(): if not key in out_dict: _logger.warning('Initialized {}'.format(key)) return out_dict