def _build_decoder_lm_loss(self, central_lang_tag="<en>"): """Builds LM loss on the auxiliary decodings.""" # Get target embeddigns and vocab size. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Build auxiliary sequences (if necessary). aux_keys = self._build_aux_sequences( target_embeddings, target_vocab_size, central_lang_tag=central_lang_tag) # Make sure LM loss does not affect embeddings. target_embeddings = tf.stop_gradient(target_embeddings) # Build loss. aux_loss = 0. with tf.name_scope("aux_lm_loss"): for key in aux_keys: dec_outputs = tf.expand_dims( self.dec_outputs[key]["rnn_output"], axis=2) dec_output_tags = tf.expand_dims( self.inputs[key][0][:, :1], axis=2) dec_lengths = self.dec_outputs[key]["length"] # Preprocess LM features. lm_features = { "targets": dec_outputs, "target_tags": dec_output_tags} inputs, inputs_length = self._build_lm_inputs(lm_features) # Build LM (with frozen weights in PREDICT mode). lm_outputs = self.language_model( inputs=inputs, inputs_length=inputs_length, mode=tf.estimator.ModeKeys.PREDICT, hparams=self._hparams, trainable=False, reuse=tf.AUTO_REUSE) # Compute logits. lm_logits = model_utils.build_logits( sequences=tf.expand_dims(lm_outputs, axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) # Compute decoder probabilities. dec_logits = model_utils.build_logits( sequences=dec_outputs, embeddings=target_embeddings, vocab_size=target_vocab_size) dec_probs = tf.nn.softmax(dec_logits, axis=-1) # Compute cross-entropy loss. aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=False)( lm_logits, dec_probs, dec_lengths) aux_loss = self._hparams.lm_loss_coeff * aux_loss return aux_loss
def _build_aux_sequences(self, target_embeddings, target_vocab_size, central_lang_tag="<en>"): """Builds sequences in an auxiliary language.""" aux_keys = ["src>aux", "tgt>aux"] # Determine which src and tgt sentences are central. central_lang_id = translate_multilingual.get_tag_id(central_lang_tag) self._is_central = { "src>aux": tf.squeeze( self._body_features["input_tags_raw"] == central_lang_id), "tgt>aux": tf.squeeze( self._body_features["target_tags_raw"] == central_lang_id) } for key in aux_keys: # Encode (if necessary). if key not in self.enc_outputs: encode_func = self.get_encode_func(*self.inputs[key]) self.enc_outputs[key] = encode_func() # Decode (if necessary). if key not in self.dec_outputs: # Prepare for decoding. target_seqs, target_lens = self.targets[key] hiddens = self.enc_outputs[key].outputs hiddens_length = self.inputs[key][1] enc_state = self.enc_outputs[key].final_state decoder_hparams = contrib_training.HParams(auxiliary=True) # Decode. decode_func = self.get_decode_func( target_embeddings, target_seqs, target_lens, hiddens, hiddens_length, enc_state, mode=self._hparams.mode, decoder_hparams=decoder_hparams, decoder_iterations=self._hparams.aux_decode_length) self.dec_outputs[key] = decode_func() # Compute logits. self.dec_outputs[key]["logits"] = model_utils.build_logits( sequences=tf.expand_dims( self.dec_outputs[key]["rnn_output"], axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) # Protect central directions from the gradients. for element in self.dec_outputs[key]: self.dec_outputs[key][element] = tf.where( self._is_central[key], tf.stop_gradient(self.dec_outputs[key][element]), self.dec_outputs[key][element]) return aux_keys
def _top_single(self, body_output, target_modality, features): """Top transformation that ensures correct reuse of target embeddings.""" t2t_model.log_info( "Transforming body output with %s.top", target_modality.name) # Get target embeddings. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Preprocess body output. last_only = ( target_modality.top_is_pointwise and self.hparams.mode == tf.estimator.ModeKeys.PREDICT and not self.hparams.force_full_predict) if last_only: # Take body outputs for the last position only. if "decode_loop_step" not in features: body_output = tf.expand_dims(body_output[:, -1, :, :], axis=[1]) else: body_output_shape = body_output.shape.as_list() body_output = tf.slice( body_output, [0, features["decode_loop_step"][0], 0, 0], [ body_output_shape[0], 1, body_output_shape[2], body_output_shape[3] ]) # Build logits. logits = model_utils.build_logits( sequences=body_output, embeddings=target_embeddings, vocab_size=target_vocab_size) return logits
def _build_decoder_agreement_loss(self, central_lang_tag="<en>"): """Builds an agreement loss that enforces consistency of the decodings. Args: central_lang_tag: A string with the tag of the central language. A ``central'' language (usually English) is the one that has parallel data with all other languages. It is used to protect supervised directions from gradients coming from auxiliary losses. Returns: loss: <float32> [] for the agreement losses. """ # Get target embeddigns and vocab size. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Build auxiliary sequences (if necessary). aux_keys = self._build_aux_sequences(target_embeddings, target_vocab_size, central_lang_tag=central_lang_tag) # Build loss. aux_loss = 0. with tf.name_scope("dec_agreement_loss"): for key1, key2 in zip(aux_keys, aux_keys[::-1]): # Prepare for decoding. targets = self.dec_outputs[key2]["rnn_output"] targets_length = self.dec_outputs[key2]["length"] shifted_targets = common_layers.shift_right_3d(targets) hiddens = self.enc_outputs[key1].outputs hiddens_length = self.inputs[key1][1] enc_state = self.enc_outputs[key1].final_state # Decode. decode_func = self.get_decode_func( target_embeddings, shifted_targets, targets_length, hiddens, hiddens_length, enc_state, mode=tf.estimator.ModeKeys.PREDICT, decoder_iterations=self._hparams.aux_decode_length) aux_dec_outputs = decode_func() # Compute logits (protect central directions from the gradients). aux_logits_1 = model_utils.build_logits( sequences=tf.expand_dims(aux_dec_outputs["rnn_output"], axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) aux_logits_1 = tf.where(self._is_central[key1], tf.stop_gradient(aux_logits_1), aux_logits_1) # Compute KL loss. logits = tf.squeeze(aux_logits_1, axis=2) if self._hparams.dec_agreement_loss_sparse: target_ids = self.dec_outputs[key2]["sample_id"] aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)( logits, target_ids, targets_length) else: aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"], axis=2) target_probs = tf.nn.softmax(aux_logits_2, axis=-1) aux_loss = aux_loss + losses.CrossEntropyLoss( sparse=False)(logits, target_probs, targets_length) aux_loss = self._hparams.dec_agreement_coeff * aux_loss return aux_loss