def test_inference_image_classification_head_imagenet_22k(self): model = BeitForImageClassification.from_pretrained( "microsoft/beit-large-patch16-224-pt22k-ft22k").to(torch_device) feature_extractor = self.default_feature_extractor image = prepare_img() inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # verify the logits expected_shape = torch.Size((1, 21841)) self.assertEqual(logits.shape, expected_shape) expected_slice = torch.tensor([1.6881, -0.2787, 0.5901]).to(torch_device) self.assertTrue( torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) expected_class_idx = 2396 self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
def test_inference_image_classification_head_imagenet_1k(self): model = BeitForImageClassification.from_pretrained( "microsoft/beit-base-patch16-224").to(torch_device) feature_extractor = self.default_feature_extractor image = prepare_img() inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass outputs = model(**inputs) logits = outputs.logits # verify the logits expected_shape = torch.Size((1, 1000)) self.assertEqual(logits.shape, expected_shape) expected_slice = torch.tensor([-1.2385, -1.0987, -1.0108]).to(torch_device) self.assertTrue( torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) expected_class_idx = 281 self.assertEqual(logits.argmax(-1).item(), expected_class_idx)