def create_and_check_model(self, config, pixel_values, labels, pixel_labels): model = Data2VecVisionModel(config=config) model.to(torch_device) model.eval() result = model(pixel_values) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) num_patches = (self.image_size // self.patch_size)**2 self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def test_model_from_pretrained(self): for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = Data2VecVisionModel.from_pretrained(model_name) self.assertIsNotNone(model)
def main(): args = get_args() is_finetuned = "ft1k" in args.hf_checkpoint_name is_large = "large" in args.hf_checkpoint_name if is_finetuned: # To convert Beit's data2vec_vision to HF you need to copy # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py # into this folder. import modeling_finetune # noqa: F401 else: # To convert Beit's data2vec_vision to HF you need to copy # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py # into this folder # IMPORTANT: Note that for now we've only converted the down-stream # model and not the full pretrained model. This means for the integration # test you need to add a `return x` after the following line: # https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197 # to make the integration test pass. import modeling_cyclical # noqa: F401 # 1. Create model config config = Data2VecVisionConfig() if is_finetuned: config.use_relative_position_bias = True config.use_shared_relative_position_bias = False config.use_mean_pooling = True config.num_labels = 1000 repo_id = "datasets/huggingface/label-files" filename = "imagenet-1k-id2label.json" id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} else: config.use_relative_position_bias = False config.use_shared_relative_position_bias = True config.use_mean_pooling = False if is_large: config.hidden_size = 1024 config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 # 2. Load Beit model orig_model = load_beit_model(args, is_finetuned, is_large) orig_model.eval() # 3. Forward Beit model feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png") encoding = feature_extractor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"] orig_args = (pixel_values,) if is_finetuned else (pixel_values, None) with torch.no_grad(): orig_model_output = orig_model(*orig_args) # 4. Load HF Data2VecVision model if is_finetuned: hf_model = Data2VecVisionForImageClassification(config) hf_model.eval() has_lm_head = False hf_prefix = "data2vec_vision." else: hf_model = Data2VecVisionModel(config) hf_model.eval() has_lm_head = True hf_prefix = "" rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) state_dict = orig_model.state_dict() for src, dest in rename_keys: val = state_dict.pop(src) state_dict[dest] = val read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) print("HF missing", missing_keys) print("HF unexpected_keys", unexpected_keys) # 5. Forward HF Data2VecVision model with torch.no_grad(): hf_model_output = hf_model(pixel_values) hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state # 6. Compare max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item() print(f"max_absolute_diff = {max_absolute_diff}") success = torch.allclose(hf_output, orig_model_output, atol=1e-3) print("Do both models output the same tensors?", "🔥" if success else "💩") if not success: raise Exception("Something went wRoNg") # 7. Save print(f"Saving to {args.hf_checkpoint_name}") hf_model.save_pretrained(args.hf_checkpoint_name) feature_extractor.save_pretrained(args.hf_checkpoint_name)