def build(self, hp, inputs=None): input_node = nest.flatten(inputs)[0] output_node = input_node # Translate translation_factor = utils.add_to_hp(self.translation_factor, hp) if translation_factor not in [0, (0, 0)]: height_factor, width_factor = self._get_fraction_value( translation_factor ) output_node = layers.RandomTranslation(height_factor, width_factor)( output_node ) # Flip horizontal_flip = self.horizontal_flip if horizontal_flip is None: horizontal_flip = hp.Boolean("horizontal_flip", default=True) vertical_flip = self.vertical_flip if self.vertical_flip is None: vertical_flip = hp.Boolean("vertical_flip", default=True) if not horizontal_flip and not vertical_flip: flip_mode = "" elif horizontal_flip and vertical_flip: flip_mode = "horizontal_and_vertical" elif horizontal_flip and not vertical_flip: flip_mode = "horizontal" elif not horizontal_flip and vertical_flip: flip_mode = "vertical" if flip_mode != "": output_node = layers.RandomFlip(mode=flip_mode)(output_node) # Rotate rotation_factor = utils.add_to_hp(self.rotation_factor, hp) if rotation_factor != 0: output_node = layers.RandomRotation(rotation_factor)(output_node) # Zoom zoom_factor = utils.add_to_hp(self.zoom_factor, hp) if zoom_factor not in [0, (0, 0)]: height_factor, width_factor = self._get_fraction_value(zoom_factor) # TODO: Add back RandomZoom when it is ready. # output_node = layers.RandomZoom( # height_factor, width_factor)(output_node) # Contrast contrast_factor = utils.add_to_hp(self.contrast_factor, hp) if contrast_factor not in [0, (0, 0)]: output_node = layers.RandomContrast(contrast_factor)(output_node) return output_node
def build_model_augment() -> keras.Sequential: # build a sequential model model = keras.Sequential() # add a data augmentation layer data_augmentation = keras.Sequential([ # random flip horizontally layers.RandomFlip("horizontal"), # random flip vertically layers.RandomFlip("vertical"), # random rotation at most 0.2 degrees layers.RandomRotation(0.2), # random zoom at most 0.2 times difference layers.RandomZoom(0.2), # random contrast at most 0.1 difference layers.RandomContrast(0.1) ]) model.add(data_augmentation) # add the first convolutional layer, with ReLU as activation funtion and same padding model.add( keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same", input_shape=(32, 32, 3))) # [32,32, 32] # add a maxpooling layer to reduce the dimension model.add(keras.layers.MaxPooling2D(2, 2)) # [16, 16, 32] # add the second convolutional layer, with ReLU as activation funtion and same padding model.add( keras.layers.Conv2D(64, (3, 3), activation="relu", padding="same")) # [16, 16, 64] # add a maxpooling layer to reduce the dimension model.add(keras.layers.MaxPooling2D(2, 2)) # [8, 8, 64] # add the third convolutional layer, with ReLU as activation funtion and same padding model.add( keras.layers.Conv2D(128, (3, 3), activation="relu", padding="same")) # [8, 8, 128] # add a maxpooling layer to reduce the dimension model.add(keras.layers.MaxPooling2D(2, 2)) # [4, 4, 128] # add a flatten layer to make the data into a 1-dimensional array model.add(keras.layers.Flatten()) # [1024, ] # add a fully connected layer, with ReLU as activation function model.add(keras.layers.Dense(128, activation="relu")) # add a fully connected layer, the output size is the same as the number of classes and # softmax as activation function to do multiclass classification model.add(keras.layers.Dense(10, activation="softmax")) # compile the model, use SDG as the optimizer and categorical crossentropy as loss function model.compile(optimizer=optimizers.SGD(learning_rate=1e-3, momentum=0.9), loss="categorical_crossentropy", metrics=["accuracy"]) # return the model return model
plt.axis("off") """ ### Data augmentation We can use the preprocessing layers APIs for image augmentation. """ from tensorflow.keras.models import Sequential from tensorflow.keras import layers img_augmentation = Sequential( [ layers.RandomRotation(factor=0.15), layers.RandomTranslation(height_factor=0.1, width_factor=0.1), layers.RandomFlip(), layers.RandomContrast(factor=0.1), ], name="img_augmentation", ) """ This `Sequential` model object can be used both as a part of the model we later build, and as a function to preprocess data before feeding into the model. Using them as function makes it easy to visualize the augmented images. Here we plot 9 examples of augmentation result of a given figure. """ for image, label in ds_train.take(1): for i in range(9): ax = plt.subplot(3, 3, i + 1) aug_img = img_augmentation(tf.expand_dims(image, axis=0))
strip_chars = strip_chars.replace("<", "") strip_chars = strip_chars.replace(">", "") vectorization = TextVectorization( max_tokens=VOCAB_SIZE, output_mode="int", output_sequence_length=SEQ_LENGTH, standardize=custom_standardization, ) vectorization.adapt(text_data) # Data augmentation for image data image_augmentation = keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.2), layers.RandomContrast(0.3), ]) """ ## Building a `tf.data.Dataset` pipeline for training We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object. The pipeline consists of two steps: 1. Read the image from the disk 2. Tokenize all the five captions corresponding to the image """ def decode_and_resize(img_path, size=IMAGE_SIZE): img = tf.io.read_file(img_path) img = tf.image.decode_jpeg(img, channels=3)
# transformations (small random crop and contrast adjustment) to them # each time we are looping over them. This way, we "augment" our # training dataset to contain more data. # # The augmentation transformations are implemented as preprocessing # layers in Keras. There are various such layers readily available, # see https://keras.io/guides/preprocessing_layers/ for more # information. # # ### Initialization inputs = keras.Input(shape=INPUT_IMAGE_SIZE + [3]) x = layers.Rescaling(scale=1. / 255)(inputs) x = layers.RandomCrop(75, 75)(x) x = layers.RandomContrast(0.1)(x) x = layers.Conv2D(32, (3, 3), activation='relu')(x) x = layers.MaxPooling2D(pool_size=(2, 2))(x) x = layers.Conv2D(32, (3, 3), activation='relu')(x) x = layers.MaxPooling2D(pool_size=(2, 2))(x) x = layers.Conv2D(64, (3, 3), activation='relu')(x) x = layers.MaxPooling2D(pool_size=(2, 2))(x) x = layers.Flatten()(x) x = layers.Dense(128, activation='relu')(x) x = layers.Dropout(0.5)(x) outputs = layers.Dense(43, activation='softmax')(x)
def main(): # Download the dataset # Using the Flickr8K dataset for this tutorial. This dataset # comprises over 8,000 images, that are each paired with five # different captions. #!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip #!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip #!unzip -qq Flickr8k_Dataset.zip #!unzip -qq Flickr8k_text.zip #!rm Flickr8k_Dataset.zip Flickr8k_text.zip # Path to images. IMAGES_PATH = "Flicker8k_Dataset" # Desired image dimensions. IMAGE_SIZE = (299, 299) # Vocabulary size. VOCAB_SIZE = 10000 # Fixed length allowed for any sequence. SEQ_LENGTH = 25 # Dimension for the image embeddings and token embeddings. EMBED_DIM = 512 # Pre-layer units in the feed-forward network. FF_DIM = 512 # Other training parameters. BATCH_SIZE = 64 EPOCHS = 30 AUTOTUNE = tf.data.AUTOTUNE # Preparing the dataset. def load_captions_data(filename): # Load captions (text) data and maps them to corresponding # images. # @param: filename, path to the text file containing caption # data. # @return: caption_mapping, dictionary mapping image names and # the corresponding captions. # @return: text_data, list containing all the available # captions. with open(filename) as caption_file: caption_data = caption_file.readlines() caption_mapping = {} text_data = [] images_to_skip = set() for line in caption_data: line = line.rstrip("\n") # IMage name and captions are separated using a tab. img_name, caption = line.split("\t") # Each image is repeated five times for the five # different captions. Each image name has a suffix # '#(caption_number)'. img_name = img_name.split("#")[0] img_name = os.path.join(IMAGES_PATH, img_name.strip()) # Remove captions that are either too short or too # long. tokens = caption.strip().split() if len(tokens) < 5 or len(tokens) > SEQ_LENGTH: images_to_skip.add(img_name) continue if img_name.endswith("jpg") and img_name not in images_to_skip: # Add a start and end token to each caption. caption = "<start>" + caption.strip() + "<end>" text_data.append(caption) if img_name in caption_mapping: caption_mapping[img_name].append(caption) else: caption_mapping[img_name] = [caption] for img_name in images_to_skip: if img_name in caption_mapping: del caption_mapping[img_name] return caption_mapping, text_data def train_val_split(caption_data, train_size=0.8, shuffle=True): # Split the captioning dataset into train and validation sets. # @param: caption_data (dict), dictionary containing the mapped # data. # @param: train_size (float), fraction of all the full dataset # to use as training data. # @param: shuffle (bool), whether to shuffle the dataset before # splitting. # @return: training and validation datasets as two separated # dicts. # 1) Get the list of all image names. all_images = list(caption_data.keys()) # 2) Shuffle if necessary. if shuffle: np.random.shuffle(all_images) # 3) Split into training and validation sets. train_size = int(len(caption_data) * train_size) training_data = { img_name: caption_data[img_name] for img_name in all_images[:train_size] } validation_data = { img_name: caption_data[img_name] for img_name in all_images[train_size:] } # 4) Return the splits. return training_data, validation_data # Load the dataset. captions_mapping, text_data = load_captions_data("Flickr8k.token.txt") # Split the dataset into training and validation sets. train_data, valid_data = train_val_split(captions_mapping) print("Number of training samples: ", len(train_data)) print("Number of validation samples: ", len(valid_data)) # Vectorizing the text data # Use the TextVectorization layer to vectorize the text data, that # is to say, to turn the original strings into integer sequences # where each integer represents the index of a word in a # vocabulary. Use a custom string standardization scheme (in this # case, strip punctuation characters except < and >) and the # default splitting scheme (split on whitespace). def custom_standardization(input_string): lowercase = tf.strings.lower(input_string) return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "") strip_chars = "!\"#$%&'()*+,-./;<=>?@[\]^_`{|}~" strip_chars = strip_chars.replace("<", "") strip_chars = strip_chars.replace(">", "") vectorization = TextVectorization( max_tokens=VOCAB_SIZE, output_mode="int", output_sequence_length=SEQ_LENGTH, standardize=custom_standardization, ) vectorization.adapt(text_data) # Data augmentation for image data. image_augmentation = keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.2), layers.RandomContrast(0.3), ]) # Building a tf.data.Dataset pipeline for training # Generate pairs of images and corresponding captions using a # tf.data.Dataset object. The pipeline consists of two steps: # 1) Read the image from the disk. # 2) Tokenize all the five captions corresponding to the image. def decode_and_resize(img_path): img = tf.io.read_file(img_path) img = tf.image.decode_jpeg(img, channels=3) img = tf.image.resize(img, IMAGE_SIZE) img = tf.image.convert_image_dtype(img, tf.float32) return img def process_input(img_path, captions): return decode_and_resize(img_path), vectorization(captions) def make_dataset(images, captions): ''' if split == "train": img_dataset = tf.data.Dataset.from_tensor_slices(images).map( read_train_image, num_parallel_calls=AUTOTUNE ) else: img_dataset = tf.data.Dataset.from_tensor_slices(images).map( read_valid_image, num_parallel_calls=AUTOTUNE ) cap_dataset = tf.data.Dataset.from_tensor_slices(captions).map( vectorization, num_parallel_calls=AUTOTUNE ) dataset = tf.data.Dataset.zip((img_dataset, cap_dataset)) dataset = dataset.batch(BATCH_SIZE).shuffle(256).prefetch(AUTOTUNE) return dataset ''' dataset = tf.data.Dataset.from_tensor_slices((images, captions)) dataset = dataset.shuffle(len(images)) dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE) dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE) return dataset # Pass the list of images and the list of corresponding captions. train_dataset = make_dataset(list(train_data.keys()), list(train_data.values())) valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values())) # Building the model # The image captioning architecture consists of three models: # 1) A CNN: Used to extract the image features. # 2) A TransformerEncoder: The extracted image features are then # passed to a Transformer based encoder that generates a new # representation of the inputs. # 3) A TransformerDecoder: This model takes the encoder output and # the text data (sequences) as inputs and tries to learn to # generate the caption. def get_cnn_model(): base_model = efficientnet.EfficientNetB0( input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet", ) # Freeze the feature extractor. base_model.trainable = False base_model_out = base_model.output base_model_out = layers.Reshape( (-1, base_model_out.shape[-1]))(base_model_out) cnn_model = keras.models.Model(base_model.input, base_model_out) return cnn_model class TransformerEncoderBlock(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=0.0) self.layernorm1 = layers.LayerNormalization() self.layernorm2 = layers.LayerNormalization() self.dense_1 = layers.Dense(embed_dim, activation="relu") def call(self, inputs, training, mask=None): inputs = self.layernorm1(inputs) inputs = self.dense_1(inputs) attention_output_1 = self.attention_1( query=inputs, value=inputs, key=inputs, attention_mask=None, training=training, ) out_1 = self.layernorm2(inputs + attention_output_1) return out_1 class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): super().__init__(**kwargs) self.token_embeddings = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim) self.position_embeddings = layers.Embedding( input_dim=sequence_length, output_dim=embed_dim) self.sequence_length = sequence_length self.vocab_size = vocab_size self.embed_dim = embed_dim self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32)) def call(self, inputs): length = tf.shape(inputs)[-1] positions = tf.range(start=0, limit=length, delta=1) embedded_tokens = self.token_embeddings(inputs) embedded_tokens = embedded_tokens * self.embed_scale embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions def compute_mask(self, inputs, mask=None): return tf.math.not_equal(inputs, 0) class TransformerDecoderBlock(layers.Layer): def __init__(self, embed_dim, ff_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.ff_dim = ff_dim self.num_heads = num_heads self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=0.1) self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=0.1) self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu") self.ffn_layer_2 = layers.Dense(embed_dim) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.layernorm_3 = layers.LayerNormalization() self.embedding = PositionalEmbedding(embed_dim=EMBED_DIM, sequence_length=SEQ_LENGTH, vocab_size=VOCAB_SIZE) self.out = layers.Dense(VOCAB_SIZE, activation="softmax") self.dropout_1 = layers.Dropout(0.3) self.dropout_2 = layers.Dropout(0.5) self.supports_masking = True def call(self, inputs, encoder_outputs, training, mask=None): inputs = self.embedding(inputs) causal_mask = self.get_causal_attention_mask(inputs) if mask is not None: padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) combined_mask = tf.minimum(combined_mask, causal_mask) attention_output_1 = self.attention_1( query=inputs, value=inputs, key=inputs, attention_mask=combined_mask, training=training, ) out_1 = self.layernorm_1(inputs + attention_output_1) attention_output_2 = self.attention_2( query=out_1, value=encoder_outputs, key=encoder_outputs, attention_mask=padding_mask, training=training, ) out_2 = self.layernorm_2(out_1 + attention_output_2) ffn_out = self.ffn_layer_1(out_2) ffn_out = self.dropout_1(ffn_out, training=training) ffn_out = self.ffn_layer_2(ffn_out) ffn_out = self.layernorm_3(ffn_out + out_2, training=training) ffn_out = self.dropout_2(ffn_out, training=training) preds = self.out(ffn_out) return preds def get_causal_attention_mask(self, inputs): input_shape = tf.shape(inputs) batch_size, sequence_length = input_shape[0], input_shape[1] i = tf.range(sequence_length)[:, tf.newaxis] j = tf.range(sequence_length) mask = tf.cast(i >= j, dtype="int32") mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) mult = tf.concat( [ tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32) ], axis=0, ) return tf.tile(mask, mult) class ImageCaptioningModel(keras.Model): def __init__(self, cnn_model, encoder, decoder, num_captions_per_image=5, image_aug=None): super().__init__() self.cnn_model = cnn_model self.encoder = encoder self.decoder = decoder self.loss_tracker = keras.metrics.Mean(name="loss") self.acc_tracker = keras.metrics.Mean(name="accuracy") self.num_captions_per_image = num_captions_per_image self.image_aug = image_aug def calculate_loss(self, y_true, y_pred, mask): loss = self.loss(y_true, y_pred) mask = tf.cast(mask, dtype=loss.dtype) loss *= mask return tf.reduce_sum(loss) / tf.reduce_sum(mask) def calculate_accuracy(self, y_true, y_pred, mask): accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2)) accuracy = tf.math.logical_and(mask, accuracy) accuracy = tf.cast(accuracy, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.reduce_sum(accuracy) / tf.reduce_sum(mask) def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True): encoder_out = self.encoder(img_embed, training=training) batch_seq_inp = batch_seq[:, :-1] batch_seq_true = batch_seq[:, 1:] mask = tf.math.not_equal(batch_seq_true, 0) batch_seq_pred = self.decoder(batch_seq_inp, encoder_out, training=training, mask=mask) loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask) acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask) return loss, acc def train_step(self, batch_data): batch_img, batch_seq = batch_data batch_loss = 0 batch_acc = 0 if self.image_aug: batch_img = self.image_aug(batch_img) # 1) Get image embeddings. img_embed = self.cnn_model(batch_img) # 2) Pass each of the five captions one by one to the # decoder along with the encoder outputs and compute the # loss as well as accuracy for each caption. for i in range(self.num_captions_per_image): with tf.GradientTape() as tape: loss, acc = self._compute_caption_loss_and_acc( img_embed, batch_seq[:, i, :], training=True) # 3) Update loss and accuracy. batch_loss += loss batch_acc += acc # 4) Get the list of all the trainable weights. train_vars = (self.encoder.trainable_variables + self.decoder.trainable_variables) # 5) Get the gradients. grads = tape.gradient(loss, train_vars) # 6) Update the trainable weights. self.optimizer.apply_gradients(zip(grads, train_vars)) # 7) Update the trackers. batch_acc /= float(self.num_captions_per_image) self.loss_tracker.update_state(batch_loss) self.acc_tracker.update_state(batch_acc) # 8) Return the loss and accuracy values. return { "loss": self.loss_tracker.result(), "acc": self.acc_tracker.result() } def test_step(self, batch_data): batch_img, batch_seq = batch_data batch_loss = 0 batch_acc = 0 # 1) Get image embeddings. img_embed = self.cnn_model(batch_img) # 2) Pass each of the five captions one by one to the # decoder along with the encoder outputs and compute the # loss as well as accuracy for each caption. for i in range(self.num_captions_per_image): loss, acc = self._compute_caption_loss_and_acc(img_embed, batch_seq[:, i, :], training=False) # 3) Update loss and accuracy. batch_loss += loss batch_acc += acc batch_acc /= float(self.num_captions_per_image) # 4) Update the trackers. self.loss_tracker.update_state(batch_loss) self.acc_tracker.update_state(batch_acc) # 8) Return the loss and accuracy values. return { "loss": self.loss_tracker.result(), "acc": self.acc_tracker.result() } @property def metrics(self): # List the metrics here so the 'reset_states()' can be # called automatically. return [self.loss_tracker, self.acc_tracker] cnn_model = get_cnn_model() encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1) decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2) caption_model = ImageCaptioningModel( cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation, ) # Model training # Define the loss function. cross_entropy = keras.losses.SparseCategoricalCrossentropy( from_logits=False, reduction="none") # Early stopping criteria. early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True) # Learning Rate Scheduler for the optimizer. class LRSchedule(keras.optimizers.schedules.LearningRateSchedule): def __init__(self, post_warmup_learning_rate, warmup_steps): super().__init__() self.post_warmup_learning_rate = post_warmup_learning_rate self.warmup_steps = warmup_steps def __call__(self, step): global_step = tf.cast(step, tf.float32) warmup_steps = tf.cast(self.warmup_steps, tf.float32) warmup_progress = global_step / warmup_steps warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress return tf.cond( global_step < warmup_steps, lambda: warmup_learning_rate, lambda: self.post_warmup_learning_rate, ) # Create a learning rate schedule. num_train_steps = len(train_dataset) * EPOCHS num_warmup_steps = num_train_steps // 15 lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps) # Compile the model. caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy) # Fit the model. caption_model.fit( train_dataset, epochs=EPOCHS, validation_data=valid_dataset, callbacks=[early_stopping], ) # Check sample predictions ''' vocab = vectorization.get_vocabulary() index_lookup = dict(zip(range(len(vocab)), vocab)) max_decoded_sentence_length = SEQ_LENGTH - 1 valid_images = list(valid_data.keys()) def generate_caption(): # Select a random image from the validation dataset. sample_img = np.random.choice(valid_images) # Read the image from the disk. sample_img = decode_and_resize(sample_img) img = sample_img.numpy().clip(0, 255).astype(np.uint8) plt.imshow(img) plt.show() # Pass the image to the CNN. img = tf.expand_dims(sample_img, 0) img = caption_model.cnn_model(img) # Pass the image features to the Transformer encoder. encoded_img = caption_model.encoder(img, training=False) # Generate the caption using the Transformer decoder. decoded_caption = "<start>" for i in range(max_decoded_sentence_length): tokenized_caption = vectorization([decoded_caption])[:, :-1] mask = tf.math.not_equal(tokenized_caption, 0) predictions = caption_model.decoder( tokenized_caption, encoded_img, training=False, mask=mask ) sampled_token_index = np.argmax(predictions[0, i, :]) sampled_token = index_lookup[sampled_token_index] if sampled_token == " <end>": break decoded_caption += " " + sampled_token decoded_caption = decoded_caption.replace("<start>", "") decoded_caption = decoded_caption.replace("<end>", "").strip() print("Predicted Caption: ", decoded_caption) # Check predictions for a few samples. generate_caption() generate_caption() generate_caption() ''' # End Notes # Notice that the model starts to generate reasonable captions # after a few epochs. To keep this example easily runnable, it has # been trained with a few constraints, like a minimal number of # attention heads. To improve predictions, try changing these # training settings and find a good model for your use case. # Exit the program. exit(0)