def prepare_decoder(self, targets):
        """Prepares targets for transformer decoder."""
        shape = utils.shape_list(targets)
        # sequence should be [batch, seq_length]
        assert len(shape) == 2, "Sequence tensors should be 2-dimensional"
        assert len(self.hparams.query_shape
                   ) == 1, "query shape should be 1-dimensional"

        # Mask random positions
        if self.hparams.target_dropout:
            targets = tf.where(
                tf.random.uniform(shape) < self.hparams.target_dropout,
                tf.zeros_like(targets), targets)
        # Shift positions
        targets = tf.expand_dims(targets, axis=-1)
        targets = utils.right_shift_blockwise_nd(targets,
                                                 self.hparams.query_shape)
        targets = tf.squeeze(targets, axis=-1)
        # Add token embeddings
        targets = utils.get_embeddings(targets=targets,
                                       hidden_size=self.hparams.embedding_dims,
                                       vocab_size=self.vocab_size)
        if self.hparams.dropout:
            targets = tf.nn.dropout(targets, 1 - self.hparams.dropout)
        targets = tf.layers.dense(targets,
                                  self.hidden_size,
                                  activation=None,
                                  name="emb_dense")
        if self.hparams.add_timing_signal:
            targets += utils.get_timing_signal_1d(
                self.hparams.max_target_length, self.hidden_size)
        return targets
  def body(self, features, decode_step=None, cache=None, decoding_stats=None):
    targets = tf.to_int32(features["targets"])
    if self.is_decode:
      targets = self.process_partial_targets_decoding(targets)
    decoder_input = self.prepare_decoder(targets)
    extra_losses = []

    if not self.hparams.unconditional:  # condition on class label
      if not self.is_inputs_class_label:
        raise ValueError("SparseImagetransformer can only condition on "
                         "'inputs' feature if it represents class label.")
      inputs = features["inputs"]

      # Embed class here rather than in bottom().
      if inputs.dtype not in [tf.int32, tf.int64]:
        raise ValueError("Do not embed 'inputs' before body(). "
                         "Found dtype=%s." % inputs.dtype)
      inputs = utils.get_embeddings(
          targets=inputs,
          vocab_size=self.inputs_vocab_size,
          hidden_size=self.hidden_size,
          name="class_conditional_embedding")

      # Add class embedding to each spatial location.
      batch_size = tf.shape(targets)[0]
      hidden_size = tf.shape(inputs)[-1]
      num_middle_dims = len(decoder_input.shape) - 2
      decoder_input += tf.reshape(inputs, [batch_size] + [1] * num_middle_dims +
                                  [hidden_size])

    decoder_output = utils.transformer_decoder_layers(
        inputs=decoder_input,
        num_layers=self.num_decoder_layers,
        hparams=self.hparams,
        decode_step=decode_step,
        losses=extra_losses,
        cache=cache,
        name="decoder",
        decoding_stats=decoding_stats)
    logits = self.produce_output(decoder_output)

    # Return logits as-is in decoding mode
    if self.is_decode:
      return logits

    # Produce a summary of the output.
    results = self.multinomial_squeeze(logits, self.hparams.sampling_temp)
    results = tf.reshape(
        results, [-1, self.frame_height, self.frame_width, self.num_channels])
    if utils.is_xla_compiled():
      _IMGS["predictions"] = results

    # Prepare loss.
    loss_dict = {}
    if extra_losses:
      loss_dict["extra_loss"] = tf.add_n(extra_losses)
    return logits, loss_dict