Ejemplo n.º 1
0
def convert_vit_checkpoint(model_name,
                           pytorch_dump_folder_path,
                           base_model=True):
    """
    Copy/paste/tweak model's weights to our ViT structure.
    """

    # define default ViT configuration
    config = ViTConfig()
    # patch_size
    if model_name[-1] == "8":
        config.patch_size = 8
    # set labels if required
    if not base_model:
        config.num_labels = 1000
        repo_id = "datasets/huggingface/label-files"
        filename = "imagenet-1k-id2label.json"
        id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label.items()}
    # size of the architecture
    if model_name in ["dino_vits8", "dino_vits16"]:
        config.hidden_size = 384
        config.intermediate_size = 1536
        config.num_hidden_layers = 12
        config.num_attention_heads = 6

    # load original model from torch hub
    original_model = torch.hub.load("facebookresearch/dino:main", model_name)
    original_model.eval()

    # load state_dict of original model, remove and rename some keys
    state_dict = original_model.state_dict()
    if base_model:
        remove_classification_head_(state_dict)
    rename_keys = create_rename_keys(config, base_model=base_model)
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    read_in_q_k_v(state_dict, config, base_model)

    # load HuggingFace model
    if base_model:
        model = ViTModel(config, add_pooling_layer=False).eval()
    else:
        model = ViTForImageClassification(config).eval()
    model.load_state_dict(state_dict)

    # Check outputs on an image, prepared by ViTFeatureExtractor
    feature_extractor = ViTFeatureExtractor()
    encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
    pixel_values = encoding["pixel_values"]
    outputs = model(pixel_values)

    if base_model:
        final_hidden_state_cls_token = original_model(pixel_values)
        assert torch.allclose(final_hidden_state_cls_token,
                              outputs.last_hidden_state[:, 0, :],
                              atol=1e-1)
    else:
        logits = original_model(pixel_values)
        assert logits.shape == outputs.logits.shape
        assert torch.allclose(logits, outputs.logits, atol=1e-3)

    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)
    print(f"Saving feature extractor to {pytorch_dump_folder_path}")
    feature_extractor.save_pretrained(pytorch_dump_folder_path)
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our ViT structure.
    """

    # define default ViT configuration
    config = ViTConfig()
    base_model = False
    # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size
    if vit_name[-5:] == "in21k":
        base_model = True
        config.patch_size = int(vit_name[-12:-10])
        config.image_size = int(vit_name[-9:-6])
    else:
        config.num_labels = 1000
        repo_id = "datasets/huggingface/label-files"
        filename = "imagenet-1k-id2label.json"
        id2label = json.load(
            open(cached_download(hf_hub_url(repo_id, filename)), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label.items()}
        config.patch_size = int(vit_name[-6:-4])
        config.image_size = int(vit_name[-3:])
    # size of the architecture
    if "deit" in vit_name:
        if vit_name[9:].startswith("tiny"):
            config.hidden_size = 192
            config.intermediate_size = 768
            config.num_hidden_layers = 12
            config.num_attention_heads = 3
        elif vit_name[9:].startswith("small"):
            config.hidden_size = 384
            config.intermediate_size = 1536
            config.num_hidden_layers = 12
            config.num_attention_heads = 6
        else:
            pass
    else:
        if vit_name[4:].startswith("small"):
            config.hidden_size = 768
            config.intermediate_size = 2304
            config.num_hidden_layers = 8
            config.num_attention_heads = 8
        elif vit_name[4:].startswith("base"):
            pass
        elif vit_name[4:].startswith("large"):
            config.hidden_size = 1024
            config.intermediate_size = 4096
            config.num_hidden_layers = 24
            config.num_attention_heads = 16
        elif vit_name[4:].startswith("huge"):
            config.hidden_size = 1280
            config.intermediate_size = 5120
            config.num_hidden_layers = 32
            config.num_attention_heads = 16

    # load original model from timm
    timm_model = timm.create_model(vit_name, pretrained=True)
    timm_model.eval()

    # load state_dict of original model, remove and rename some keys
    state_dict = timm_model.state_dict()
    if base_model:
        remove_classification_head_(state_dict)
    rename_keys = create_rename_keys(config, base_model)
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    read_in_q_k_v(state_dict, config, base_model)

    # load HuggingFace model
    if vit_name[-5:] == "in21k":
        model = ViTModel(config).eval()
    else:
        model = ViTForImageClassification(config).eval()
    model.load_state_dict(state_dict)

    # Check outputs on an image, prepared by ViTFeatureExtractor/DeiTFeatureExtractor
    if "deit" in vit_name:
        feature_extractor = DeiTFeatureExtractor(size=config.image_size)
    else:
        feature_extractor = ViTFeatureExtractor(size=config.image_size)
    encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
    pixel_values = encoding["pixel_values"]
    outputs = model(pixel_values)

    if base_model:
        timm_pooled_output = timm_model.forward_features(pixel_values)
        assert timm_pooled_output.shape == outputs.pooler_output.shape
        assert torch.allclose(timm_pooled_output,
                              outputs.pooler_output,
                              atol=1e-3)
    else:
        timm_logits = timm_model(pixel_values)
        assert timm_logits.shape == outputs.logits.shape
        assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)

    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)
    print(f"Saving feature extractor to {pytorch_dump_folder_path}")
    feature_extractor.save_pretrained(pytorch_dump_folder_path)
def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our VisionEncoderDecoderModel structure.
    """
    # define encoder and decoder configs based on checkpoint_url
    encoder_config = ViTConfig(image_size=384, qkv_bias=False)
    decoder_config = TrOCRConfig()

    # size of the architecture
    if "base" in checkpoint_url:
        decoder_config.encoder_hidden_size = 768
    elif "large" in checkpoint_url:
        # use ViT-large encoder
        encoder_config.hidden_size = 1024
        encoder_config.intermediate_size = 4096
        encoder_config.num_hidden_layers = 24
        encoder_config.num_attention_heads = 16
        decoder_config.encoder_hidden_size = 1024
    else:
        raise ValueError(
            "Should either find 'base' or 'large' in checkpoint URL")

    # the large-printed + stage1 checkpoints uses sinusoidal position embeddings, no layernorm afterwards
    if "large-printed" in checkpoint_url or "stage1" in checkpoint_url:
        decoder_config.tie_word_embeddings = False
        decoder_config.activation_function = "relu"
        decoder_config.max_position_embeddings = 1024
        decoder_config.scale_embedding = True
        decoder_config.use_learned_position_embeddings = False
        decoder_config.layernorm_embedding = False

    # load HuggingFace model
    encoder = ViTModel(encoder_config, add_pooling_layer=False)
    decoder = TrOCRForCausalLM(decoder_config)
    model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
    model.eval()

    # load state_dict of original model, rename some keys
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url,
                                                    map_location="cpu",
                                                    check_hash=True)["model"]

    rename_keys = create_rename_keys(encoder_config, decoder_config)
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    read_in_q_k_v(state_dict, encoder_config)

    # remove parameters we don't need
    del state_dict["encoder.deit.head.weight"]
    del state_dict["encoder.deit.head.bias"]
    del state_dict["decoder.version"]

    # add prefix to decoder keys
    for key, val in state_dict.copy().items():
        val = state_dict.pop(key)
        if key.startswith("decoder") and "output_projection" not in key:
            state_dict["decoder.model." + key] = val
        else:
            state_dict[key] = val

    # load state dict
    model.load_state_dict(state_dict)

    # Check outputs on an image
    feature_extractor = ViTFeatureExtractor(size=encoder_config.image_size)
    tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
    processor = TrOCRProcessor(feature_extractor, tokenizer)

    pixel_values = processor(images=prepare_img(checkpoint_url),
                             return_tensors="pt").pixel_values

    # verify logits
    decoder_input_ids = torch.tensor(
        [[model.config.decoder.decoder_start_token_id]])
    outputs = model(pixel_values=pixel_values,
                    decoder_input_ids=decoder_input_ids)
    logits = outputs.logits

    expected_shape = torch.Size([1, 1, 50265])
    if "trocr-base-handwritten" in checkpoint_url:
        expected_slice = torch.tensor([
            -1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764,
            1.7560, 8.7358, -1.5311
        ])
    elif "trocr-large-handwritten" in checkpoint_url:
        expected_slice = torch.tensor([
            -2.6437, -1.3129, -2.2596, -5.3455, 6.3539, 1.7604, 5.4991, 1.4702,
            5.6113, 2.0170
        ])
    elif "trocr-base-printed" in checkpoint_url:
        expected_slice = torch.tensor([
            -5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284,
            -1.0232, -1.9661, -3.9210
        ])
    elif "trocr-large-printed" in checkpoint_url:
        expected_slice = torch.tensor([
            -6.0162, -7.0959, 4.4155, -5.1063, 7.0468, -3.1631, 2.6466,
            -0.3081, -0.8106, -1.7535
        ])

    if "stage1" not in checkpoint_url:
        assert logits.shape == expected_shape, "Shape of logits not as expected"
        assert torch.allclose(
            logits[0, 0, :10], expected_slice,
            atol=1e-3), "First elements of logits not as expected"

    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    print(f"Saving model to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)
    print(f"Saving processor to {pytorch_dump_folder_path}")
    processor.save_pretrained(pytorch_dump_folder_path)