Exemple #1
0
  def body(self, features):
    """Universal Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    hparams = self._hparams

    assert self.has_input, ("universal_transformer_encoder is applicable on "
                            "problems with inputs")

    inputs = features["inputs"]
    target_space = features["target_space_id"]
    encoder_output, enc_extra_output = self.encode(
        inputs, target_space, hparams, features=features)

    encoder_output = tf.expand_dims(encoder_output, 2)

    if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0:
      ponder_times, remainders = enc_extra_output
      act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times +
                                                          remainders)
      contrib.summary().scalar("act_loss", act_loss)

      return encoder_output, {"act_loss": act_loss}
    return encoder_output
Exemple #2
0
  def body(self, features):
    """Universal Transformer main model_fn.


    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    hparams = self._hparams
    if hparams.add_position_timing_signal:
      # Turning off addition of positional embedding in the encoder/decoder
      # preparation as we do it in the beginning of each step.
      hparams.pos = None

    if self.has_input:
      inputs = features["inputs"]
      target_space = features["target_space_id"]
      (encoder_output, encoder_decoder_attention_bias,
       enc_extra_output) = self.encode(
           inputs, target_space, hparams, features=features)
    else:
      (encoder_output, encoder_decoder_attention_bias,
       enc_extra_output) = (None, None, (None, None))

    targets = features["targets"]
    targets = common_layers.flatten4d3d(targets)

    (decoder_input,
     decoder_self_attention_bias) = transformer.transformer_prepare_decoder(
         targets, hparams, features=features)

    decoder_output, dec_extra_output = self.decode(
        decoder_input,
        encoder_output,
        encoder_decoder_attention_bias,
        decoder_self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, "targets"))

    expected_attentions = features.get("expected_attentions")
    if expected_attentions is not None:
      attention_loss = common_attention.encoder_decoder_attention_loss(
          expected_attentions, self.attention_weights,
          hparams.expected_attention_loss_type,
          hparams.expected_attention_loss_multiplier)
      return decoder_output, {"attention_loss": attention_loss}

    if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0:
      if self.has_input:
        enc_ponder_times, enc_remainders = enc_extra_output
        enc_act_loss = (
            hparams.act_loss_weight *
            tf.reduce_mean(enc_ponder_times + enc_remainders))
      else:
        enc_act_loss = 0.0

      (dec_ponder_times, dec_remainders) = dec_extra_output
      dec_act_loss = (
          hparams.act_loss_weight *
          tf.reduce_mean(dec_ponder_times + dec_remainders))
      act_loss = enc_act_loss + dec_act_loss
      contrib.summary().scalar("act_loss", act_loss)
      return decoder_output, {"act_loss": act_loss}

    return decoder_output