def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our VisionEncoderDecoderModel structure. """ # define encoder and decoder configs based on checkpoint_url encoder_config = ViTConfig(image_size=384, qkv_bias=False) decoder_config = TrOCRConfig() # size of the architecture if "base" in checkpoint_url: decoder_config.encoder_hidden_size = 768 elif "large" in checkpoint_url: # use ViT-large encoder encoder_config.hidden_size = 1024 encoder_config.intermediate_size = 4096 encoder_config.num_hidden_layers = 24 encoder_config.num_attention_heads = 16 decoder_config.encoder_hidden_size = 1024 else: raise ValueError( "Should either find 'base' or 'large' in checkpoint URL") # the large-printed + stage1 checkpoints uses sinusoidal position embeddings, no layernorm afterwards if "large-printed" in checkpoint_url or "stage1" in checkpoint_url: decoder_config.tie_word_embeddings = False decoder_config.activation_function = "relu" decoder_config.max_position_embeddings = 1024 decoder_config.scale_embedding = True decoder_config.use_learned_position_embeddings = False decoder_config.layernorm_embedding = False # load HuggingFace model encoder = ViTModel(encoder_config, add_pooling_layer=False) decoder = TrOCRForCausalLM(decoder_config) model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) model.eval() # load state_dict of original model, rename some keys state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] rename_keys = create_rename_keys(encoder_config, decoder_config) for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, encoder_config) # remove parameters we don't need del state_dict["encoder.deit.head.weight"] del state_dict["encoder.deit.head.bias"] del state_dict["decoder.version"] # add prefix to decoder keys for key, val in state_dict.copy().items(): val = state_dict.pop(key) if key.startswith("decoder") and "output_projection" not in key: state_dict["decoder.model." + key] = val else: state_dict[key] = val # load state dict model.load_state_dict(state_dict) # Check outputs on an image feature_extractor = ViTFeatureExtractor(size=encoder_config.image_size) tokenizer = RobertaTokenizer.from_pretrained("roberta-large") processor = TrOCRProcessor(feature_extractor, tokenizer) pixel_values = processor(images=prepare_img(checkpoint_url), return_tensors="pt").pixel_values # verify logits decoder_input_ids = torch.tensor( [[model.config.decoder.decoder_start_token_id]]) outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) logits = outputs.logits expected_shape = torch.Size([1, 1, 50265]) if "trocr-base-handwritten" in checkpoint_url: expected_slice = torch.tensor([ -1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311 ]) elif "trocr-large-handwritten" in checkpoint_url: expected_slice = torch.tensor([ -2.6437, -1.3129, -2.2596, -5.3455, 6.3539, 1.7604, 5.4991, 1.4702, 5.6113, 2.0170 ]) elif "trocr-base-printed" in checkpoint_url: expected_slice = torch.tensor([ -5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210 ]) elif "trocr-large-printed" in checkpoint_url: expected_slice = torch.tensor([ -6.0162, -7.0959, 4.4155, -5.1063, 7.0468, -3.1631, 2.6466, -0.3081, -0.8106, -1.7535 ]) if "stage1" not in checkpoint_url: assert logits.shape == expected_shape, "Shape of logits not as expected" assert torch.allclose( logits[0, 0, :10], expected_slice, atol=1e-3), "First elements of logits not as expected" Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving processor to {pytorch_dump_folder_path}") processor.save_pretrained(pytorch_dump_folder_path)
def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): # load original model original_model = DonutModel.from_pretrained(model_name).eval() # load HuggingFace model encoder_config, decoder_config = get_configs(original_model) encoder = DonutSwinModel(encoder_config) decoder = MBartForCausalLM(decoder_config) model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) model.eval() state_dict = original_model.state_dict() new_state_dict = convert_state_dict(state_dict, model) model.load_state_dict(new_state_dict) # verify results on scanned document dataset = load_dataset("hf-internal-testing/example-documents") image = dataset["test"][0]["image"].convert("RGB") tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name, from_slow=True) feature_extractor = DonutFeatureExtractor( do_align_long_axis=original_model.config.align_long_axis, size=original_model.config.input_size[::-1]) processor = DonutProcessor(feature_extractor, tokenizer) pixel_values = processor(image, return_tensors="pt").pixel_values if model_name == "naver-clova-ix/donut-base-finetuned-docvqa": task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" question = "When is the coffee break?" task_prompt = task_prompt.replace("{user_input}", question) elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip": task_prompt = "<s_rvlcdip>" elif model_name in [ "naver-clova-ix/donut-base-finetuned-cord-v1", "naver-clova-ix/donut-base-finetuned-cord-v1-2560", ]: task_prompt = "<s_cord>" elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2": task_prompt = "s_cord-v2>" elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket": task_prompt = "<s_zhtrainticket>" elif model_name in [ "naver-clova-ix/donut-proto", "naver-clova-ix/donut-base" ]: # use a random prompt task_prompt = "hello world" else: raise ValueError("Model name not supported") prompt_tensors = original_model.decoder.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] original_patch_embed = original_model.encoder.model.patch_embed( pixel_values) patch_embeddings, _ = model.encoder.embeddings(pixel_values) assert torch.allclose(original_patch_embed, patch_embeddings, atol=1e-3) # verify encoder hidden states original_last_hidden_state = original_model.encoder(pixel_values) last_hidden_state = model.encoder(pixel_values).last_hidden_state assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2) # verify decoder hidden states original_logits = original_model(pixel_values, prompt_tensors, None).logits logits = model(pixel_values, decoder_input_ids=prompt_tensors).logits assert torch.allclose(original_logits, logits, atol=1e-3) print("Looks ok!") if pytorch_dump_folder_path is not None: print(f"Saving model and processor to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) processor.save_pretrained(pytorch_dump_folder_path) if push_to_hub: model.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model") processor.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model")