def recurrent_transformer_decoder(decoder_input,
                                  encoder_output,
                                  decoder_self_attention_bias,
                                  encoder_decoder_attention_bias,
                                  hparams,
                                  name="decoder",
                                  nonpadding=None,
                                  save_weights_to=None,
                                  make_image_summary=True):
    """Recurrent decoder function."""
    x = decoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))
    with tf.variable_scope(name):
        ffn_unit = functools.partial(
            # use encoder ffn, since decoder ffn use left padding
            universal_transformer_util.transformer_encoder_ffn_unit,
            hparams=hparams,
            nonpadding_mask=nonpadding)

        attention_unit = functools.partial(
            universal_transformer_util.transformer_decoder_attention_unit,
            hparams=hparams,
            encoder_output=encoder_output,
            decoder_self_attention_bias=decoder_self_attention_bias,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            attention_dropout_broadcast_dims=attention_dropout_broadcast_dims,
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)

        x, extra_output = universal_transformer_util.universal_transformer_layer(
            x, hparams, ffn_unit, attention_unit)

        return common_layers.layer_preprocess(x, hparams), extra_output
def recurrent_transformer_decoder(
    decoder_input,
    encoder_output,
    decoder_self_attention_bias,
    encoder_decoder_attention_bias,
    hparams,
    name="decoder",
    nonpadding=None,
    save_weights_to=None,
    make_image_summary=True):
  """Recurrent decoder function."""
  x = decoder_input
  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
  with tf.variable_scope(name):
    ffn_unit = functools.partial(
        # use encoder ffn, since decoder ffn use left padding
        universal_transformer_util.transformer_encoder_ffn_unit,
        hparams=hparams,
        nonpadding_mask=nonpadding)

    attention_unit = functools.partial(
        universal_transformer_util.transformer_decoder_attention_unit,
        hparams=hparams,
        encoder_output=encoder_output,
        decoder_self_attention_bias=decoder_self_attention_bias,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        attention_dropout_broadcast_dims=attention_dropout_broadcast_dims,
        save_weights_to=save_weights_to,
        make_image_summary=make_image_summary)

    x, extra_output = universal_transformer_util.universal_transformer_layer(
        x, hparams, ffn_unit, attention_unit)

    return common_layers.layer_preprocess(x, hparams), extra_output