def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our ViT structure. """ # define default ViT configuration config = ViTConfig() base_model = False # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size if vit_name[-5:] == "in21k": base_model = True config.patch_size = int(vit_name[-12:-10]) config.image_size = int(vit_name[-9:-6]) else: config.num_labels = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(vit_name[-6:-4]) config.image_size = int(vit_name[-3:]) # size of the architecture if "deit" in vit_name: if vit_name[9:].startswith("tiny"): config.hidden_size = 192 config.intermediate_size = 768 config.num_hidden_layers = 12 config.num_attention_heads = 3 elif vit_name[9:].startswith("small"): config.hidden_size = 384 config.intermediate_size = 1536 config.num_hidden_layers = 12 config.num_attention_heads = 6 else: pass else: if vit_name[4:].startswith("small"): config.hidden_size = 768 config.intermediate_size = 2304 config.num_hidden_layers = 8 config.num_attention_heads = 8 elif vit_name[4:].startswith("base"): pass elif vit_name[4:].startswith("large"): config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 elif vit_name[4:].startswith("huge"): config.hidden_size = 1280 config.intermediate_size = 5120 config.num_hidden_layers = 32 config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(vit_name, pretrained=True) timm_model.eval() # load state_dict of original model, remove and rename some keys state_dict = timm_model.state_dict() if base_model: remove_classification_head_(state_dict) rename_keys = create_rename_keys(config, base_model) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model if vit_name[-5:] == "in21k": model = ViTModel(config).eval() else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) # Check outputs on an image, prepared by ViTFeatureExtractor/DeiTFeatureExtractor if "deit" in vit_name: feature_extractor = DeiTFeatureExtractor(size=config.image_size) else: feature_extractor = ViTFeatureExtractor(size=config.image_size) encoding = feature_extractor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) assert timm_pooled_output.shape == outputs.pooler_output.shape assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {vit_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)
def default_feature_extractor(self): return (DeiTFeatureExtractor.from_pretrained( "facebook/deit-base-distilled-patch16-224") if is_vision_available() else None)
def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our DeiT structure. """ # define default DeiT configuration config = DeiTConfig() # all deit models have fine-tuned heads base_model = False # dataset (fine-tuned on ImageNet 2012), patch_size and image_size config.num_labels = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load( open(cached_download(hf_hub_url(repo_id, filename)), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(deit_name[-6:-4]) config.image_size = int(deit_name[-3:]) # size of the architecture if deit_name[9:].startswith("tiny"): config.hidden_size = 192 config.intermediate_size = 768 config.num_hidden_layers = 12 config.num_attention_heads = 3 elif deit_name[9:].startswith("small"): config.hidden_size = 384 config.intermediate_size = 1536 config.num_hidden_layers = 12 config.num_attention_heads = 6 if deit_name[9:].startswith("base"): pass elif deit_name[4:].startswith("large"): config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(deit_name, pretrained=True) timm_model.eval() # load state_dict of original model, remove and rename some keys state_dict = timm_model.state_dict() rename_keys = create_rename_keys(config, base_model) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model model = DeiTForImageClassificationWithTeacher(config).eval() model.load_state_dict(state_dict) # Check outputs on an image, prepared by DeiTFeatureExtractor size = int( (256 / 224) * config.image_size ) # to maintain same ratio w.r.t. 224 images, see https://github.com/facebookresearch/deit/blob/ab5715372db8c6cad5740714b2216d55aeae052e/datasets.py#L103 feature_extractor = DeiTFeatureExtractor(size=size, crop_size=config.image_size) encoding = feature_extractor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {deit_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving feature extractor to {pytorch_dump_folder_path}") feature_extractor.save_pretrained(pytorch_dump_folder_path)