def predict(self, features, max_decode_len, beam_size, **beam_kwargs):
        """Predict."""
        cache = self._encode(features, False)
        B, _, D = cache["memory"].shape
        T, V, H = max_decode_len, self._vocab_size, self._num_heads

        bias_1xTxT = attention.upper_triangle_bias(T, self._dtype)
        for i in range(len(self._decoder_layers)):
            cache[str(i)] = {
                "k": tf.zeros([B, H, T, D // H], self._dtype),
                "v": tf.zeros([B, H, T, D // H], self._dtype)
            }

        def symbols_to_logits_fn(dec_BxT, context, i):
            """Decode loop."""
            dec_Bx1 = tf.slice(dec_BxT, [0, tf.maximum(tf.cast(0, i.dtype), i - 1)],
                               [dec_BxT.shape[0], 1])
            bias_1x1xT = tf.slice(bias_1xTxT, [0, i, 0], [1, 1, T])
            dec_Bx1xD = self._embedding_layer(dec_Bx1, True)
            dec_Bx1xD *= tf.cast(tf.greater(i, 0), self._dtype)
            dec_Bx1xD = timing.add_time_signal(dec_Bx1xD, start_index=i)
            with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE):
                dec_Bx1xD = transformer_block.stack(self._decoder_layers, False,
                                                    dec_Bx1xD, bias_1x1xT,
                                                    context["memory"],
                                                    context["memory_bias"], context, i)
                dec_Bx1xD = contrib_layers.layer_norm(dec_Bx1xD, begin_norm_axis=2)
            logits_Bx1xV = self._embedding_layer(dec_Bx1xD, False)
            logits_BxV = tf.squeeze(logits_Bx1xV, axis=1)
            return logits_BxV

        decodes_BxT = decoding.left2right_decode(symbols_to_logits_fn, cache, B, T,
                                                 V, beam_size, **beam_kwargs)
        return {"outputs": decodes_BxT}
    def __call__(self, features, training):
        """Create model.

    Args:
      features: dictionary of tensors including "inputs" [batch, input_len] and
        "targets" [batch, output_len]
      training: bool of whether the mode is training.

    Returns:
     Tuple of (loss, outputs): Loss is a scalar. Output is a dictionary of
       tensors, containing model's output logits.
    """
        if "inputs" not in features or "targets" not in features:
            raise ValueError("Require inputs and targets keys in features.")

        context = self._encode(features, training)
        self._context = context
        targets_BxT = features["targets"]
        bias_1xTxT = attention.upper_triangle_bias(
            tf.shape(targets_BxT)[1], self._dtype)
        states_BxTxD = self._embedding_layer(targets_BxT, True)
        states_BxTxD = tf.pad(states_BxTxD,
                              [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        states_BxTxD = timing.add_time_signal(states_BxTxD)
        states_BxTxD = self._dropout_fn(states_BxTxD, training)
        with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE):
            states_BxTxD = transformer_block.stack(self._decoder_layers,
                                                   training, states_BxTxD,
                                                   bias_1xTxT,
                                                   context["memory"],
                                                   context["memory_bias"])
            states_BxTxD = contrib_layers.layer_norm(states_BxTxD,
                                                     begin_norm_axis=2)
        logits_BxTxV = self._embedding_layer(states_BxTxD, False)
        targets_mask_BxT = tf.cast(tf.greater(targets_BxT, 0), self._dtype)

        XENT_loss = tf.losses.softmax_cross_entropy(
            tf.one_hot(targets_BxT, self._vocab_size),
            logits_BxTxV,
            label_smoothing=self._label_smoothing,
            weights=targets_mask_BxT)

        # want the one hot targets for sampling
        one_hot_targets = tf.one_hot(targets_BxT, self._vocab_size)

        return XENT_loss, {
            "logits": logits_BxTxV,
            "targets": targets_BxT,
            "one_hot_targets": one_hot_targets,
            "hidden_states": states_BxTxD,
            "context_memory": context["memory"],
            "context_bias": context["memory_bias"]
        }
    def double_sampling(self,
                        features,
                        training,
                        batchsize,
                        seqlen,
                        mixed=False):
        if "inputs" not in features or "targets" not in features:
            raise ValueError("Require inputs and targets keys in features.")

        # First "loop" - uses ground truth to supplement
        context = self._encode(features, training)
        self._context = context
        targets_BxT = features["targets"]
        bias_1xTxT = attention.upper_triangle_bias(
            tf.shape(targets_BxT)[1], self._dtype)
        states_BxTxD = self._embedding_layer(targets_BxT, True)
        states_BxTxD = tf.pad(states_BxTxD,
                              [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        states_BxTxD = timing.add_time_signal(states_BxTxD)
        states_BxTxD = self._dropout_fn(states_BxTxD, training)
        with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE):
            states_BxTxD = transformer_block.stack(self._decoder_layers,
                                                   training, states_BxTxD,
                                                   bias_1xTxT,
                                                   context["memory"],
                                                   context["memory_bias"])
            states_BxTxD = contrib_layers.layer_norm(states_BxTxD,
                                                     begin_norm_axis=2)
        logits_BxTxV = self._embedding_layer(states_BxTxD, False)
        targets_mask_BxT = tf.cast(tf.greater(targets_BxT, 0), self._dtype)

        # argmax the logits to get teacher-forcing sequence
        # ensure this does not have any EOS apart from the end token, before passing into next loop.
        new_input = tf.reshape(tf.math.argmax(logits_BxTxV, axis=2),
                               [batchsize, seqlen])

        # nucleus or top-k processing
        # new_input = iid_process_logits(logits_BxTxV, seqlen, batchsize, logits_BxTxV.get_shape().as_list()[-1],
        #                                top_k=0, top_p=0.9, temperature=1.0)

        # def tensor_loop(i, max_decode_len, logits, new_input, unused_targets_BxT):
        #     def f2(logits_BxTxV, new_input):
        #         topk_probs, topk_indices = tf.math.top_k(logits_BxTxV[0, i], k=2)
        #         topk_inds2 = tf.slice(topk_indices, [1], [1, ])
        #         return tf.tensor_scatter_nd_update(new_input, [[0, i]], tf.cast(topk_inds2, tf.int64))

        #     def f3(i, new_input):
        #         new_input2 = new_input[0].numpy().tolist()
        #         return True if new_input2[0][i] == 1 else False

        #     new_input = tf.cond(tf.py_function(f3, (i, [new_input]), tf.bool), lambda: f2(logits_BxTxV, new_input),
        #                         lambda: new_input)
        #     return i + 1, max_decode_len, logits, new_input, unused_targets_BxT

        # def finish_cond_ref(i, max_decode_len, unused_logits, unused_new_input, targets_BxT):
        # add here condition to return reference summary length

        #     def f4(i, targets, max_len):
        #         targets2 = targets[0].numpy().tolist()
        #         if targets2[i] == 0:  # padded token
        #           return i
        #         else:  # if not 1, still needs a number to refer to
        #           return max_len

        #     ref_len = tf.py_function(f4, (i, targets_BxT, max_decode_len), tf.int32)
        #     return i < ref_len  # T/F -> will change depending on padded token presence

        # def finish_cond_max(i, max_decode_len, unused_logits, unused_new_input, unused_targets_BxT):
        #     return i < max_decode_len

        # _, _, _, new_input, _ = tf.while_loop(finish_cond_max, tensor_loop,
        #                                       [0, seqlen, logits_BxTxV, new_input, targets_BxT])

        # find target length -> py.func() as it has to be outside graph
        def f5(targets_BxT):
            try:
                exist = targets_BxT[0].numpy().tolist().index(1)
                # do they all have an EOS?
            except ValueError:
                exist = targets_BxT.get_shape().as_list()[-1]

            return tf.Variable(
                exist, shape=()).read_value()  # token prior is the last token

        cut_off = tf.py_function(f5, [targets_BxT], tf.int32)

        # implement cut_off for new_input
        new_input2 = tf.slice(new_input, [0, 0], [1, cut_off])
        new_input = tf.reshape(
            tf.pad(
                new_input2,
                [[0, 0], [0, new_input.get_shape().as_list()[-1] - cut_off]],
                "CONSTANT"), [batchsize, seqlen])

        # Second "loop" - uses predicted sequence as input
        context_2 = self._encode(features, training)
        # targets_BxT = features["targets"]
        bias_1xTxT_2 = attention.upper_triangle_bias(
            tf.shape(new_input)[1], self._dtype)
        states_BxTxD_2 = self._embedding_layer(new_input, True)
        states_BxTxD_2 = tf.pad(states_BxTxD_2,
                                [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        states_BxTxD_2 = timing.add_time_signal(states_BxTxD_2)
        states_BxTxD_2 = self._dropout_fn(states_BxTxD_2, training)
        with tf.variable_scope(self._decoder_scope_name, reuse=tf.AUTO_REUSE):
            states_BxTxD_2 = transformer_block.stack(self._decoder_layers,
                                                     training, states_BxTxD_2,
                                                     bias_1xTxT_2,
                                                     context_2["memory"],
                                                     context_2["memory_bias"])
            states_BxTxD_2 = contrib_layers.layer_norm(states_BxTxD_2,
                                                       begin_norm_axis=2)
        logits_BxTxV_2 = self._embedding_layer(states_BxTxD_2, False)
        targets_mask_BxT_2 = tf.cast(tf.greater(new_input, 0), self._dtype)

        # mixed parallel scheduled sampling
        if mixed:
            bool_mask = np.random.choice(
                [True, False],
                [batchsize, seqlen,
                 logits_BxTxV.get_shape().as_list()[2]],
                p=[0.75, 0.25])
            mixed_logits = tf.where(bool_mask, logits_BxTxV, logits_BxTxV_2)
        else:
            mixed_logits = tf.zeros(logits_BxTxV.get_shape().as_list()
                                    )  # empty - satisfies exit criteria

        XENT_loss = tf.losses.softmax_cross_entropy(
            tf.one_hot(new_input, self._vocab_size),
            mixed_logits,
            label_smoothing=self._label_smoothing,
            weights=targets_mask_BxT_2)

        # want the one hot targets for sampling
        one_hot_targets = tf.one_hot(new_input, self._vocab_size)
        return XENT_loss, {
            "sampled_BxT": new_input,
            "logits1": logits_BxTxV,
            "logits2": logits_BxTxV_2,
            "targets": targets_BxT,
            "mixed_logits": mixed_logits
        }