def create_and_check_for_image_classification(self, config, pixel_values,
                                               labels):
     config.num_labels = self.type_sequence_label_size
     model = TFConvNextForImageClassification(config)
     result = model(pixel_values, labels=labels, training=False)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.type_sequence_label_size))
    def test_inference_image_classification_head(self):
        model = TFConvNextForImageClassification.from_pretrained(
            "facebook/convnext-tiny-224")

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="tf")

        # forward pass
        outputs = model(**inputs)

        # verify the logits
        expected_shape = tf.TensorShape((1, 1000))
        self.assertEqual(outputs.logits.shape, expected_shape)

        expected_slice = tf.constant([-0.0260, -0.4739, 0.1911])

        tf.debugging.assert_near(outputs.logits[0, :3],
                                 expected_slice,
                                 atol=1e-4)