def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = FlaxBeitForImageClassification(config=config) result = model(pixel_values) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def test_inference_image_classification_head_imagenet_22k(self): model = FlaxBeitForImageClassification.from_pretrained( "microsoft/beit-large-patch16-224-pt22k-ft22k") feature_extractor = self.default_feature_extractor image = prepare_img() inputs = feature_extractor(images=image, return_tensors="np") # forward pass outputs = model(**inputs) logits = outputs.logits # verify the logits expected_shape = (1, 21841) self.assertEqual(logits.shape, expected_shape) expected_slice = np.array([1.6881, -0.2787, 0.5901]) self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4)) expected_class_idx = 2396 self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
def test_inference_image_classification_head_imagenet_1k(self): model = FlaxBeitForImageClassification.from_pretrained( "microsoft/beit-base-patch16-224") feature_extractor = self.default_feature_extractor image = prepare_img() inputs = feature_extractor(images=image, return_tensors="np") # forward pass outputs = model(**inputs) logits = outputs.logits # verify the logits expected_shape = (1, 1000) self.assertEqual(logits.shape, expected_shape) expected_slice = np.array([-1.2385, -1.0987, -1.0108]) self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4)) expected_class_idx = 281 self.assertEqual(logits.argmax(-1).item(), expected_class_idx)