def test_inference_image_classification_head_imagenet_22k(self):
        model = BeitForImageClassification.from_pretrained(
            "microsoft/beit-large-patch16-224-pt22k-ft22k").to(torch_device)

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image,
                                   return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits

        # verify the logits
        expected_shape = torch.Size((1, 21841))
        self.assertEqual(logits.shape, expected_shape)

        expected_slice = torch.tensor([1.6881, -0.2787,
                                       0.5901]).to(torch_device)

        self.assertTrue(
            torch.allclose(logits[0, :3], expected_slice, atol=1e-4))

        expected_class_idx = 2396
        self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
    def test_inference_image_classification_head_imagenet_1k(self):
        model = BeitForImageClassification.from_pretrained(
            "microsoft/beit-base-patch16-224").to(torch_device)

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image,
                                   return_tensors="pt").to(torch_device)

        # forward pass
        outputs = model(**inputs)
        logits = outputs.logits

        # verify the logits
        expected_shape = torch.Size((1, 1000))
        self.assertEqual(logits.shape, expected_shape)

        expected_slice = torch.tensor([-1.2385, -1.0987,
                                       -1.0108]).to(torch_device)

        self.assertTrue(
            torch.allclose(logits[0, :3], expected_slice, atol=1e-4))

        expected_class_idx = 281
        self.assertEqual(logits.argmax(-1).item(), expected_class_idx)