def test_inference_image_classification_head(self):
        model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="tf")

        # forward pass
        outputs = model(**inputs)

        # verify the logits
        expected_shape = tf.TensorShape((1, 1000))
        self.assertEqual(outputs.logits.shape, expected_shape)

        expected_slice = tf.constant([-0.2744, 0.8215, -0.0836])

        tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
Esempio n. 2
0
dataset_train = dataset_train.map(load_and_process_image,
                                  num_parallel_calls=tf.data.AUTOTUNE)
dataset_train = dataset_train.shuffle(len(dataset_train)).batch(
    BATCH_SIZE, drop_remainder=True)

dataset_validation = tf.data.Dataset.from_tensor_slices(
    (image_paths['validation'], image_labels['validation']))
dataset_validation = dataset_validation.map(
    load_and_process_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset_validation = dataset_validation.batch(BATCH_SIZE, drop_remainder=True)

# ## Model
#
# ### Initialization

model = TFViTForImageClassification.from_pretrained(
    VITMODEL, num_labels=1, ignore_mismatched_sizes=True)

LR = 1e-5

optimizer = tf.keras.optimizers.Adam(learning_rate=LR)
loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
metric = 'accuracy'

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

print(model.summary())

# ### Learning

logdir = os.path.join(
    os.getcwd(), "logs",