def body(self, features, decode_step=None, cache=None, decoding_stats=None):
    targets = tf.to_int32(features["targets"])
    if self.is_decode:
      targets = self.process_partial_targets_decoding(targets)
    decoder_input = self.prepare_decoder(targets)
    extra_losses = []

    if not self.hparams.unconditional:  # condition on class label
      if not self.is_inputs_class_label:
        raise ValueError("SparseImagetransformer can only condition on "
                         "'inputs' feature if it represents class label.")
      inputs = features["inputs"]

      # Embed class here rather than in bottom().
      if inputs.dtype not in [tf.int32, tf.int64]:
        raise ValueError("Do not embed 'inputs' before body(). "
                         "Found dtype=%s." % inputs.dtype)
      inputs = utils.get_embeddings(
          targets=inputs,
          vocab_size=self.inputs_vocab_size,
          hidden_size=self.hidden_size,
          name="class_conditional_embedding")

      # Add class embedding to each spatial location.
      batch_size = tf.shape(targets)[0]
      hidden_size = tf.shape(inputs)[-1]
      num_middle_dims = len(decoder_input.shape) - 2
      decoder_input += tf.reshape(inputs, [batch_size] + [1] * num_middle_dims +
                                  [hidden_size])

    decoder_output = utils.transformer_decoder_layers(
        inputs=decoder_input,
        num_layers=self.num_decoder_layers,
        hparams=self.hparams,
        decode_step=decode_step,
        losses=extra_losses,
        cache=cache,
        name="decoder",
        decoding_stats=decoding_stats)
    logits = self.produce_output(decoder_output)

    # Return logits as-is in decoding mode
    if self.is_decode:
      return logits

    # Produce a summary of the output.
    results = self.multinomial_squeeze(logits, self.hparams.sampling_temp)
    results = tf.reshape(
        results, [-1, self.frame_height, self.frame_width, self.num_channels])
    if utils.is_xla_compiled():
      _IMGS["predictions"] = results

    # Prepare loss.
    loss_dict = {}
    if extra_losses:
      loss_dict["extra_loss"] = tf.add_n(extra_losses)
    return logits, loss_dict
    def body(self,
             features,
             decode_step=None,
             cache=None,
             decoding_stats=None,
             add_summary=True):
        encoder_output = None
        extra_losses = []
        padding_bias = None
        if not self.hparams.fast_decode:
            decode_step = None
        if "inputs" in features:
            inputs = features["inputs"]
            # remove the last two dimensions that are always 1.
            inputs = tf.reshape(
                inputs,
                utils.shape_list(inputs)[:2] + [self.hidden_size])
            # Padding bias only used for seq2seq models.
            padding_bias = utils.embedding_to_padding(inputs)
            # Mask random positions
            shape = utils.shape_list(inputs)
            if self.hparams.input_dropout:
                inputs = tf.where(
                    tf.random.uniform(shape) < self.hparams.input_dropout,
                    tf.zeros_like(inputs), inputs)
            if self.hparams.add_timing_signal:
                inputs += utils.get_timing_signal_1d(self.hparams.max_length,
                                                     self.hidden_size)
            if cache is not None and -1 in cache:
                encoder_output = cache[-1]
            else:
                encoder_output = utils.transformer_encoder_layers(
                    inputs=inputs,
                    num_layers=self.num_encoder_layers,
                    hparams=self.hparams,
                    losses=extra_losses,
                    name="encoder",
                    token_bias=features.get("token_bias_inputs"),
                    padding_bias=padding_bias)
            if cache is not None and -1 not in cache:
                cache[-1] = encoder_output
        targets = tf.to_int32(features["targets"])
        # remove the last two dimensions that are always 1.
        targets = tf.reshape(targets, utils.shape_list(targets)[:2])
        # Clamp targets to max_target_length
        targets = targets[:, :self.hparams.max_target_length]
        if self.is_decode:
            targets = self.process_partial_targets_decoding(targets)
        decoder_input = self.prepare_decoder(targets)

        decoder_output = utils.transformer_decoder_layers(
            inputs=decoder_input,
            num_layers=self.num_decoder_layers,
            hparams=self.hparams,
            encoder_output=encoder_output,
            decode_step=decode_step,
            losses=extra_losses,
            cache=cache,
            name="decoder",
            decoding_stats=decoding_stats,
            token_bias_inputs=features.get("token_bias_inputs"),
            token_bias_targets=features.get("token_bias_targets"),
            padding_bias=padding_bias)
        logits = self.produce_output(decoder_output)

        # Return logits as-is in decoding mode
        if self.is_decode:
            return logits

        # Add cross entropy loss
        one_hot_targets = tf.one_hot(tf.cast(targets, dtype=tf.int32),
                                     self.vocab_size)
        x_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_targets, logits=logits)
        weights = tf.to_float(tf.not_equal(targets, 0))
        loss = tf.reduce_sum(x_entropy * weights) / tf.reduce_sum(weights)
        if add_summary:
            tf.summary.scalar("losses/weight", tf.reduce_sum(weights))
            tf.summary.scalar("losses/x_entropy",
                              tf.reduce_sum(x_entropy * weights))

        loss_dict = {"training": loss}
        if extra_losses:
            loss_dict["extra_loss"] = tf.add_n(extra_losses)
        # hack for T2T metrics
        logits = tf.reshape(
            logits,
            utils.shape_list(logits)[:2] + [1, 1] +
            utils.shape_list(logits)[-1:])
        return logits, loss_dict