示例#1
0
  def _build_decoder_lm_loss(self, central_lang_tag="<en>"):
    """Builds LM loss on the auxiliary decodings."""

    # Get target embeddigns and vocab size.
    target_modality = self._problem_hparams.modality["targets"]
    target_modality_scope = self._variable_scopes[target_modality.name]
    target_embeddings = model_utils.get_embeddings(
        modality=target_modality,
        outer_scope=target_modality_scope,
        inner_scope="shared")
    target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

    # Build auxiliary sequences (if necessary).
    aux_keys = self._build_aux_sequences(
        target_embeddings, target_vocab_size,
        central_lang_tag=central_lang_tag)

    # Make sure LM loss does not affect embeddings.
    target_embeddings = tf.stop_gradient(target_embeddings)

    # Build loss.
    aux_loss = 0.
    with tf.name_scope("aux_lm_loss"):
      for key in aux_keys:
        dec_outputs = tf.expand_dims(
            self.dec_outputs[key]["rnn_output"], axis=2)
        dec_output_tags = tf.expand_dims(
            self.inputs[key][0][:, :1], axis=2)
        dec_lengths = self.dec_outputs[key]["length"]
        # Preprocess LM features.
        lm_features = {
            "targets": dec_outputs,
            "target_tags": dec_output_tags}
        inputs, inputs_length = self._build_lm_inputs(lm_features)
        # Build LM (with frozen weights in PREDICT mode).
        lm_outputs = self.language_model(
            inputs=inputs,
            inputs_length=inputs_length,
            mode=tf.estimator.ModeKeys.PREDICT,
            hparams=self._hparams,
            trainable=False,
            reuse=tf.AUTO_REUSE)
        # Compute logits.
        lm_logits = model_utils.build_logits(
            sequences=tf.expand_dims(lm_outputs, axis=2),
            embeddings=target_embeddings,
            vocab_size=target_vocab_size)
        # Compute decoder probabilities.
        dec_logits = model_utils.build_logits(
            sequences=dec_outputs,
            embeddings=target_embeddings,
            vocab_size=target_vocab_size)
        dec_probs = tf.nn.softmax(dec_logits, axis=-1)
        # Compute cross-entropy loss.
        aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=False)(
            lm_logits, dec_probs, dec_lengths)

    aux_loss = self._hparams.lm_loss_coeff * aux_loss

    return aux_loss
示例#2
0
    def _build_aux_sequences(self,
                             target_embeddings,
                             target_vocab_size,
                             central_lang_tag="<en>"):
        """Builds sequences in an auxiliary language."""
        aux_keys = ["src>aux", "tgt>aux"]

        # Determine which src and tgt sentences are central.
        central_lang_id = translate_multilingual.get_tag_id(central_lang_tag)
        self._is_central = {
            "src>aux":
            tf.squeeze(
                self._body_features["input_tags_raw"] == central_lang_id),
            "tgt>aux":
            tf.squeeze(
                self._body_features["target_tags_raw"] == central_lang_id)
        }

        for key in aux_keys:
            # Encode (if necessary).
            if key not in self.enc_outputs:
                encode_func = self.get_encode_func(*self.inputs[key])
                self.enc_outputs[key] = encode_func()

            # Decode (if necessary).
            if key not in self.dec_outputs:
                # Prepare for decoding.
                target_seqs, target_lens = self.targets[key]
                hiddens = self.enc_outputs[key].outputs
                hiddens_length = self.inputs[key][1]
                enc_state = self.enc_outputs[key].final_state
                decoder_hparams = contrib_training.HParams(auxiliary=True)
                # Decode.
                decode_func = self.get_decode_func(
                    target_embeddings,
                    target_seqs,
                    target_lens,
                    hiddens,
                    hiddens_length,
                    enc_state,
                    mode=self._hparams.mode,
                    decoder_hparams=decoder_hparams,
                    decoder_iterations=self._hparams.aux_decode_length)
                self.dec_outputs[key] = decode_func()
                # Compute logits.
                self.dec_outputs[key]["logits"] = model_utils.build_logits(
                    sequences=tf.expand_dims(
                        self.dec_outputs[key]["rnn_output"], axis=2),
                    embeddings=target_embeddings,
                    vocab_size=target_vocab_size)
                # Protect central directions from the gradients.
                for element in self.dec_outputs[key]:
                    self.dec_outputs[key][element] = tf.where(
                        self._is_central[key],
                        tf.stop_gradient(self.dec_outputs[key][element]),
                        self.dec_outputs[key][element])

        return aux_keys
示例#3
0
  def _top_single(self, body_output, target_modality, features):
    """Top transformation that ensures correct reuse of target embeddings."""
    t2t_model.log_info(
        "Transforming body output with %s.top", target_modality.name)

    # Get target embeddings.
    target_modality = self._problem_hparams.modality["targets"]
    target_modality_scope = self._variable_scopes[target_modality.name]
    target_embeddings = model_utils.get_embeddings(
        modality=target_modality,
        outer_scope=target_modality_scope,
        inner_scope="shared")
    target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

    # Preprocess body output.
    last_only = (
        target_modality.top_is_pointwise and
        self.hparams.mode == tf.estimator.ModeKeys.PREDICT and
        not self.hparams.force_full_predict)
    if last_only:
      # Take body outputs for the last position only.
      if "decode_loop_step" not in features:
        body_output = tf.expand_dims(body_output[:, -1, :, :], axis=[1])
      else:
        body_output_shape = body_output.shape.as_list()
        body_output = tf.slice(
            body_output, [0, features["decode_loop_step"][0], 0, 0], [
                body_output_shape[0], 1, body_output_shape[2],
                body_output_shape[3]
            ])

    # Build logits.
    logits = model_utils.build_logits(
        sequences=body_output,
        embeddings=target_embeddings,
        vocab_size=target_vocab_size)
    return logits
示例#4
0
    def _build_decoder_agreement_loss(self, central_lang_tag="<en>"):
        """Builds an agreement loss that enforces consistency of the decodings.

    Args:
      central_lang_tag: A string with the tag of the central language.
        A ``central'' language (usually English) is the one that has parallel
        data with all other languages. It is used to protect supervised
        directions from gradients coming from auxiliary losses.

    Returns:
      loss: <float32> [] for the agreement losses.
    """
        # Get target embeddigns and vocab size.
        target_modality = self._problem_hparams.modality["targets"]
        target_modality_scope = self._variable_scopes[target_modality.name]
        target_embeddings = model_utils.get_embeddings(
            modality=target_modality,
            outer_scope=target_modality_scope,
            inner_scope="shared")
        target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

        # Build auxiliary sequences (if necessary).
        aux_keys = self._build_aux_sequences(target_embeddings,
                                             target_vocab_size,
                                             central_lang_tag=central_lang_tag)

        # Build loss.
        aux_loss = 0.
        with tf.name_scope("dec_agreement_loss"):
            for key1, key2 in zip(aux_keys, aux_keys[::-1]):
                # Prepare for decoding.
                targets = self.dec_outputs[key2]["rnn_output"]
                targets_length = self.dec_outputs[key2]["length"]
                shifted_targets = common_layers.shift_right_3d(targets)
                hiddens = self.enc_outputs[key1].outputs
                hiddens_length = self.inputs[key1][1]
                enc_state = self.enc_outputs[key1].final_state
                # Decode.
                decode_func = self.get_decode_func(
                    target_embeddings,
                    shifted_targets,
                    targets_length,
                    hiddens,
                    hiddens_length,
                    enc_state,
                    mode=tf.estimator.ModeKeys.PREDICT,
                    decoder_iterations=self._hparams.aux_decode_length)
                aux_dec_outputs = decode_func()
                # Compute logits (protect central directions from the gradients).
                aux_logits_1 = model_utils.build_logits(
                    sequences=tf.expand_dims(aux_dec_outputs["rnn_output"],
                                             axis=2),
                    embeddings=target_embeddings,
                    vocab_size=target_vocab_size)
                aux_logits_1 = tf.where(self._is_central[key1],
                                        tf.stop_gradient(aux_logits_1),
                                        aux_logits_1)
                # Compute KL loss.
                logits = tf.squeeze(aux_logits_1, axis=2)
                if self._hparams.dec_agreement_loss_sparse:
                    target_ids = self.dec_outputs[key2]["sample_id"]
                    aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)(
                        logits, target_ids, targets_length)
                else:
                    aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"],
                                              axis=2)
                    target_probs = tf.nn.softmax(aux_logits_2, axis=-1)
                    aux_loss = aux_loss + losses.CrossEntropyLoss(
                        sparse=False)(logits, target_probs, targets_length)

        aux_loss = self._hparams.dec_agreement_coeff * aux_loss

        return aux_loss