def body(self, features): """Process features and produce outputs.""" # Preprocess features. inputs, inputs_length, targets, targets_length = self._preprocess( features) # Encode. encoder = encoders.get(self._hparams.encoder_type) enc_outputs = encoder(inputs=inputs, inputs_length=inputs_length, mode=self._hparams.mode, hparams=self._hparams) # Get target embeddings. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") # Decode. decoder = decoders.get(self._hparams.decoder_type) decoder = decoder(embeddings=target_embeddings, inputs=targets, inputs_length=targets_length, hiddens=enc_outputs.outputs, hiddens_length=inputs_length, enc_state=enc_outputs.final_state, mode=self._hparams.mode, hparams=self._hparams) dec_outputs, _, _ = contrib_seq2seq.dynamic_decode(decoder=decoder) return tf.expand_dims(dec_outputs.rnn_output, axis=2)
def _build_decoder_lm_loss(self, central_lang_tag="<en>"): """Builds LM loss on the auxiliary decodings.""" # Get target embeddigns and vocab size. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Build auxiliary sequences (if necessary). aux_keys = self._build_aux_sequences( target_embeddings, target_vocab_size, central_lang_tag=central_lang_tag) # Make sure LM loss does not affect embeddings. target_embeddings = tf.stop_gradient(target_embeddings) # Build loss. aux_loss = 0. with tf.name_scope("aux_lm_loss"): for key in aux_keys: dec_outputs = tf.expand_dims( self.dec_outputs[key]["rnn_output"], axis=2) dec_output_tags = tf.expand_dims( self.inputs[key][0][:, :1], axis=2) dec_lengths = self.dec_outputs[key]["length"] # Preprocess LM features. lm_features = { "targets": dec_outputs, "target_tags": dec_output_tags} inputs, inputs_length = self._build_lm_inputs(lm_features) # Build LM (with frozen weights in PREDICT mode). lm_outputs = self.language_model( inputs=inputs, inputs_length=inputs_length, mode=tf.estimator.ModeKeys.PREDICT, hparams=self._hparams, trainable=False, reuse=tf.AUTO_REUSE) # Compute logits. lm_logits = model_utils.build_logits( sequences=tf.expand_dims(lm_outputs, axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) # Compute decoder probabilities. dec_logits = model_utils.build_logits( sequences=dec_outputs, embeddings=target_embeddings, vocab_size=target_vocab_size) dec_probs = tf.nn.softmax(dec_logits, axis=-1) # Compute cross-entropy loss. aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=False)( lm_logits, dec_probs, dec_lengths) aux_loss = self._hparams.lm_loss_coeff * aux_loss return aux_loss
def _top_single(self, body_output, target_modality, features): """Top transformation that ensures correct reuse of target embeddings.""" t2t_model.log_info( "Transforming body output with %s.top", target_modality.name) # Get target embeddings. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Preprocess body output. last_only = ( target_modality.top_is_pointwise and self.hparams.mode == tf.estimator.ModeKeys.PREDICT and not self.hparams.force_full_predict) if last_only: # Take body outputs for the last position only. if "decode_loop_step" not in features: body_output = tf.expand_dims(body_output[:, -1, :, :], axis=[1]) else: body_output_shape = body_output.shape.as_list() body_output = tf.slice( body_output, [0, features["decode_loop_step"][0], 0, 0], [ body_output_shape[0], 1, body_output_shape[2], body_output_shape[3] ]) # Build logits. logits = model_utils.build_logits( sequences=body_output, embeddings=target_embeddings, vocab_size=target_vocab_size) return logits
def body(self, features): # Save a reference to the features to access in other methods. self._body_features = features # Preprocess features. self.inputs, self.targets = self._preprocess(features) # Ensure auxiliary parts of the graph are built when necessary. batch_size = common_layers.shape_list(features["inputs"])[0] global_step = model_utils.get_global_step(self._hparams) # Encode (src>tgt). key = "src>tgt" self.enc_outputs = {} self.encoder = encoders.get(self._hparams.encoder_type) encode_func = self.get_encode_func(*self.inputs[key]) self.enc_outputs[key] = encode_func() # Get target embeddings. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") # Decode (src>tgt). key = "src>tgt" self.decoders = {} self.dec_outputs = {} self.decoder = decoders.get(self._hparams.decoder_type) # Prepare for decoding. target_seqs, target_lens = self.targets[key] hiddens = self.enc_outputs[key].outputs hiddens_length = self.inputs[key][1] enc_state = self.enc_outputs[key].final_state # Decode. decode_func = self.get_decode_func(target_embeddings, target_seqs, target_lens, hiddens, hiddens_length, enc_state, mode=self._hparams.mode) self.dec_outputs[key] = decode_func() outputs = tf.expand_dims(self.dec_outputs[key]["rnn_output"], axis=2) # Construct agreement losses. aux_losses = {} if self._hparams.mode == tf.estimator.ModeKeys.TRAIN: if self._hparams.enc_agreement_coeff > 0: aux_losses["agreement_enc"] = tf.cond( global_step > self._hparams.enc_agreement_enable_step, self._build_encoder_agreement_loss, lambda: tf.zeros([batch_size])) if self._hparams.dec_agreement_coeff > 0: aux_losses["agreement_dec"] = tf.cond( global_step > self._hparams.dec_agreement_enable_step, self._build_decoder_agreement_loss, lambda: tf.zeros([batch_size])) return outputs, aux_losses
def _build_decoder_agreement_loss(self, central_lang_tag="<en>"): """Builds an agreement loss that enforces consistency of the decodings. Args: central_lang_tag: A string with the tag of the central language. A ``central'' language (usually English) is the one that has parallel data with all other languages. It is used to protect supervised directions from gradients coming from auxiliary losses. Returns: loss: <float32> [] for the agreement losses. """ # Get target embeddigns and vocab size. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Build auxiliary sequences (if necessary). aux_keys = self._build_aux_sequences(target_embeddings, target_vocab_size, central_lang_tag=central_lang_tag) # Build loss. aux_loss = 0. with tf.name_scope("dec_agreement_loss"): for key1, key2 in zip(aux_keys, aux_keys[::-1]): # Prepare for decoding. targets = self.dec_outputs[key2]["rnn_output"] targets_length = self.dec_outputs[key2]["length"] shifted_targets = common_layers.shift_right_3d(targets) hiddens = self.enc_outputs[key1].outputs hiddens_length = self.inputs[key1][1] enc_state = self.enc_outputs[key1].final_state # Decode. decode_func = self.get_decode_func( target_embeddings, shifted_targets, targets_length, hiddens, hiddens_length, enc_state, mode=tf.estimator.ModeKeys.PREDICT, decoder_iterations=self._hparams.aux_decode_length) aux_dec_outputs = decode_func() # Compute logits (protect central directions from the gradients). aux_logits_1 = model_utils.build_logits( sequences=tf.expand_dims(aux_dec_outputs["rnn_output"], axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) aux_logits_1 = tf.where(self._is_central[key1], tf.stop_gradient(aux_logits_1), aux_logits_1) # Compute KL loss. logits = tf.squeeze(aux_logits_1, axis=2) if self._hparams.dec_agreement_loss_sparse: target_ids = self.dec_outputs[key2]["sample_id"] aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)( logits, target_ids, targets_length) else: aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"], axis=2) target_probs = tf.nn.softmax(aux_logits_2, axis=-1) aux_loss = aux_loss + losses.CrossEntropyLoss( sparse=False)(logits, target_probs, targets_length) aux_loss = self._hparams.dec_agreement_coeff * aux_loss return aux_loss