Ejemplo n.º 1
0
import sys
import numpy as np

# Stop pycache #
sys.dont_write_bytecode = True

# Keras and tensorflow imports ##############################################
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
import tensorflow.keras as keras
import tensorflow.keras.backend as K

# Allow relative imports when being executed as script.
if __name__ == "__main__" and __package__ is None:
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
  import keras_image2seq
  __package__ = "image2seq"

from image2seq.metrics.edit_distance import edit_distance_metric

prediction = np.array([[1, 3, 2, 2, 3, 1], [4, 1, 2, 1, 0, 0]])
target = np.array([[1, 3, 4, 2, 0, 0], [3, 2, 0, 0, 0, 0]])

prediction = tf.convert_to_tensor(prediction, dtype=tf.float32)
target = tf.convert_to_tensor(target, dtype=tf.float32)

prediction_file ="image2seq/metrics/test_file.txt"

edit_distance_metric(target=target, 
                     prediction=prediction, 
                     predictions_file=prediction_file)
Ejemplo n.º 2
0
    def call(self, inputs, val_mode=False, dropout=False):
        # Train or validation mode ##############################################
        if val_mode:
            logging.debug("MODEL EDU XU MLP CALL - Train mode")
        else:
            logging.debug("MODEL EDU XU MLP CALL - Validation mode")

        # STEP 0: Process Inputs ################################################
        # Input               | Encoder input       | batch_size=None x         #
        #                     | (inceptionv3)       | img_width=299 x           #
        #                     |                     | img_height=299 x          #
        #                     |                     | num_colours=3             #
        #                     | Token input         | batch_size = None x       #
        #                     |                     | token_seq_len =  x        #
        #                     |                     | token_seq_len             #
        #_____________________|_____________________|___________________________#
        input_image = inputs[0]
        input_tokens = inputs[1]
        self.batch_size = input_tokens.shape[0]
        batch_token_seq_len = input_tokens.shape[1]

        # Logging, Debug & Assert
        logging.debug("MODEL EDA XU MLP CALL - Step 0 - Process inputs - "
                      "batch_size {}".format(self.batch_size))
        logging.debug("MODEL EDA XU MLP CALL - Step 0 - Process inputs - "
                      "input_image shape {}".format(K.int_shape(input_image)))
        logging.debug("MODEL EDA XU MLP CALL - Step 0 - Process inputs - "
                      "input_tokens shape {}".format(
                          K.int_shape(input_tokens)))
        if self.image_encoder == "inceptionv3":
            tf.compat.v1.debugging.assert_equal(K.int_shape(input_image),
                                                (self.batch_size, 299, 299, 3))
        else:
            tf.compat.v1.debugging.assert_equal(K.int_shape(input_image),
                                                (self.batch_size, 64, 2048))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(input_tokens), (self.batch_size, batch_token_seq_len))

        # STEP 1: Reset Decoder Hidden State ####################################
        # Zeroes              | Initial decoder     | batch_size=None x         #
        #                     | hidden state        | decoder_hidden_dim=       #
        #                     |                     | decoder_hidden_dim        #
        #_____________________|_____________________|___________________________#
        decoder_hidden_state = \
          keras.backend.zeros(shape=(self.batch_size, self.decoder_hidden_dim))

        # Logging, Debug & Assert
        logging.debug("MODEL EDA XU MLP CALL - Step 1 - Reset decoder hidden "
                      "state - decoder_hidden_state shape {}".format(
                          K.int_shape(decoder_hidden_state)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(decoder_hidden_state),
            (self.batch_size, self.decoder_hidden_dim))

        # STEP 2: Image Encoding ################################################
        # Dense + Activations | Image encoder       | batch_size=None x         #
        #                     | output              | feature_map_wxh=None(64)  #
        #                     |                     | image_embedding_dim=      #
        #                     |                     | image_embedding_dim       #
        #_____________________|_____________________|___________________________#
        input_image_features = \
          self.model1_image_encoding([input_image],
                                     dropout=dropout)

        # Logging, Debug & Assert
        logging.debug(
            "MODEL EDA XU MLP CALL - Step 2 - Image encoding dense"
            " and activations - input_image_features shape {}".format(
                K.int_shape(input_image_features)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(input_image_features),
            (self.batch_size, 64, self.image_embedding_dim))

        # STEP 3: Token Embedding for all batch input sequences #################
        # Embedding           | Token embedding     | batch_size=None x         #
        #                     |                     | token_seq_len =           #
        #                     |                     | token_seq_len x           #
        #                     |                     | token_embedding_dim=      #
        #                     |                     | token_embedding_dim       #
        # ____________________|_____________________|___________________________#
        input_token_embeddings = \
          self.model2_token_embedding([input_tokens],
                                      dropout=dropout)

        # Logging, Debug & Assert
        logging.debug("MODEL EDA XU MLP CALL - Step 3 - Token embeddings - "
                      "target_token_embeddings shape {}".format(
                          keras.backend.int_shape(input_token_embeddings)))
        tf.compat.v1.debugging.assert_equal(
            keras.backend.int_shape(input_token_embeddings),
            (self.batch_size, batch_token_seq_len, self.token_embedding_dim))

        # STEP 4: Decoder inputs is a 'GO' ######################################
        # Slice + Expand dims | GO column           | batch_size=None x         #
        #                     |                     | token_seq_len = 1 x       #
        #                     |                     | token_embedding_dim=      #
        #                     |                     | token_embedding_dim       #
        # ____________________|_____________________|___________________________#
        # For first character input is always  GO = 1 at index 0
        # Both for teaching forcing mode and validation mode
        decoder_token_input = \
          K.expand_dims(input_token_embeddings[:, 0], 1)

        # Logging, Debug, & Assert
        logging.debug("MODEL EDA XU MLP CALL - Step 4 - Decoder inputs - "
                      "decoder_teaching_forcing_inputs shape {}".format(
                          K.int_shape(decoder_token_input)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(decoder_token_input),
            (self.batch_size, 1, self.token_embedding_dim))

        # STEP 5: Loop through token sequence ###################################
        batch_loss = 0
        batch_mean_edit_distance = 0
        self.attention_weights = []
        if val_mode:
            list_predictions = []
        for i in range(1, batch_token_seq_len):
            # STEP 5.1: Attention #################################################
            # Summed weights    | Context vector      | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            context_vector, attention_weights = \
              self.model3_attention([input_image_features, decoder_hidden_state],
                                    dropout=dropout)

            # Attention weights ###################################################
            self.attention_weights.append(attention_weights)

            # Logging, Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.1 - Attention - "
                          "Context vector shape {}".format(
                              K.int_shape(context_vector)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(context_vector),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.2: LSTM Input ################################################
            # Expand +          | LSTM input          | batch_size=None x         #
            # Concatenate       |                     | token_seq_len=1 x         #
            #                   |                     | lstm_input_dim=           #
            #                   |                     | decoder_hidden_dim +      #
            #                   |                     | token_embedding_dim       #
            # __________________|_____________________|___________________________#
            context_vector_expanded = self.layer1_expand_dims(
                context_vector, 1)

            # Logging, Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.2 - Expand context "
                          "vector - context_vector_expanded shape {}".format(
                              K.int_shape(context_vector_expanded)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(context_vector_expanded),
                (self.batch_size, 1, self.decoder_hidden_dim))

            lstm_input = self.layer2_concatenate(
                [context_vector_expanded, decoder_token_input], axis=-1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL EDA XU MLP CALL - Step 5.2 - Concat context"
                "vector and token embedding - lstm_input shape {}".format(
                    K.int_shape(lstm_input)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(lstm_input),
                (self.batch_size, 1,
                 self.token_embedding_dim + self.decoder_hidden_dim))

            # STEP 5.3: LSTM ######################################################
            # LSTM return       | LSTM Output         | batch_size=None x         #
            # sequences and     |                     | token_seq_len=1 x         #
            # state             |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Hidden State   | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Cell State     | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            # If dropout = true then in training mode
            # If dropout = false then in validation mode
            lstm_output, decoder_hidden_state, decoder_cell_state = \
              self.layer3_lstm(lstm_input, training=dropout)

            # Logging, Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.3 - LSTM output - "
                          "lstm_output shape {}".format(
                              K.int_shape(lstm_output)))
            logging.debug("MODEL EDA XU MLP CALL - Step 5.3 - LSTM output - "
                          "decoder_hidden_state shape {}".format(
                              K.int_shape(decoder_hidden_state)))
            logging.debug("MODEL EDA XU MLP CALL - Step 5.3 - LSTM output - "
                          "decoder_cell_state shape {}".format(
                              K.int_shape(decoder_cell_state)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(lstm_output),
                (self.batch_size, 1, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_hidden_state),
                (self.batch_size, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_cell_state),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.4: MLP #######################################################
            # Dense             | Predicted token     | batch_size=None x         #
            #                   |                     | token_vocab_size x        #
            #                   |                     | token_vocab_size          #
            #___________________|_____________________|___________________________#
            mlp_input = self.layer4_concatenate(
                [context_vector_expanded, lstm_output], axis=-1)

            single_token_prediction = self.model4_mlp([mlp_input],
                                                      dropout=dropout)

            # Logging. Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.4 - MLP output - "
                          "single_token_prediction shape {}".format(
                              K.int_shape(single_token_prediction)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(single_token_prediction),
                (self.batch_size, self.token_vocab_size))

            # STEP 5.5: Calculate loss ############################################
            # Loss              | Single token loss   | int                       #
            #___________________|_____________________|___________________________#
            batch_loss += masked_ce_loss_fn(
                target=input_tokens[:, i],
                prediction=single_token_prediction,
                batch_size=self.batch_size,
                token_vocab_size=self.token_vocab_size)

            # Logging, Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.5 - "
                          "Single prediction loss {}".format(batch_loss))

            # STEP 5.6 Update decoder input #######################################
            # Decoder input     | New decoder         | batch_size=None x         #
            #                   | hidden state        | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #___________________|_____________________|___________________________#
            if val_mode:
                # In validation mode use argmax output from decoder
                argmax_prediction = tf.argmax(single_token_prediction,
                                              axis=1,
                                              output_type=tf.dtypes.int32)
                list_predictions.append(argmax_prediction)
                argmax_prediction_expanded = K.expand_dims(argmax_prediction)
                decoder_token_input = \
                  self.model2_token_embedding([argmax_prediction_expanded])
            else:
                # In training mode use teacher forcing inputs
                decoder_token_input = \
                  K.expand_dims(input_token_embeddings[:, i], 1)

            # Logging, Debug & Assert
            logging.debug("MODEL EDA XU MLP CALL - Step 5.6 - Update decoder "
                          " input - decoder_token_input shape {}".format(
                              K.int_shape(decoder_token_input)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_token_input),
                (self.batch_size, 1, self.token_embedding_dim))

        # STEP 6: Calculate levenstein distance
        if val_mode:
            stack_predictions = tf.stack(list_predictions, axis=1)

            self.attention_predictions = stack_predictions
            # Logging, Debug & Assert
            # logging.debug("MODEL EDA XU CALL - Step 6 - Stack predictions  \n{}"
            #               .format(stack_predictions))
            logging.debug("MODEL EDA XU CALL - Step 6 - Stack predictions "
                          "shape {}".format(K.int_shape(stack_predictions)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(stack_predictions),
                (self.batch_size, batch_token_seq_len - 1))

            batch_mean_edit_distance = \
              edit_distance_metric(target=input_tokens[:,1:],
                                   prediction=stack_predictions,
                                   predictions_file=self.predictions_file)

        # STEP 7: Return word sequence batch loss ###############################
        return batch_loss, batch_mean_edit_distance
Ejemplo n.º 3
0
    def call(self, inputs, val_mode=False, dropout=False, train_mode="full"):
        # Train or validation mode ##############################################
        if val_mode:
            logging.debug("MODEL DRAKE NESTED CALL - Train mode")
        else:
            logging.debug("MODEL DRAKE NESTED CALL - Validation mode")

        # STEP 0: Process Inputs ################################################
        # Input               | Encoder input       | batch_size=None x         #
        #                     |                     | feature_map_wxh=None(64) x#
        #                     |                     | image_embedding_dim=      #
        #                     |                     | image_embedding_dim       #
        #                     | Token input         | batch_size = None x       #
        #                     |                     | token_seq_len =  x        #
        #                     |                     | token_seq_len             #
        #_____________________|_____________________|___________________________#
        input_image = inputs[0]
        input_tokens = inputs[1]
        self.batch_size = input_tokens.shape[0]
        batch_token_seq_len = input_tokens.shape[1]

        # Logging, Debug & Assert
        logging.debug("MODEL DRAKE NESTED CALL - Step 0 - Process inputs - "
                      "batch_size {}".format(self.batch_size))
        logging.debug("MODEL DRAKE NESTED CALL - Step 0 - Process inputs - "
                      "input_image shape {}".format(K.int_shape(input_image)))
        logging.debug("MODEL DRAKE NESTED CALL - Step 0 - Process inputs - "
                      "input_tokens shape {}".format(
                          K.int_shape(input_tokens)))
        if self.image_encoder == "inceptionv3":
            tf.compat.v1.debugging.assert_equal(K.int_shape(input_image),
                                                (self.batch_size, 299, 299, 3))
        else:
            tf.compat.v1.debugging.assert_equal(K.int_shape(input_image),
                                                (self.batch_size, 64, 2048))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(input_tokens), (self.batch_size, batch_token_seq_len))

        # STEP 1: Reset Decoder Hidden State ####################################
        # Zeroes              | Initial decoder     | batch_size=None x         #
        #                     | hidden state        | decoder_hidden_dim=       #
        #                     |                     | decoder_hidden_dim        #
        #                     | Initial outer       | batch_size=None x         #
        #                     | decoder hidden      | decoder_hidden_dim=       #
        #                     | state               | decoder_hidden_dim        #
        #_____________________|_____________________|___________________________#
        decoder_hidden_state = \
          keras.backend.zeros(shape=(self.batch_size, self.decoder_hidden_dim))
        outer_decoder_hidden_state = \
          keras.backend.zeros(shape=(self.batch_size, self.decoder_hidden_dim))

        # Logging, Debug & Assert
        logging.debug(
            "MODEL DRAKE NESTED CALL - Step 1 - Reset decoder hidden "
            "state - decoder_hidden_state shape {}".format(
                K.int_shape(decoder_hidden_state)))
        logging.debug("MODEL DRAKE NESTED CALL - Step 1 - Reset outer decoder "
                      "hidden state - decoder_hidden_state shape {}".format(
                          K.int_shape(outer_decoder_hidden_state)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(decoder_hidden_state),
            (self.batch_size, self.decoder_hidden_dim))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(outer_decoder_hidden_state),
            (self.batch_size, self.decoder_hidden_dim))

        # STEP 2: Image Encoding ################################################
        # Dense + Activations | Image encoder       | batch_size=None x         #
        #                     | output              | feature_map_wxh=None(64)  #
        #                     |                     | image_embedding_dim=      #
        #                     |                     | image_embedding_dim       #
        #_____________________|_____________________|___________________________#
        input_image_features = \
          self.model1_image_encoding([input_image],
                                      dropout=dropout)

        # Logging, Debug & Assert
        logging.debug(
            "MODEL DRAKE NESTED CALL - Step 2 - Image encoding dense"
            " and activations - input_image_features shape {}".format(
                K.int_shape(input_image_features)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(input_image_features),
            (self.batch_size, 64, self.image_embedding_dim))

        # STEP 3: Token Embedding for all batch input sequences #################
        # Embedding           | Token embedding     | batch_size=None x         #
        #                     |                     | token_seq_len =           #
        #                     |                     | token_seq_len x           #
        #                     |                     | token_embedding_dim=      #
        #                     |                     | token_embedding_dim       #
        # ____________________|_____________________|___________________________#
        input_token_embeddings = \
          self.model2_token_embedding([input_tokens],
                                      dropout=dropout)

        # Logging, Debug & Assert
        logging.debug("MODEL DRAKE NESTED CALL - Step 3 - Token embeddings - "
                      "target_token_embeddings shape {}".format(
                          keras.backend.int_shape(input_token_embeddings)))
        tf.compat.v1.debugging.assert_equal(
            keras.backend.int_shape(input_token_embeddings),
            (self.batch_size, batch_token_seq_len, self.token_embedding_dim))

        # STEP 4: Decoder inputs is a 'GO' ######################################
        # Slice + Expand dims | GO column           | batch_size=None x         #
        #                     |                     | token_seq_len = 1 x       #
        #                     |                     | token_embedding_dim=      #
        #                     |                     | token_embedding_dim       #
        # ____________________|_____________________|___________________________#
        # For first character input is always  GO = 1 at index 0
        # Both for teaching forcing mode and validation mode
        decoder_token_input = \
          K.expand_dims(input_token_embeddings[:, 0], 1)

        # Logging, Debug, & Assert
        logging.debug("MODEL DRAKE NESTED CALL - Step 4 - Decoder inputs - "
                      "decoder_teaching_forcing_inputs shape {}".format(
                          K.int_shape(decoder_token_input)))
        tf.compat.v1.debugging.assert_equal(
            K.int_shape(decoder_token_input),
            (self.batch_size, 1, self.token_embedding_dim))

        # STEP 5: Loop through token sequence ###################################
        batch_loss = 0
        batch_mean_edit_distance = 0
        if val_mode:
            list_predictions = []
        for i in range(1, batch_token_seq_len):
            # Nested mode trains the subnet on only the first five characters
            # Note this approach assumes embedded padding.
            if train_mode == "nested" and i >= 6:
                logging.debug(
                    "MODEL DRAKE NESTED CALL - Step 5 - Skip loop if "
                    "in nested mode and index is > 6")
                continue
            # STEP 5.1: Outer attention ###########################################
            # Summed weights    | Outer context       | batch_size=None x         #
            #                   | vector              | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            outer_context_vector, outer_attention_weights = \
              self.model5_attention([input_image_features,
                                     outer_decoder_hidden_state],
                                    dropout=dropout)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.1 - Outer attention - "
                "Context vector shape {}".format(
                    K.int_shape(outer_context_vector)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_context_vector),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.2. Outer LSTM Input ##########################################
            # Expand +          | Outer LSTM input    | batch_size=None x         #
            # Concatenate       |                     | token_seq_len=1 x         #
            #                   |                     | lstm_input_dim=           #
            #                   |                     | decoder_hidden_dim +      #
            #                   |                     | token_embedding_dim       #
            # __________________|_____________________|___________________________#
            outer_context_vector_expanded = \
              self.layer5_expand_dims(outer_context_vector, 1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.2 - Expand context "
                "vector - context_vector_expanded shape {}".format(
                    K.int_shape(outer_context_vector_expanded)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_context_vector_expanded),
                (self.batch_size, 1, self.decoder_hidden_dim))

            outer_lstm_input = \
              self.layer6_concatenate([outer_context_vector_expanded,
                                       decoder_token_input],
                                      axis=-1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.2 - Concat context"
                "vector and token embedding - lstm_input shape {}".format(
                    K.int_shape(outer_lstm_input)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_lstm_input),
                (self.batch_size, 1,
                 self.token_embedding_dim + self.decoder_hidden_dim))

            # STEP 5.3: Outer LSTM ################################################
            # LSTM return       | LSTM Output         | batch_size=None x         #
            # sequences and     |                     | token_seq_len=1 x         #
            # state             |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Hidden State   | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Cell State     | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            # If dropout = true then in training mode
            # If dropout = false then in validation mode
            # outer_lstm_output, outer_decoder_hidden_state, \
            #   outer_decoder_cell_state = \
            #   self.layer7_lstm(outer_lstm_input, training=dropout)
            outer_lstm_output, outer_decoder_hidden_state, \
              outer_decoder_cell_state = \
              self.layer7_lstm(outer_lstm_input)

            # Logging, Debug & Assert
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.3 - LSTM output - "
                          "lstm_output shape {}".format(
                              K.int_shape(outer_lstm_output)))
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.3 - LSTM output - "
                          "decoder_hidden_state shape {}".format(
                              K.int_shape(outer_decoder_hidden_state)))
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.3 - LSTM output - "
                          "decoder_cell_state shape {}".format(
                              K.int_shape(outer_decoder_cell_state)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_lstm_output),
                (self.batch_size, 1, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_decoder_hidden_state),
                (self.batch_size, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(outer_decoder_cell_state),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.4: Inner attention ###########################################
            # Summed weights    | Context vector      | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            # If in full or outer training mode and reach the end of a sub-sequence
            # feed in outer attention output rather than lstm output into
            # inner attention
            if train_mode != "nested" and (i - 1) % 5 == 0 and i != 1:
                decoder_hidden_state = outer_context_vector

            context_vector, attention_weights = \
              self.model3_attention([input_image_features, decoder_hidden_state],
                                    dropout=dropout)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.4 - Inner attention - "
                "Context vector shape {}".format(K.int_shape(context_vector)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(context_vector),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.5: LSTM Input ################################################
            # Expand +          | LSTM input          | batch_size=None x         #
            # Concatenate       |                     | token_seq_len=1 x         #
            #                   |                     | lstm_input_dim=           #
            #                   |                     | decoder_hidden_dim +      #
            #                   |                     | token_embedding_dim       #
            # __________________|_____________________|___________________________#
            context_vector_expanded = self.layer1_expand_dims(
                context_vector, 1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.5 - Expand context "
                "vector - context_vector_expanded shape {}".format(
                    K.int_shape(context_vector_expanded)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(context_vector_expanded),
                (self.batch_size, 1, self.decoder_hidden_dim))

            lstm_input = self.layer2_concatenate(
                [context_vector_expanded, decoder_token_input], axis=-1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.5 - Concat context"
                "vector and token embedding - lstm_input shape {}".format(
                    K.int_shape(lstm_input)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(lstm_input),
                (self.batch_size, 1,
                 self.token_embedding_dim + self.decoder_hidden_dim))

            # STEP 5.6: LSTM ######################################################
            # LSTM return       | LSTM Output         | batch_size=None x         #
            # sequences and     |                     | token_seq_len=1 x         #
            # state             |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Hidden State   | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #                   |                     |                           #
            #                   | LSTM Cell State     | batch_size=None x         #
            #                   |                     | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            # __________________|_____________________|___________________________#
            # If dropout = true then in training mode
            # If dropout = false then in validation mode
            # lstm_output, decoder_hidden_state, decoder_cell_state = \
            #   self.layer3_lstm(lstm_input, training=dropout)
            lstm_output, decoder_hidden_state, decoder_cell_state = \
              self.layer3_lstm(lstm_input)

            # Logging, Debug & Assert
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.6 - LSTM output - "
                          "lstm_output shape {}".format(
                              K.int_shape(lstm_output)))
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.6 - LSTM output - "
                          "decoder_hidden_state shape {}".format(
                              K.int_shape(decoder_hidden_state)))
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.6 - LSTM output - "
                          "decoder_cell_state shape {}".format(
                              K.int_shape(decoder_cell_state)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(lstm_output),
                (self.batch_size, 1, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_hidden_state),
                (self.batch_size, self.decoder_hidden_dim))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_cell_state),
                (self.batch_size, self.decoder_hidden_dim))

            # STEP 5.7: MLP #######################################################
            # Dense             | Predicted token     | batch_size=None x         #
            #                   |                     | token_vocab_size x        #
            #                   |                     | token_vocab_size          #
            #___________________|_____________________|___________________________#
            mlp_input = self.layer4_concatenate(
                [context_vector_expanded, lstm_output], axis=-1)

            single_token_prediction = self.model4_mlp([mlp_input],
                                                      dropout=dropout)

            # Logging. Debug & Assert
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.7 - MLP output - "
                          "single_token_prediction shape {}".format(
                              K.int_shape(single_token_prediction)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(single_token_prediction),
                (self.batch_size, self.token_vocab_size))

            # STEP 5.8: Calculate loss ############################################
            # Loss              | Single token loss   | int                       #
            #___________________|_____________________|___________________________#
            batch_loss += masked_ce_loss_fn(
                target=input_tokens[:, i],
                prediction=single_token_prediction,
                batch_size=self.batch_size,
                token_vocab_size=self.token_vocab_size)

            # Logging, Debug & Assert
            logging.debug("MODEL DRAKE NESTED CALL - Step 5.8 - "
                          "Single prediction loss {}".format(batch_loss))

            # STEP 5.9 Update decoder input #######################################
            # Decoder input     | New decoder         | batch_size=None x         #
            #                   | hidden state        | decoder_hidden_dim=       #
            #                   |                     | decoder_hidden_dim        #
            #___________________|_____________________|___________________________#
            if val_mode:
                # In validation mode use argmax output from decoder
                argmax_prediction = tf.argmax(single_token_prediction,
                                              axis=1,
                                              output_type=tf.dtypes.int32)
                list_predictions.append(argmax_prediction)
                argmax_prediction_expanded = K.expand_dims(argmax_prediction)
                decoder_token_input = \
                  self.model2_token_embedding([argmax_prediction_expanded])
            else:
                # In training mode use teacher forcing inputs
                decoder_token_input = \
                  K.expand_dims(input_token_embeddings[:, i], 1)

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 5.9 - Update decoder "
                " input - decoder_token_input shape {}".format(
                    K.int_shape(decoder_token_input)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(decoder_token_input),
                (self.batch_size, 1, self.token_embedding_dim))

        # STEP 6: Calculate levenstein distance
        if val_mode:
            stack_predictions = tf.stack(list_predictions, axis=1)
            stack_predictions_len = stack_predictions.shape[1]

            # Logging, Debug & Assert
            logging.debug(
                "MODEL DRAKE NESTED CALL - Step 6 - Stack predictions "
                "shape {}".format(K.int_shape(stack_predictions)))
            tf.compat.v1.debugging.assert_equal(
                K.int_shape(stack_predictions),
                (self.batch_size, stack_predictions_len))

            batch_mean_edit_distance = \
              edit_distance_metric(
                target=input_tokens[:, 1:stack_predictions_len +1],
                prediction=stack_predictions,
                predictions_file=self.predictions_file)

        # STEP 7: Return word sequence batch loss ###############################
        return batch_loss, batch_mean_edit_distance