コード例 #1
0
 def create_and_check_for_image_classification(self, config, pixel_values,
                                               labels):
     config.num_labels = self.type_sequence_label_size
     model = FlaxBeitForImageClassification(config=config)
     result = model(pixel_values)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.type_sequence_label_size))
コード例 #2
0
    def test_inference_image_classification_head_imagenet_22k(self):
        model = FlaxBeitForImageClassification.from_pretrained(
            "microsoft/beit-large-patch16-224-pt22k-ft22k")

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="np")

        # forward pass
        outputs = model(**inputs)
        logits = outputs.logits

        # verify the logits
        expected_shape = (1, 21841)
        self.assertEqual(logits.shape, expected_shape)

        expected_slice = np.array([1.6881, -0.2787, 0.5901])

        self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))

        expected_class_idx = 2396
        self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
コード例 #3
0
    def test_inference_image_classification_head_imagenet_1k(self):
        model = FlaxBeitForImageClassification.from_pretrained(
            "microsoft/beit-base-patch16-224")

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="np")

        # forward pass
        outputs = model(**inputs)
        logits = outputs.logits

        # verify the logits
        expected_shape = (1, 1000)
        self.assertEqual(logits.shape, expected_shape)

        expected_slice = np.array([-1.2385, -1.0987, -1.0108])

        self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))

        expected_class_idx = 281
        self.assertEqual(logits.argmax(-1).item(), expected_class_idx)