def body(self, features): """Universal Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams assert self.has_input, ("universal_transformer_encoder is applicable on " "problems with inputs") inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, enc_extra_output = self.encode( inputs, target_space, hparams, features=features) encoder_output = tf.expand_dims(encoder_output, 2) if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: ponder_times, remainders = enc_extra_output act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times + remainders) contrib.summary().scalar("act_loss", act_loss) return encoder_output, {"act_loss": act_loss} return encoder_output
def body(self, features): """Universal Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if hparams.add_position_timing_signal: # Turning off addition of positional embedding in the encoder/decoder # preparation as we do it in the beginning of each step. hparams.pos = None if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = self.encode( inputs, target_space, hparams, features=features) else: (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = (None, None, (None, None)) targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias) = transformer.transformer_prepare_decoder( targets, hparams, features=features) decoder_output, dec_extra_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: if self.has_input: enc_ponder_times, enc_remainders = enc_extra_output enc_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) else: enc_act_loss = 0.0 (dec_ponder_times, dec_remainders) = dec_extra_output dec_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss contrib.summary().scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output