Exemplo n.º 1
0
    def build(self, hp, inputs=None):
        input_node = nest.flatten(inputs)[0]
        output_node = input_node

        # Translate
        translation_factor = utils.add_to_hp(self.translation_factor, hp)
        if translation_factor not in [0, (0, 0)]:
            height_factor, width_factor = self._get_fraction_value(
                translation_factor
            )
            output_node = layers.RandomTranslation(height_factor, width_factor)(
                output_node
            )

        # Flip
        horizontal_flip = self.horizontal_flip
        if horizontal_flip is None:
            horizontal_flip = hp.Boolean("horizontal_flip", default=True)
        vertical_flip = self.vertical_flip
        if self.vertical_flip is None:
            vertical_flip = hp.Boolean("vertical_flip", default=True)
        if not horizontal_flip and not vertical_flip:
            flip_mode = ""
        elif horizontal_flip and vertical_flip:
            flip_mode = "horizontal_and_vertical"
        elif horizontal_flip and not vertical_flip:
            flip_mode = "horizontal"
        elif not horizontal_flip and vertical_flip:
            flip_mode = "vertical"
        if flip_mode != "":
            output_node = layers.RandomFlip(mode=flip_mode)(output_node)

        # Rotate
        rotation_factor = utils.add_to_hp(self.rotation_factor, hp)
        if rotation_factor != 0:
            output_node = layers.RandomRotation(rotation_factor)(output_node)

        # Zoom
        zoom_factor = utils.add_to_hp(self.zoom_factor, hp)
        if zoom_factor not in [0, (0, 0)]:
            height_factor, width_factor = self._get_fraction_value(zoom_factor)
            # TODO: Add back RandomZoom when it is ready.
            # output_node = layers.RandomZoom(
            # height_factor, width_factor)(output_node)

        # Contrast
        contrast_factor = utils.add_to_hp(self.contrast_factor, hp)
        if contrast_factor not in [0, (0, 0)]:
            output_node = layers.RandomContrast(contrast_factor)(output_node)

        return output_node
Exemplo n.º 2
0
def build_model_augment() -> keras.Sequential:
    # build a sequential model
    model = keras.Sequential()
    # add a data augmentation layer
    data_augmentation = keras.Sequential([  # random flip horizontally
        layers.RandomFlip("horizontal"),
        # random flip vertically
        layers.RandomFlip("vertical"),
        # random rotation at most 0.2 degrees
        layers.RandomRotation(0.2),
        # random zoom at most 0.2 times difference
        layers.RandomZoom(0.2),
        # random contrast at most 0.1 difference
        layers.RandomContrast(0.1)
    ])
    model.add(data_augmentation)
    # add the first convolutional layer, with ReLU as activation funtion and same padding
    model.add(
        keras.layers.Conv2D(32, (3, 3),
                            activation="relu",
                            padding="same",
                            input_shape=(32, 32, 3)))  # [32,32, 32]
    # add a maxpooling layer to reduce the dimension
    model.add(keras.layers.MaxPooling2D(2, 2))  # [16, 16, 32]
    # add the second convolutional layer, with ReLU as activation funtion and same padding
    model.add(
        keras.layers.Conv2D(64, (3, 3), activation="relu",
                            padding="same"))  # [16, 16, 64]
    # add a maxpooling layer to reduce the dimension
    model.add(keras.layers.MaxPooling2D(2, 2))  # [8, 8, 64]
    # add the third convolutional layer, with ReLU as activation funtion and same padding
    model.add(
        keras.layers.Conv2D(128, (3, 3), activation="relu",
                            padding="same"))  # [8, 8, 128]
    # add a maxpooling layer to reduce the dimension
    model.add(keras.layers.MaxPooling2D(2, 2))  # [4, 4, 128]
    # add a flatten layer to make the data into a 1-dimensional array
    model.add(keras.layers.Flatten())  # [1024, ]
    # add a fully connected layer, with ReLU as activation function
    model.add(keras.layers.Dense(128, activation="relu"))
    # add a fully connected layer, the output size is the same as the number of classes and
    # softmax as activation function to do multiclass classification
    model.add(keras.layers.Dense(10, activation="softmax"))
    # compile the model, use SDG as the optimizer and categorical crossentropy as loss function
    model.compile(optimizer=optimizers.SGD(learning_rate=1e-3, momentum=0.9),
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])
    # return the model
    return model
Exemplo n.º 3
0
    plt.axis("off")
"""
### Data augmentation

We can use the preprocessing layers APIs for image augmentation.
"""

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

img_augmentation = Sequential(
    [
        layers.RandomRotation(factor=0.15),
        layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
        layers.RandomFlip(),
        layers.RandomContrast(factor=0.1),
    ],
    name="img_augmentation",
)
"""
This `Sequential` model object can be used both as a part of
the model we later build, and as a function to preprocess
data before feeding into the model. Using them as function makes
it easy to visualize the augmented images. Here we plot 9 examples
of augmentation result of a given figure.
"""

for image, label in ds_train.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        aug_img = img_augmentation(tf.expand_dims(image, axis=0))
Exemplo n.º 4
0
strip_chars = strip_chars.replace("<", "")
strip_chars = strip_chars.replace(">", "")

vectorization = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=custom_standardization,
)
vectorization.adapt(text_data)

# Data augmentation for image data
image_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomContrast(0.3),
])
"""
## Building a `tf.data.Dataset` pipeline for training

We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object.
The pipeline consists of two steps:

1. Read the image from the disk
2. Tokenize all the five captions corresponding to the image
"""


def decode_and_resize(img_path, size=IMAGE_SIZE):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
# transformations (small random crop and contrast adjustment) to them
# each time we are looping over them. This way, we "augment" our
# training dataset to contain more data.
#
# The augmentation transformations are implemented as preprocessing
# layers in Keras. There are various such layers readily available,
# see https://keras.io/guides/preprocessing_layers/ for more
# information.
#
# ### Initialization

inputs = keras.Input(shape=INPUT_IMAGE_SIZE + [3])
x = layers.Rescaling(scale=1. / 255)(inputs)

x = layers.RandomCrop(75, 75)(x)
x = layers.RandomContrast(0.1)(x)

x = layers.Conv2D(32, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Conv2D(32, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Conv2D(64, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(43, activation='softmax')(x)
Exemplo n.º 6
0
def main():
    # Download the dataset
    # Using the Flickr8K dataset for this tutorial. This dataset
    # comprises over 8,000 images, that are each paired with five
    # different captions.
    #!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
    #!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
    #!unzip -qq Flickr8k_Dataset.zip
    #!unzip -qq Flickr8k_text.zip
    #!rm Flickr8k_Dataset.zip Flickr8k_text.zip

    # Path to images.
    IMAGES_PATH = "Flicker8k_Dataset"

    # Desired image dimensions.
    IMAGE_SIZE = (299, 299)

    # Vocabulary size.
    VOCAB_SIZE = 10000

    # Fixed length allowed for any sequence.
    SEQ_LENGTH = 25

    # Dimension for the image embeddings and token embeddings.
    EMBED_DIM = 512

    # Pre-layer units in the feed-forward network.
    FF_DIM = 512

    # Other training parameters.
    BATCH_SIZE = 64
    EPOCHS = 30
    AUTOTUNE = tf.data.AUTOTUNE

    # Preparing the dataset.
    def load_captions_data(filename):
        # Load captions (text) data and maps them to corresponding
        # images.
        # @param: filename, path to the text file containing caption
        #	data.
        # @return: caption_mapping, dictionary mapping image names and
        #	the corresponding captions.
        # @return: text_data, list containing all the available
        #	captions.
        with open(filename) as caption_file:
            caption_data = caption_file.readlines()
            caption_mapping = {}
            text_data = []
            images_to_skip = set()

            for line in caption_data:
                line = line.rstrip("\n")

                # IMage name and captions are separated using a tab.
                img_name, caption = line.split("\t")

                # Each image is repeated five times for the five
                # different captions. Each image name has a suffix
                # '#(caption_number)'.
                img_name = img_name.split("#")[0]
                img_name = os.path.join(IMAGES_PATH, img_name.strip())

                # Remove captions that are either too short or too
                # long.
                tokens = caption.strip().split()

                if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:
                    images_to_skip.add(img_name)
                    continue

                if img_name.endswith("jpg") and img_name not in images_to_skip:
                    # Add a start and end token to each caption.
                    caption = "<start>" + caption.strip() + "<end>"
                    text_data.append(caption)

                    if img_name in caption_mapping:
                        caption_mapping[img_name].append(caption)
                    else:
                        caption_mapping[img_name] = [caption]

            for img_name in images_to_skip:
                if img_name in caption_mapping:
                    del caption_mapping[img_name]

            return caption_mapping, text_data

    def train_val_split(caption_data, train_size=0.8, shuffle=True):
        # Split the captioning dataset into train and validation sets.
        # @param: caption_data (dict), dictionary containing the mapped
        #	data.
        # @param: train_size (float), fraction of all the full dataset
        #	to use as training data.
        # @param: shuffle (bool), whether to shuffle the dataset before
        #	splitting.
        # @return: training and validation datasets as two separated
        #	dicts.
        # 1) Get the list of all image names.
        all_images = list(caption_data.keys())

        # 2) Shuffle if necessary.
        if shuffle:
            np.random.shuffle(all_images)

        # 3) Split into training and validation sets.
        train_size = int(len(caption_data) * train_size)

        training_data = {
            img_name: caption_data[img_name]
            for img_name in all_images[:train_size]
        }
        validation_data = {
            img_name: caption_data[img_name]
            for img_name in all_images[train_size:]
        }

        # 4) Return the splits.
        return training_data, validation_data

    # Load the dataset.
    captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")

    # Split the dataset into training and validation sets.
    train_data, valid_data = train_val_split(captions_mapping)
    print("Number of training samples: ", len(train_data))
    print("Number of validation samples: ", len(valid_data))

    # Vectorizing the text data
    # Use the TextVectorization layer to vectorize the text data, that
    # is to say, to turn the original strings into integer sequences
    # where each integer represents the index of a word in a
    # vocabulary. Use a custom string standardization scheme (in this
    # case, strip punctuation characters except < and >) and the
    # default splitting scheme (split on whitespace).
    def custom_standardization(input_string):
        lowercase = tf.strings.lower(input_string)
        return tf.strings.regex_replace(lowercase,
                                        "[%s]" % re.escape(strip_chars), "")

    strip_chars = "!\"#$%&'()*+,-./;<=>?@[\]^_`{|}~"
    strip_chars = strip_chars.replace("<", "")
    strip_chars = strip_chars.replace(">", "")

    vectorization = TextVectorization(
        max_tokens=VOCAB_SIZE,
        output_mode="int",
        output_sequence_length=SEQ_LENGTH,
        standardize=custom_standardization,
    )
    vectorization.adapt(text_data)

    # Data augmentation for image data.
    image_augmentation = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.2),
        layers.RandomContrast(0.3),
    ])

    # Building a tf.data.Dataset pipeline for training
    # Generate pairs of images and corresponding captions using a
    # tf.data.Dataset object. The pipeline consists of two steps:
    # 1) Read the image from the disk.
    # 2) Tokenize all the five captions corresponding to the image.
    def decode_and_resize(img_path):
        img = tf.io.read_file(img_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, IMAGE_SIZE)
        img = tf.image.convert_image_dtype(img, tf.float32)
        return img

    def process_input(img_path, captions):
        return decode_and_resize(img_path), vectorization(captions)

    def make_dataset(images, captions):
        '''
		if split == "train":
			img_dataset = tf.data.Dataset.from_tensor_slices(images).map(
				read_train_image, num_parallel_calls=AUTOTUNE
			)
		else:
			img_dataset = tf.data.Dataset.from_tensor_slices(images).map(
				read_valid_image, num_parallel_calls=AUTOTUNE
			)

		cap_dataset = tf.data.Dataset.from_tensor_slices(captions).map(
			vectorization, num_parallel_calls=AUTOTUNE
		)

		dataset = tf.data.Dataset.zip((img_dataset, cap_dataset))
		dataset = dataset.batch(BATCH_SIZE).shuffle(256).prefetch(AUTOTUNE)
		return dataset
		'''
        dataset = tf.data.Dataset.from_tensor_slices((images, captions))
        dataset = dataset.shuffle(len(images))
        dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)
        dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
        return dataset

    # Pass the list of images and the list of corresponding captions.
    train_dataset = make_dataset(list(train_data.keys()),
                                 list(train_data.values()))
    valid_dataset = make_dataset(list(valid_data.keys()),
                                 list(valid_data.values()))

    # Building the model
    # The image captioning architecture consists of three models:
    # 1) A CNN: Used to extract the image features.
    # 2) A TransformerEncoder: The extracted image features are then
    #	passed to a Transformer based encoder that generates a new
    #	representation of the inputs.
    # 3) A TransformerDecoder: This model takes the encoder output and
    #	the text data (sequences) as inputs and tries to learn to
    #	generate the caption.
    def get_cnn_model():
        base_model = efficientnet.EfficientNetB0(
            input_shape=(*IMAGE_SIZE, 3),
            include_top=False,
            weights="imagenet",
        )

        # Freeze the feature extractor.
        base_model.trainable = False
        base_model_out = base_model.output
        base_model_out = layers.Reshape(
            (-1, base_model_out.shape[-1]))(base_model_out)
        cnn_model = keras.models.Model(base_model.input, base_model_out)
        return cnn_model

    class TransformerEncoderBlock(layers.Layer):
        def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
            super().__init__(**kwargs)
            self.embed_dim = embed_dim
            self.dense_dim = dense_dim
            self.num_heads = num_heads
            self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads,
                                                         key_dim=embed_dim,
                                                         dropout=0.0)
            self.layernorm1 = layers.LayerNormalization()
            self.layernorm2 = layers.LayerNormalization()
            self.dense_1 = layers.Dense(embed_dim, activation="relu")

        def call(self, inputs, training, mask=None):
            inputs = self.layernorm1(inputs)
            inputs = self.dense_1(inputs)

            attention_output_1 = self.attention_1(
                query=inputs,
                value=inputs,
                key=inputs,
                attention_mask=None,
                training=training,
            )

            out_1 = self.layernorm2(inputs + attention_output_1)
            return out_1

    class PositionalEmbedding(layers.Layer):
        def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
            super().__init__(**kwargs)
            self.token_embeddings = layers.Embedding(input_dim=vocab_size,
                                                     output_dim=embed_dim)
            self.position_embeddings = layers.Embedding(
                input_dim=sequence_length, output_dim=embed_dim)
            self.sequence_length = sequence_length
            self.vocab_size = vocab_size
            self.embed_dim = embed_dim
            self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))

        def call(self, inputs):
            length = tf.shape(inputs)[-1]
            positions = tf.range(start=0, limit=length, delta=1)
            embedded_tokens = self.token_embeddings(inputs)
            embedded_tokens = embedded_tokens * self.embed_scale
            embedded_positions = self.position_embeddings(positions)
            return embedded_tokens + embedded_positions

        def compute_mask(self, inputs, mask=None):
            return tf.math.not_equal(inputs, 0)

    class TransformerDecoderBlock(layers.Layer):
        def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
            super().__init__(**kwargs)
            self.embed_dim = embed_dim
            self.ff_dim = ff_dim
            self.num_heads = num_heads
            self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads,
                                                         key_dim=embed_dim,
                                                         dropout=0.1)
            self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads,
                                                         key_dim=embed_dim,
                                                         dropout=0.1)
            self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
            self.ffn_layer_2 = layers.Dense(embed_dim)

            self.layernorm_1 = layers.LayerNormalization()
            self.layernorm_2 = layers.LayerNormalization()
            self.layernorm_3 = layers.LayerNormalization()

            self.embedding = PositionalEmbedding(embed_dim=EMBED_DIM,
                                                 sequence_length=SEQ_LENGTH,
                                                 vocab_size=VOCAB_SIZE)
            self.out = layers.Dense(VOCAB_SIZE, activation="softmax")

            self.dropout_1 = layers.Dropout(0.3)
            self.dropout_2 = layers.Dropout(0.5)
            self.supports_masking = True

        def call(self, inputs, encoder_outputs, training, mask=None):
            inputs = self.embedding(inputs)
            causal_mask = self.get_causal_attention_mask(inputs)

            if mask is not None:
                padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
                combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
                combined_mask = tf.minimum(combined_mask, causal_mask)

            attention_output_1 = self.attention_1(
                query=inputs,
                value=inputs,
                key=inputs,
                attention_mask=combined_mask,
                training=training,
            )
            out_1 = self.layernorm_1(inputs + attention_output_1)

            attention_output_2 = self.attention_2(
                query=out_1,
                value=encoder_outputs,
                key=encoder_outputs,
                attention_mask=padding_mask,
                training=training,
            )
            out_2 = self.layernorm_2(out_1 + attention_output_2)

            ffn_out = self.ffn_layer_1(out_2)
            ffn_out = self.dropout_1(ffn_out, training=training)
            ffn_out = self.ffn_layer_2(ffn_out)

            ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
            ffn_out = self.dropout_2(ffn_out, training=training)
            preds = self.out(ffn_out)
            return preds

        def get_causal_attention_mask(self, inputs):
            input_shape = tf.shape(inputs)
            batch_size, sequence_length = input_shape[0], input_shape[1]
            i = tf.range(sequence_length)[:, tf.newaxis]
            j = tf.range(sequence_length)
            mask = tf.cast(i >= j, dtype="int32")
            mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
            mult = tf.concat(
                [
                    tf.expand_dims(batch_size, -1),
                    tf.constant([1, 1], dtype=tf.int32)
                ],
                axis=0,
            )
            return tf.tile(mask, mult)

    class ImageCaptioningModel(keras.Model):
        def __init__(self,
                     cnn_model,
                     encoder,
                     decoder,
                     num_captions_per_image=5,
                     image_aug=None):
            super().__init__()
            self.cnn_model = cnn_model
            self.encoder = encoder
            self.decoder = decoder
            self.loss_tracker = keras.metrics.Mean(name="loss")
            self.acc_tracker = keras.metrics.Mean(name="accuracy")
            self.num_captions_per_image = num_captions_per_image
            self.image_aug = image_aug

        def calculate_loss(self, y_true, y_pred, mask):
            loss = self.loss(y_true, y_pred)
            mask = tf.cast(mask, dtype=loss.dtype)
            loss *= mask
            return tf.reduce_sum(loss) / tf.reduce_sum(mask)

        def calculate_accuracy(self, y_true, y_pred, mask):
            accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
            accuracy = tf.math.logical_and(mask, accuracy)
            accuracy = tf.cast(accuracy, dtype=tf.float32)
            mask = tf.cast(mask, dtype=tf.float32)
            return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

        def _compute_caption_loss_and_acc(self,
                                          img_embed,
                                          batch_seq,
                                          training=True):
            encoder_out = self.encoder(img_embed, training=training)
            batch_seq_inp = batch_seq[:, :-1]
            batch_seq_true = batch_seq[:, 1:]
            mask = tf.math.not_equal(batch_seq_true, 0)
            batch_seq_pred = self.decoder(batch_seq_inp,
                                          encoder_out,
                                          training=training,
                                          mask=mask)
            loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
            acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
            return loss, acc

        def train_step(self, batch_data):
            batch_img, batch_seq = batch_data
            batch_loss = 0
            batch_acc = 0

            if self.image_aug:
                batch_img = self.image_aug(batch_img)

            # 1) Get image embeddings.
            img_embed = self.cnn_model(batch_img)

            # 2) Pass each of the five captions one by one to the
            # decoder along with the encoder outputs and compute the
            # loss as well as accuracy for each caption.
            for i in range(self.num_captions_per_image):
                with tf.GradientTape() as tape:
                    loss, acc = self._compute_caption_loss_and_acc(
                        img_embed, batch_seq[:, i, :], training=True)

                    # 3) Update loss and accuracy.
                    batch_loss += loss
                    batch_acc += acc

                # 4) Get the list of all the trainable weights.
                train_vars = (self.encoder.trainable_variables +
                              self.decoder.trainable_variables)

                # 5) Get the gradients.
                grads = tape.gradient(loss, train_vars)

                # 6) Update the trainable weights.
                self.optimizer.apply_gradients(zip(grads, train_vars))

            # 7) Update the trackers.
            batch_acc /= float(self.num_captions_per_image)
            self.loss_tracker.update_state(batch_loss)
            self.acc_tracker.update_state(batch_acc)

            # 8) Return the loss and accuracy values.
            return {
                "loss": self.loss_tracker.result(),
                "acc": self.acc_tracker.result()
            }

        def test_step(self, batch_data):
            batch_img, batch_seq = batch_data
            batch_loss = 0
            batch_acc = 0

            # 1) Get image embeddings.
            img_embed = self.cnn_model(batch_img)

            # 2) Pass each of the five captions one by one to the
            # decoder along with the encoder outputs and compute the
            # loss as well as accuracy for each caption.
            for i in range(self.num_captions_per_image):
                loss, acc = self._compute_caption_loss_and_acc(img_embed,
                                                               batch_seq[:,
                                                                         i, :],
                                                               training=False)

                # 3) Update loss and accuracy.
                batch_loss += loss
                batch_acc += acc

            batch_acc /= float(self.num_captions_per_image)

            # 4) Update the trackers.
            self.loss_tracker.update_state(batch_loss)
            self.acc_tracker.update_state(batch_acc)

            # 8) Return the loss and accuracy values.
            return {
                "loss": self.loss_tracker.result(),
                "acc": self.acc_tracker.result()
            }

        @property
        def metrics(self):
            # List the metrics here so the 'reset_states()' can be
            # called automatically.
            return [self.loss_tracker, self.acc_tracker]

    cnn_model = get_cnn_model()
    encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM,
                                      dense_dim=FF_DIM,
                                      num_heads=1)
    decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM,
                                      ff_dim=FF_DIM,
                                      num_heads=2)
    caption_model = ImageCaptioningModel(
        cnn_model=cnn_model,
        encoder=encoder,
        decoder=decoder,
        image_aug=image_augmentation,
    )

    # Model training
    # Define the loss function.
    cross_entropy = keras.losses.SparseCategoricalCrossentropy(
        from_logits=False, reduction="none")

    # Early stopping criteria.
    early_stopping = keras.callbacks.EarlyStopping(patience=3,
                                                   restore_best_weights=True)

    # Learning Rate Scheduler for the optimizer.
    class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, post_warmup_learning_rate, warmup_steps):
            super().__init__()
            self.post_warmup_learning_rate = post_warmup_learning_rate
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            global_step = tf.cast(step, tf.float32)
            warmup_steps = tf.cast(self.warmup_steps, tf.float32)
            warmup_progress = global_step / warmup_steps
            warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
            return tf.cond(
                global_step < warmup_steps,
                lambda: warmup_learning_rate,
                lambda: self.post_warmup_learning_rate,
            )

    # Create a learning rate schedule.
    num_train_steps = len(train_dataset) * EPOCHS
    num_warmup_steps = num_train_steps // 15
    lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4,
                             warmup_steps=num_warmup_steps)

    # Compile the model.
    caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule),
                          loss=cross_entropy)

    # Fit the model.
    caption_model.fit(
        train_dataset,
        epochs=EPOCHS,
        validation_data=valid_dataset,
        callbacks=[early_stopping],
    )

    # Check sample predictions
    '''
	vocab = vectorization.get_vocabulary()
	index_lookup = dict(zip(range(len(vocab)), vocab))
	max_decoded_sentence_length = SEQ_LENGTH - 1
	valid_images = list(valid_data.keys())

	def generate_caption():
		# Select a random image from the validation dataset.
		sample_img = np.random.choice(valid_images)

		# Read the image from the disk.
		sample_img = decode_and_resize(sample_img)
		img = sample_img.numpy().clip(0, 255).astype(np.uint8)
		plt.imshow(img)
		plt.show()

		# Pass the image to the CNN.
		img = tf.expand_dims(sample_img, 0)
		img = caption_model.cnn_model(img)

		# Pass the image features to the Transformer encoder.
		encoded_img = caption_model.encoder(img, training=False)

		# Generate the caption using the Transformer decoder.
		decoded_caption = "<start>"
		for i in range(max_decoded_sentence_length):
			tokenized_caption = vectorization([decoded_caption])[:, :-1]
			mask = tf.math.not_equal(tokenized_caption, 0)
			predictions = caption_model.decoder(
				tokenized_caption, encoded_img, training=False, mask=mask
			)
			sampled_token_index = np.argmax(predictions[0, i, :])
			sampled_token = index_lookup[sampled_token_index]
			if sampled_token == " <end>":
				break
			decoded_caption += " " + sampled_token

		decoded_caption = decoded_caption.replace("<start>", "")
		decoded_caption = decoded_caption.replace("<end>", "").strip()
		print("Predicted Caption: ", decoded_caption)

	# Check predictions for a few samples.
	generate_caption()
	generate_caption()
	generate_caption()
	'''

    # End Notes
    # Notice that the model starts to generate reasonable captions
    # after a few epochs. To keep this example easily runnable, it has
    # been trained with a few constraints, like a minimal number of
    # attention heads. To improve predictions, try changing these
    # training settings and find a good model for your use case.

    # Exit the program.
    exit(0)