def body(i, inputs_tuple, all_cache_key, all_cache_value, past_length,
                 initial_id):
            """[This is the body of the top k top p decoder]

            Args:
                i ([tf.tensor]): [iterator (an int)]
                inputs ([List of model inputs]): [description]
                all_cache_key ([K]): [description]
                all_cache_value ([V]): [description]
                past_length ([tf.tensor (1 x batch_size)]): [description]
                This is our main output or decoded ids]
                initial_id ([tf.tensor]): [To keep track of concatanted ids generated
                in each iteration]

            Returns:
                [List of tensors]: [Outputs]
            """
            inputs = {}
            for k in range(len(self.input_name_list)):
                inputs[self.input_name_list[k]] = inputs_tuple[k]
            inputs["all_cache_key"] = all_cache_key
            inputs["all_cache_value"] = all_cache_value
            inputs["past_length"] = past_length

            model_outputs = self.model(inputs)
            model_logits = model_outputs["last_token_logits"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            if self.do_sample:
                prediction_ids = tf.random.categorical(model_logits,
                                                       num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1),
                                    tf.int32)

            inputs_tuple = [None] * len(self.input_name_list)

            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
            # Convert to tuple
            inputs_tuple = tuple(inputs_tuple)
            return [
                i + 1,
                inputs_tuple,
                model_outputs["all_cache_key"],
                model_outputs["all_cache_value"],
                model_outputs["past_length"],
                tf.concat([initial_id, input_ids], axis=1),
            ]
        def call_top_k_top_p(inputs):
            """The main function to perform Top K top P (Nucleus) decoding
            Args:
                inputs ([dict]): [dict of tf.tensors (model inputs)]
            """
            input_ids_orig = inputs["input_ids"]
            batch_size = tf.shape(inputs["input_ids"])[0]
            max_sequence_length = tf.shape(inputs["input_ids"])[1]

            if self.max_iterations is None:
                iterations = tf.squeeze(inputs["iterations"])
            else:
                iterations = self.max_iterations

            model_inputs = {}
            for input_key, input_value in inputs.items():
                if input_key == "iterations":
                    continue
                model_inputs[input_key] = tf.repeat(
                    input_value, [self.num_return_sequences], axis=0)
            # Updated batch size
            batch_size_updated = tf.shape(model_inputs["input_ids"])[0]

            # Pre-initialize addtional inputs
            zero_entry = tf.zeros((
                self.num_hidden_layers,
                batch_size_updated,
                self.num_attention_heads,
                max_sequence_length,
                self.attention_state,
            ))
            all_cache_key = zero_entry
            all_cache_value = zero_entry
            # past_length for keeping track of positional ids
            past_length = tf.expand_dims(
                tf.zeros(batch_size_updated, dtype=tf.int32), 0)
            # Iterator to keep track of the loop
            i = tf.constant([[0]])
            initial_id = tf.ones(shape=(batch_size_updated, 1), dtype=tf.int32)

            # Add remaining model inputs
            model_inputs["all_cache_key"] = all_cache_key
            model_inputs["all_cache_value"] = all_cache_value
            model_inputs["past_length"] = past_length

            if "input_type_ids" in self.input_name_list:
                model_inputs["input_type_ids"] = tf.ones_like(
                    model_inputs["input_ids"]) * self.input_type_ids

            if "input_mask" in self.input_name_list:
                model_inputs["input_mask"] = tf.ones_like(
                    model_inputs["input_ids"]) * self.input_mask_ids

            # First pass to the model
            model_outputs = self.model(model_inputs)
            model_logits = model_outputs["last_token_logits"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            if self.do_sample:
                prediction_ids = tf.random.categorical(model_logits,
                                                       num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1),
                                    tf.int32)
            inputs_tuple = [None] * len(self.input_name_list)
            input_shapes_tuple = [tf.TensorShape([None, None])] * len(
                self.input_name_list)
            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue

            inputs_tuple = tuple(inputs_tuple)
            input_shapes_tuple = tuple(input_shapes_tuple)

            # Concatanate
            initial_id = tf.concat([initial_id, input_ids], axis=1)

            # on step 0

            masks = tf.cast(tf.not_equal(model_inputs["input_ids"], -1),
                            tf.float32)
            masks = tf.reshape(
                masks,
                (1, batch_size_updated, 1, tf.shape(
                    model_inputs["input_ids"])[1], 1),
            )

            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            all_cache_key = all_cache_key * masks
            all_cache_value = all_cache_value * masks
            # END

            results = tf.while_loop(
                cond,
                body,
                maximum_iterations=iterations - 1,
                loop_vars=[
                    i,
                    inputs_tuple,
                    all_cache_key,
                    all_cache_value,
                    model_outputs["past_length"],
                    initial_id,
                ],
                shape_invariants=[
                    i.get_shape(),
                    input_shapes_tuple,
                    tf.TensorShape([
                        self.num_hidden_layers,
                        None,
                        self.num_attention_heads,
                        None,
                        self.attention_state,
                    ]),
                    tf.TensorShape([
                        self.num_hidden_layers,
                        None,
                        self.num_attention_heads,
                        None,
                        self.attention_state,
                    ]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                ],
            )

            results_dict = {}
            results_dict["iterations"] = results[0]
            results_dict["input_ids"] = input_ids_orig
            # Skip -1 initial ids
            results_dict["predicted_ids"] = results[-1][:, 1:]
            results_dict["predicted_ids"] = tf.reshape(
                results_dict["predicted_ids"],
                (batch_size, self.num_return_sequences, -1),
            )

            matched_positions = (tf.squeeze(
                tf.reshape(
                    tf.argmax(
                        tf.cast(
                            tf.equal(self.eos_id,
                                     results_dict["predicted_ids"]),
                            tf.int32,
                        ),
                        axis=2,
                    ),
                    (-1, batch_size * self.num_return_sequences),
                ),
                [0],
            ) - 1)
            # no eos matched positions will be 0, replace with -1
            eos_pos_mask = tf.cast(tf.equal(matched_positions, 0),
                                   tf.int32) * -1
            matched_positions = tf.cast(matched_positions,
                                        tf.int32) + eos_pos_mask
            results_dict["matched_eos_pos"] = matched_positions

            return results_dict
        def call_beam(inputs):
            """The main function to perform beam search
            Args:
                inputs ([dict]): [dict of tf.tensors (model inputs)]
            """
            input_ids_orig = inputs["input_ids"]
            # We take 2x beams
            beams_to_keep = 2 * self.beam_size
            # Original batch size and sequence length
            batch_size = tf.shape(inputs["input_ids"])[0]
            max_sequence_length = tf.shape(inputs["input_ids"])[1]
            # Repeat for beam search (We nedd batch_size x beam_size)
            model_inputs = {}
            for input_key, input_value in inputs.items():
                if input_key == "iterations":
                    continue
                model_inputs[input_key] = tf.repeat(input_value,
                                                    [self.beam_size],
                                                    axis=0)
            # New batch size
            batch_size_updated = tf.shape(model_inputs["input_ids"])[0]

            # Pre-initialize addtional inputs
            zero_entry = tf.zeros((
                self.num_hidden_layers,
                batch_size_updated,
                self.num_attention_heads,
                max_sequence_length,
                self.attention_state,
            ))
            all_cache_key = zero_entry
            all_cache_value = zero_entry
            # past_length for keeping track of positional ids
            past_length = tf.expand_dims(
                tf.zeros(batch_size_updated, dtype=tf.int32), 0)
            # Iterator to keep track of the loop
            i = tf.constant([[0]])

            if self.max_iterations is None:
                iterations = tf.squeeze(inputs["iterations"])
            else:
                iterations = self.max_iterations

            # Add remaining model inputs
            model_inputs["all_cache_key"] = all_cache_key
            model_inputs["all_cache_value"] = all_cache_value
            model_inputs["past_length"] = past_length

            if "input_type_ids" in self.input_name_list:
                model_inputs["input_type_ids"] = tf.ones_like(
                    model_inputs["input_ids"]) * self.input_type_ids

            if "input_mask" in self.input_name_list:
                model_inputs["input_mask"] = tf.ones_like(
                    model_inputs["input_ids"]) * self.input_mask_ids

            # We need this to re-ordering and keep track of best -log(prob))
            alive_log_probs = -np.inf * tf.ones(
                (batch_size, self.beam_size - 1))
            alive_log_probs = tf.concat(
                [tf.zeros([batch_size, 1]), alive_log_probs], axis=1)
            alive_seq = tf.zeros((batch_size, self.beam_size, 1))

            # First pass to the model
            model_outputs = self.model(model_inputs)
            model_logits = model_outputs["last_token_logits"] / self.temperature
            # Update iter
            i = i + 1
            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            past_length = model_outputs["past_length"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            # vocab size
            vocab_size = tf.shape(model_logits)[1]
            logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1))
            # # Convert logits to normalized log probs
            candidate_log_probs = _log_prob_from_logits(logits)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(
                alive_log_probs, 2)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                             axis=2)

            # Add length penalty
            length_penalty = tf.pow(
                ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), self.alpha)
            log_probs = log_probs / length_penalty

            # Each batch item has beam_size * vocab_size candidate sequences. For each
            # batch item, get the k candidates with the highest log probabilities.
            flat_log_probs = tf.reshape(log_probs,
                                        [-1, self.beam_size * vocab_size])

            if self.do_sample:
                next_tokens = tf.random.categorical(
                    flat_log_probs, dtype=tf.int32,
                    num_samples=beams_to_keep)  # (batch_size, 2 * num_beams)

                # # Compute next scores
                next_scores = tf.gather(
                    flat_log_probs, next_tokens,
                    batch_dims=1)  # (batch_size, 2 * num_beams)

                # # sort the sampled vector to make sure that the first num_beams
                # samples are the best
                next_scores_indices = tf.argsort(next_scores,
                                                 direction="DESCENDING",
                                                 axis=1)
                next_scores = tf.gather(
                    next_scores, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(
                    next_tokens, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)

                topk_log_probs = next_scores
                topk_indices = next_tokens
            else:
                topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs,
                                                           k=beams_to_keep)

            topk_beam_indices = topk_indices // vocab_size
            topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_seq = tf.cast(topk_seq, tf.int32)
            topk_ids = topk_indices % vocab_size
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_alive_seq = topk_seq[:, :self.beam_size, :]
            alive_log_probs = topk_log_probs[:, :self.beam_size]
            input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1])
            alive_seq = topk_alive_seq

            inputs_tuple = [None] * len(self.input_name_list)
            input_shapes_tuple = [tf.TensorShape([None, None])] * len(
                self.input_name_list)
            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue

            inputs_tuple = tuple(inputs_tuple)
            input_shapes_tuple = tuple(input_shapes_tuple)

            # on step 0

            masks = tf.cast(tf.not_equal(model_inputs["input_ids"], -1),
                            tf.float32)
            masks = tf.reshape(
                masks,
                (1, batch_size_updated, 1, tf.shape(
                    model_inputs["input_ids"])[1], 1),
            )
            all_cache_key = all_cache_key * masks
            all_cache_value = all_cache_value * masks

            all_cache_key, all_cache_value = self.reorder_past_batches(
                all_cache_key, all_cache_value, coordinates, self.beam_size)

            # END
            results = tf.while_loop(
                cond,
                body,
                maximum_iterations=iterations - 1,
                loop_vars=[
                    i,
                    inputs_tuple,
                    all_cache_key,
                    all_cache_value,
                    past_length,
                    alive_log_probs,
                    alive_seq,
                ],
                shape_invariants=[
                    i.get_shape(),
                    input_shapes_tuple,
                    tf.TensorShape([
                        self.num_hidden_layers,
                        None,
                        self.num_attention_heads,
                        None,
                        self.attention_state,
                    ]),
                    tf.TensorShape([
                        self.num_hidden_layers,
                        None,
                        self.num_attention_heads,
                        None,
                        self.attention_state,
                    ]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None, None]),
                ],
            )

            results_dict = {}
            results_dict["iterations"] = results[0]
            results_dict["input_ids"] = input_ids_orig
            # Skip -1 initial ids
            results_dict["predicted_ids"] = results[
                -1][:, :, 1:]  # to remove initial 0

            matched_positions = (tf.squeeze(
                tf.reshape(
                    tf.argmax(
                        tf.cast(
                            tf.equal(self.eos_id,
                                     results_dict["predicted_ids"]),
                            tf.int32,
                        ),
                        axis=2,
                    ),
                    (-1, batch_size * self.beam_size),
                ),
                [0],
            ) - 1)
            # no eos matched positions will be 0, replace with -1
            eos_pos_mask = tf.cast(tf.equal(matched_positions, 0),
                                   tf.int32) * -1
            matched_positions = tf.cast(matched_positions,
                                        tf.int32) + eos_pos_mask
            results_dict["matched_eos_pos"] = matched_positions

            return results_dict
        def body(
            i,
            inputs_tuple,
            all_cache_key,
            all_cache_value,
            past_length,
            alive_log_probs,
            alive_seq,
        ):
            """[This is the body of the beam decoder]

            Args:
                i ([tf.tensor]): [iterator (an int)]
                inputs ([List of model inputs]): [description]
                all_cache_key ([K]): [description]
                all_cache_value ([V]): [description]
                past_length ([tf.tensor (1 x batch_size)]): [description]
                This is our main output or decoded ids]
                alive_log_probs ([tf.tensor]): [To keep track of active ids]
                alive_seq ([tf.tensor]): [description]

            Returns:
                [List of tensors]: [Outputs]
            """
            inputs = {}
            for k in range(len(self.input_name_list)):
                inputs[self.input_name_list[k]] = inputs_tuple[k]
            inputs["all_cache_key"] = all_cache_key
            inputs["all_cache_value"] = all_cache_value
            inputs["past_length"] = past_length

            beams_to_keep = 2 * self.beam_size
            model_outputs = self.model(inputs)

            model_logits = model_outputs["last_token_logits"]

            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            past_length = model_outputs["past_length"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            vocab_size = tf.shape(model_logits)[1]
            batch_size = tf.shape(inputs["input_ids"])[0] // self.beam_size
            logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1))
            # # Convert logits to normalized log probs
            candidate_log_probs = _log_prob_from_logits(logits)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(
                alive_log_probs, 2)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                             axis=2)

            # Add length penalty
            length_penalty = tf.pow(
                ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), self.alpha)
            log_probs = log_probs / length_penalty

            # Each batch item has beam_size * vocab_size candidate sequences. For each
            # batch item, get the k candidates with the highest log probabilities.
            flat_log_probs = tf.reshape(log_probs,
                                        [-1, self.beam_size * vocab_size])

            if self.do_sample:
                next_tokens = tf.random.categorical(
                    flat_log_probs, dtype=tf.int32,
                    num_samples=beams_to_keep)  # (batch_size, 2 * num_beams)

                # # Compute next scores
                next_scores = tf.gather(
                    flat_log_probs, next_tokens,
                    batch_dims=1)  # (batch_size, 2 * num_beams)

                # # sort the sampled vector to make sure that the first num_beams
                # samples are the best
                next_scores_indices = tf.argsort(next_scores,
                                                 direction="DESCENDING",
                                                 axis=1)
                next_scores = tf.gather(
                    next_scores, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(
                    next_tokens, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)

                topk_log_probs = next_scores
                topk_indices = next_tokens
            else:
                topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs,
                                                           k=beams_to_keep)

            topk_beam_indices = topk_indices // vocab_size
            topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_seq = tf.cast(topk_seq, tf.int32)
            topk_ids = topk_indices % vocab_size
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_alive_seq = topk_seq[:, :self.beam_size, :]
            alive_log_probs = topk_log_probs[:, :self.beam_size]
            input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1])
            alive_seq = topk_alive_seq

            inputs_tuple = [None] * len(self.input_name_list)

            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(
                        input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
            # Convert to tuple
            inputs_tuple = tuple(inputs_tuple)

            all_cache_key, all_cache_value = self.reorder_past_batches(
                all_cache_key, all_cache_value, coordinates, self.beam_size)
            model_outputs["all_cache_key"] = all_cache_key
            model_outputs["all_cache_value"] = all_cache_value

            return [
                i + 1,
                inputs_tuple,
                model_outputs["all_cache_key"],
                model_outputs["all_cache_value"],
                model_outputs["past_length"],
                alive_log_probs,
                alive_seq,
            ]
    def top_k_top_p(
        self,
        tokenized_input_dict,
        max_iterations,
        top_k=0,
        top_p=0,
        temperature=1.0,
        do_sample=True,
        num_return_sequences=1,
        eos_id=-100,
    ):

        # We need this to return
        input_ids_original = tokenized_input_dict["encoder_input_ids"]
        batch_size = tf.shape(tokenized_input_dict["encoder_input_ids"])[0]

        # Repeat for beam search
        tokenized_input_dict_ragged = {}
        for input_key, input_value in tokenized_input_dict.items():
            tokenized_input_dict_ragged[input_key] = tf.repeat(
                input_value, [num_return_sequences], axis=0)

        # We take 2x beams
        batch_size_updated = tf.shape(
            tokenized_input_dict_ragged["encoder_input_ids"])[0]
        decoder_start_sequence_length = 1

        # Initialize with zeros
        encoder_sequence_length = tf.shape(
            tokenized_input_dict_ragged["encoder_input_ids"])[1]
        decoder_start_sequence_length = 1

        encoder_hidden_states = tf.zeros(
            (batch_size_updated, encoder_sequence_length, self.embedding_size))
        all_cache_key = tf.zeros((
            self.decoder_num_hidden_layers,
            batch_size_updated,
            self.decoder_num_attention_heads,
            decoder_start_sequence_length,
            self.decoder_attention_state,
        ))
        all_cache_value = tf.zeros((
            self.decoder_num_hidden_layers,
            batch_size_updated,
            self.decoder_num_attention_heads,
            decoder_start_sequence_length,
            self.decoder_attention_state,
        ))

        # Inputs ready
        tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast(
            tf.ones(shape=(batch_size_updated, 1)) *
            self.decode_start_token_id,
            tf.int32,
        )
        tokenized_input_dict_ragged["decoder_all_cache_key"] = all_cache_key
        tokenized_input_dict_ragged[
            "decoder_all_cache_value"] = all_cache_value
        tokenized_input_dict_ragged[
            "encoder_hidden_states"] = encoder_hidden_states

        if self.decoder_input_type_ids > -1:
            tokenized_input_dict_ragged["decoder_input_type_ids"] = (
                tf.ones_like(tokenized_input_dict_ragged["decoder_input_ids"])
                * self.decoder_input_type_ids)

        all_predictions = []
        matched_positions = tf.constant([-1] * batch_size_updated)

        # Iterate Over
        for i in range(max_iterations):
            result = self.model_fn(tokenized_input_dict_ragged)

            model_logits = result["last_token_logits"]
            all_cache_key = result["decoder_all_cache_key"]
            all_cache_value = result["decoder_all_cache_value"]
            encoder_hidden_states = result["encoder_hidden_states"]

            model_logits = model_logits / temperature

            if top_k > 0:
                model_logits = top_k_logits(model_logits, k=top_k)
            if top_p > 0:
                model_logits = top_p_logits(model_logits, p=top_p)

            if do_sample:
                prediction_ids = tf.random.categorical(model_logits,
                                                       num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1),
                                    tf.int32)

            all_predictions.append(input_ids)

            tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast(
                input_ids, tf.int32)
            tokenized_input_dict_ragged[
                "decoder_all_cache_key"] = all_cache_key
            tokenized_input_dict_ragged[
                "decoder_all_cache_value"] = all_cache_value
            tokenized_input_dict_ragged[
                "encoder_hidden_states"] = encoder_hidden_states

            eos_check = tf.greater(
                tf.reduce_prod(
                    tf.reduce_sum(
                        tf.cast(
                            tf.equal(tf.concat(all_predictions, axis=1),
                                     eos_id), tf.int32),
                        axis=[1],
                    )),
                0,
            )
            if eos_check:
                break

        matched_positions = (tf.reshape(
            tf.argmax(
                tf.cast(tf.equal(eos_id, tf.concat(all_predictions, axis=1)),
                        tf.int32),
                axis=1,
            ),
            -1,
        ) - 1)
        # no eos matched positions will be 0, replace with -1
        eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1
        matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask

        all_predictions = tf.reshape(tf.concat(all_predictions, axis=1),
                                     (batch_size, num_return_sequences, -1))

        return {
            "iterations": i + 1,
            "input_ids": input_ids_original,
            "predicted_ids": all_predictions,
            "matched_eos_pos": matched_positions,
        }
    def beam_decode(
        self,
        tokenized_input_dict,
        beam_size,
        max_iterations,
        temperature=1.0,
        alpha=0.0,
        top_k=0,
        top_p=0,
        do_sample=False,
        eos_id=-100,
    ):
        """Supports Variable Batch Decoding for GPT2

        text_list: a list of text
        beam_size: int
        length: number of steps to decode
        vocab_size: vocabulary size
        do_sample: Using multinomial distribution to \
            sample the most likely word, still uses beam
        eos_ids: list of IDS, to consider as decoder stop
        """

        # We need this to return
        input_ids_original = tokenized_input_dict["encoder_input_ids"]
        batch_size = tf.shape(tokenized_input_dict["encoder_input_ids"])[0]

        tokenized_input_dict_ragged = {}
        # Repeat for beam search
        for input_key, input_value in tokenized_input_dict.items():
            tokenized_input_dict_ragged[input_key] = tf.repeat(input_value,
                                                               [beam_size],
                                                               axis=0)

        # We take 2x beams
        beams_to_keep = 2 * beam_size
        batch_size_updated = tf.shape(
            tokenized_input_dict_ragged["encoder_input_ids"])[0]
        decoder_start_sequence_length = 1

        # Initialize with zeros
        encoder_sequence_length = tf.shape(
            tokenized_input_dict_ragged["encoder_input_ids"])[1]
        decoder_start_sequence_length = 1

        encoder_hidden_states = tf.zeros(
            (batch_size_updated, encoder_sequence_length, self.embedding_size))
        all_cache_key = tf.zeros((
            self.decoder_num_hidden_layers,
            batch_size_updated,
            self.decoder_num_attention_heads,
            decoder_start_sequence_length,
            self.decoder_attention_state,
        ))
        all_cache_value = tf.zeros((
            self.decoder_num_hidden_layers,
            batch_size_updated,
            self.decoder_num_attention_heads,
            decoder_start_sequence_length,
            self.decoder_attention_state,
        ))

        # Inputs ready
        tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast(
            tf.ones(shape=(batch_size_updated, 1)) *
            self.decode_start_token_id,
            tf.int32,
        )
        tokenized_input_dict_ragged["decoder_all_cache_key"] = all_cache_key
        tokenized_input_dict_ragged[
            "decoder_all_cache_value"] = all_cache_value
        tokenized_input_dict_ragged[
            "encoder_hidden_states"] = encoder_hidden_states

        if self.decoder_input_type_ids > -1:
            tokenized_input_dict_ragged["decoder_input_type_ids"] = (
                tf.ones_like(tokenized_input_dict_ragged["decoder_input_ids"])
                * self.decoder_input_type_ids)

        matched_positions = tf.constant([-1] * batch_size_updated)

        alive_log_probs = -np.inf * tf.ones((batch_size, beam_size - 1))

        # alive_log_probs = tf.zeros((batch_size, beam_size-1))

        alive_log_probs = tf.concat(
            [tf.zeros([batch_size, 1]), alive_log_probs], axis=1)
        alive_seq = tf.zeros((batch_size, beam_size, 1))

        for i in range(max_iterations):

            result = self.model_fn(tokenized_input_dict_ragged)

            model_logits = result["last_token_logits"]
            all_cache_key = result["decoder_all_cache_key"]
            all_cache_value = result["decoder_all_cache_value"]
            encoder_hidden_states = result["encoder_hidden_states"]

            model_logits = model_logits / temperature

            if top_k > 0:
                model_logits = top_k_logits(model_logits, k=top_k)
            if top_p > 0:
                model_logits = top_p_logits(model_logits, p=top_p)

            vocab_size = tf.shape(model_logits)[1]
            logits = tf.reshape(model_logits, (batch_size, beam_size, -1))
            # # Convert logits to normalized log probs
            candidate_log_probs = _log_prob_from_logits(logits)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(
                alive_log_probs, 2)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                             axis=2)

            # Add length penalty
            length_penalty = tf.pow(
                ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), alpha)
            log_probs = log_probs / length_penalty
            # Each batch item has beam_size * vocab_size candidate sequences. For each
            # batch item, get the k candidates with the highest log probabilities.
            flat_log_probs = tf.reshape(log_probs,
                                        [-1, beam_size * vocab_size])

            if do_sample:
                next_tokens = tf.random.categorical(
                    flat_log_probs, dtype=tf.int32,
                    num_samples=beams_to_keep)  # (batch_size, 2 * num_beams)

                # # Compute next scores
                next_scores = tf.gather(
                    flat_log_probs, next_tokens,
                    batch_dims=1)  # (batch_size, 2 * num_beams)

                # # sort the sampled vector to make sure that \
                # the first num_beams samples are the best
                next_scores_indices = tf.argsort(next_scores,
                                                 direction="DESCENDING",
                                                 axis=1)
                next_scores = tf.gather(
                    next_scores, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(
                    next_tokens, next_scores_indices,
                    batch_dims=1)  # (batch_size, num_beams * 2)

                topk_log_probs = next_scores
                topk_indices = next_tokens
            else:
                topk_log_probs, topk_indices = tf.nn.top_k(
                    flat_log_probs,
                    k=beams_to_keep)  # (batch_size x k (beams_to_keep))

            topk_beam_indices = topk_indices // vocab_size
            topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_seq = tf.cast(topk_seq, tf.int32)
            topk_ids = topk_indices % vocab_size
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_alive_seq = topk_seq[:, :beam_size, :]
            alive_log_probs = topk_log_probs[:, :beam_size]
            input_ids = tf.reshape(topk_ids[:, :beam_size], [-1, 1])
            alive_seq = topk_alive_seq

            all_cache_key, all_cache_value = self.reorder_past_batches(
                all_cache_key, all_cache_value, coordinates, beam_size)

            tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast(
                input_ids, tf.int32)
            tokenized_input_dict_ragged[
                "decoder_all_cache_key"] = all_cache_key
            tokenized_input_dict_ragged[
                "decoder_all_cache_value"] = all_cache_value
            tokenized_input_dict_ragged[
                "encoder_hidden_states"] = encoder_hidden_states

            eos_check = tf.greater(
                tf.reduce_prod(
                    tf.reduce_sum(
                        tf.cast(tf.equal(topk_alive_seq, eos_id), tf.int32),
                        axis=[2],
                    )),
                0,
            )
            if eos_check:
                break

        matched_positions = (tf.reshape(
            tf.argmax(tf.cast(tf.equal(eos_id, topk_alive_seq), tf.int32),
                      axis=2),
            -1,
        ) - 1)
        # no eos matched positions will be 0, replace with -1
        eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1
        matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask

        return {
            "iterations": i + 1,
            "input_ids": input_ids_original,
            "predicted_ids": topk_alive_seq[:, :, 1:],  # to avoid initial 0
            "matched_eos_pos": matched_positions - 1,
        }
    def top_k_top_p(
        self,
        tokenized_input_dict,
        max_iterations,
        top_k=0,
        top_p=0,
        temperature=1.0,
        do_sample=True,
        num_return_sequences=1,
        eos_id=-100,
    ):

        # We need this to return
        input_ids_original = tokenized_input_dict["input_ids"]
        batch_size = tf.shape(input_ids_original)[0]
        max_sequence_length = tf.shape(input_ids_original)[1]

        # Repeat for beam search
        tokenized_input_dict_ragged = {}
        for input_key, input_value in tokenized_input_dict.items():
            tokenized_input_dict_ragged[input_key] = tf.repeat(input_value, [num_return_sequences], axis=0)

        # We take 2x beams
        batch_size_updated = tokenized_input_dict_ragged["input_ids"].shape[0]

        # Initialize with zeros
        zero_entry = tf.zeros(
            (
                self.num_hidden_layers,
                batch_size_updated,
                self.num_attention_heads,
                max_sequence_length,
                self.attention_state,
            )
        )
        all_cache_key = zero_entry
        all_cache_value = zero_entry
        past_length = tf.expand_dims(tf.zeros(batch_size_updated, dtype=tf.int32), 0)

        # Inputs ready
        tokenized_input_dict_ragged["all_cache_key"] = all_cache_key
        tokenized_input_dict_ragged["all_cache_value"] = all_cache_value
        tokenized_input_dict_ragged["past_length"] = past_length

        if self.input_type_ids > -1:
            tokenized_input_dict_ragged["input_type_ids"] = (
                tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_type_ids
            )
        if self.input_mask_ids > -1:
            tokenized_input_dict_ragged["input_mask"] = (
                tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_mask_ids
            )

        all_predictions = []
        matched_positions = tf.constant([-1] * batch_size_updated)

        # Iterate Over
        for i in range(max_iterations):
            result = self.model_fn(tokenized_input_dict_ragged)
            model_logits = result["last_token_logits"]
            all_cache_key = result["all_cache_key"]
            all_cache_value = result["all_cache_value"]
            past_length = result["past_length"]

            if top_k > 0:
                model_logits = top_k_logits(model_logits, k=top_k)
            if top_p > 0:
                model_logits = top_p_logits(model_logits, p=top_p)

            if do_sample:
                prediction_ids = tf.random.categorical(model_logits, num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32)
            all_predictions.append(input_ids)
            if i == 0:
                # all_cache_key = assign_zeros_to_K_V(all_cache_key, \
                # input_ids_copy, batch_size, max_sequence_length)
                # all_cache_value = assign_zeros_to_K_V(all_cache_value, \
                # input_ids_copy, batch_size, max_sequence_length)

                masks = tf.cast(
                    tf.not_equal(tokenized_input_dict_ragged["input_ids"], -1),
                    tf.float32,
                )
                masks = tf.reshape(masks, (1, batch_size_updated, 1, max_sequence_length, 1))
                all_cache_key = all_cache_key * masks
                all_cache_value = all_cache_value * masks

            tokenized_input_dict_ragged["input_ids"] = tf.cast(input_ids, tf.int32)
            tokenized_input_dict_ragged["all_cache_key"] = all_cache_key
            tokenized_input_dict_ragged["all_cache_value"] = all_cache_value
            tokenized_input_dict_ragged["past_length"] = past_length

            if self.input_type_ids > -1:
                tokenized_input_dict_ragged["input_type_ids"] = (
                    tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_type_ids
                )
            if self.input_mask_ids > -1:
                tokenized_input_dict_ragged["input_mask"] = (
                    tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_mask_ids
                )

            if eos_id:
                temp_m = tf.concat(all_predictions, axis=1)
                eos_check = tf.greater(
                    tf.reduce_prod(tf.reduce_sum(tf.cast(tf.equal(temp_m, eos_id), tf.int32), axis=[1])),
                    0,
                )
                if eos_check:
                    matched_positions = tf.argmax(tf.cast(tf.equal(eos_id, temp_m), tf.int32), axis=1)
                    # matched_positions += 1
                    break

        matched_positions = (
            tf.reshape(
                tf.argmax(
                    tf.cast(tf.equal(eos_id, tf.concat(all_predictions, axis=1)), tf.int32),
                    axis=1,
                ),
                -1,
            )
            - 1
        )
        # no eos matched positions will be 0, replace with -1
        eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1
        matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask

        all_predictions = tf.reshape(tf.concat(all_predictions, axis=1), (batch_size, num_return_sequences, -1))
        return {
            "iterations": i + 1,
            "input_ids": input_ids_original,
            "predicted_ids": all_predictions,
            "matched_eos_pos": matched_positions,
        }