Exemplo n.º 1
0
    def body(self, features):
        """Process features and produce outputs."""
        # Preprocess features.
        inputs, inputs_length, targets, targets_length = self._preprocess(
            features)

        # Encode.
        encoder = encoders.get(self._hparams.encoder_type)
        enc_outputs = encoder(inputs=inputs,
                              inputs_length=inputs_length,
                              mode=self._hparams.mode,
                              hparams=self._hparams)

        # Get target embeddings.
        target_modality = self._problem_hparams.modality["targets"]
        target_modality_scope = self._variable_scopes[target_modality.name]
        target_embeddings = model_utils.get_embeddings(
            modality=target_modality,
            outer_scope=target_modality_scope,
            inner_scope="shared")

        # Decode.
        decoder = decoders.get(self._hparams.decoder_type)
        decoder = decoder(embeddings=target_embeddings,
                          inputs=targets,
                          inputs_length=targets_length,
                          hiddens=enc_outputs.outputs,
                          hiddens_length=inputs_length,
                          enc_state=enc_outputs.final_state,
                          mode=self._hparams.mode,
                          hparams=self._hparams)
        dec_outputs, _, _ = contrib_seq2seq.dynamic_decode(decoder=decoder)

        return tf.expand_dims(dec_outputs.rnn_output, axis=2)
Exemplo n.º 2
0
  def _build_decoder_lm_loss(self, central_lang_tag="<en>"):
    """Builds LM loss on the auxiliary decodings."""

    # Get target embeddigns and vocab size.
    target_modality = self._problem_hparams.modality["targets"]
    target_modality_scope = self._variable_scopes[target_modality.name]
    target_embeddings = model_utils.get_embeddings(
        modality=target_modality,
        outer_scope=target_modality_scope,
        inner_scope="shared")
    target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

    # Build auxiliary sequences (if necessary).
    aux_keys = self._build_aux_sequences(
        target_embeddings, target_vocab_size,
        central_lang_tag=central_lang_tag)

    # Make sure LM loss does not affect embeddings.
    target_embeddings = tf.stop_gradient(target_embeddings)

    # Build loss.
    aux_loss = 0.
    with tf.name_scope("aux_lm_loss"):
      for key in aux_keys:
        dec_outputs = tf.expand_dims(
            self.dec_outputs[key]["rnn_output"], axis=2)
        dec_output_tags = tf.expand_dims(
            self.inputs[key][0][:, :1], axis=2)
        dec_lengths = self.dec_outputs[key]["length"]
        # Preprocess LM features.
        lm_features = {
            "targets": dec_outputs,
            "target_tags": dec_output_tags}
        inputs, inputs_length = self._build_lm_inputs(lm_features)
        # Build LM (with frozen weights in PREDICT mode).
        lm_outputs = self.language_model(
            inputs=inputs,
            inputs_length=inputs_length,
            mode=tf.estimator.ModeKeys.PREDICT,
            hparams=self._hparams,
            trainable=False,
            reuse=tf.AUTO_REUSE)
        # Compute logits.
        lm_logits = model_utils.build_logits(
            sequences=tf.expand_dims(lm_outputs, axis=2),
            embeddings=target_embeddings,
            vocab_size=target_vocab_size)
        # Compute decoder probabilities.
        dec_logits = model_utils.build_logits(
            sequences=dec_outputs,
            embeddings=target_embeddings,
            vocab_size=target_vocab_size)
        dec_probs = tf.nn.softmax(dec_logits, axis=-1)
        # Compute cross-entropy loss.
        aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=False)(
            lm_logits, dec_probs, dec_lengths)

    aux_loss = self._hparams.lm_loss_coeff * aux_loss

    return aux_loss
Exemplo n.º 3
0
  def _top_single(self, body_output, target_modality, features):
    """Top transformation that ensures correct reuse of target embeddings."""
    t2t_model.log_info(
        "Transforming body output with %s.top", target_modality.name)

    # Get target embeddings.
    target_modality = self._problem_hparams.modality["targets"]
    target_modality_scope = self._variable_scopes[target_modality.name]
    target_embeddings = model_utils.get_embeddings(
        modality=target_modality,
        outer_scope=target_modality_scope,
        inner_scope="shared")
    target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

    # Preprocess body output.
    last_only = (
        target_modality.top_is_pointwise and
        self.hparams.mode == tf.estimator.ModeKeys.PREDICT and
        not self.hparams.force_full_predict)
    if last_only:
      # Take body outputs for the last position only.
      if "decode_loop_step" not in features:
        body_output = tf.expand_dims(body_output[:, -1, :, :], axis=[1])
      else:
        body_output_shape = body_output.shape.as_list()
        body_output = tf.slice(
            body_output, [0, features["decode_loop_step"][0], 0, 0], [
                body_output_shape[0], 1, body_output_shape[2],
                body_output_shape[3]
            ])

    # Build logits.
    logits = model_utils.build_logits(
        sequences=body_output,
        embeddings=target_embeddings,
        vocab_size=target_vocab_size)
    return logits
Exemplo n.º 4
0
    def body(self, features):
        # Save a reference to the features to access in other methods.
        self._body_features = features

        # Preprocess features.
        self.inputs, self.targets = self._preprocess(features)

        # Ensure auxiliary parts of the graph are built when necessary.
        batch_size = common_layers.shape_list(features["inputs"])[0]
        global_step = model_utils.get_global_step(self._hparams)

        # Encode (src>tgt).
        key = "src>tgt"
        self.enc_outputs = {}
        self.encoder = encoders.get(self._hparams.encoder_type)
        encode_func = self.get_encode_func(*self.inputs[key])
        self.enc_outputs[key] = encode_func()

        # Get target embeddings.
        target_modality = self._problem_hparams.modality["targets"]
        target_modality_scope = self._variable_scopes[target_modality.name]
        target_embeddings = model_utils.get_embeddings(
            modality=target_modality,
            outer_scope=target_modality_scope,
            inner_scope="shared")

        # Decode (src>tgt).
        key = "src>tgt"
        self.decoders = {}
        self.dec_outputs = {}
        self.decoder = decoders.get(self._hparams.decoder_type)
        # Prepare for decoding.
        target_seqs, target_lens = self.targets[key]
        hiddens = self.enc_outputs[key].outputs
        hiddens_length = self.inputs[key][1]
        enc_state = self.enc_outputs[key].final_state
        # Decode.
        decode_func = self.get_decode_func(target_embeddings,
                                           target_seqs,
                                           target_lens,
                                           hiddens,
                                           hiddens_length,
                                           enc_state,
                                           mode=self._hparams.mode)
        self.dec_outputs[key] = decode_func()
        outputs = tf.expand_dims(self.dec_outputs[key]["rnn_output"], axis=2)

        # Construct agreement losses.
        aux_losses = {}
        if self._hparams.mode == tf.estimator.ModeKeys.TRAIN:
            if self._hparams.enc_agreement_coeff > 0:
                aux_losses["agreement_enc"] = tf.cond(
                    global_step > self._hparams.enc_agreement_enable_step,
                    self._build_encoder_agreement_loss,
                    lambda: tf.zeros([batch_size]))
            if self._hparams.dec_agreement_coeff > 0:
                aux_losses["agreement_dec"] = tf.cond(
                    global_step > self._hparams.dec_agreement_enable_step,
                    self._build_decoder_agreement_loss,
                    lambda: tf.zeros([batch_size]))

        return outputs, aux_losses
Exemplo n.º 5
0
    def _build_decoder_agreement_loss(self, central_lang_tag="<en>"):
        """Builds an agreement loss that enforces consistency of the decodings.

    Args:
      central_lang_tag: A string with the tag of the central language.
        A ``central'' language (usually English) is the one that has parallel
        data with all other languages. It is used to protect supervised
        directions from gradients coming from auxiliary losses.

    Returns:
      loss: <float32> [] for the agreement losses.
    """
        # Get target embeddigns and vocab size.
        target_modality = self._problem_hparams.modality["targets"]
        target_modality_scope = self._variable_scopes[target_modality.name]
        target_embeddings = model_utils.get_embeddings(
            modality=target_modality,
            outer_scope=target_modality_scope,
            inner_scope="shared")
        target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

        # Build auxiliary sequences (if necessary).
        aux_keys = self._build_aux_sequences(target_embeddings,
                                             target_vocab_size,
                                             central_lang_tag=central_lang_tag)

        # Build loss.
        aux_loss = 0.
        with tf.name_scope("dec_agreement_loss"):
            for key1, key2 in zip(aux_keys, aux_keys[::-1]):
                # Prepare for decoding.
                targets = self.dec_outputs[key2]["rnn_output"]
                targets_length = self.dec_outputs[key2]["length"]
                shifted_targets = common_layers.shift_right_3d(targets)
                hiddens = self.enc_outputs[key1].outputs
                hiddens_length = self.inputs[key1][1]
                enc_state = self.enc_outputs[key1].final_state
                # Decode.
                decode_func = self.get_decode_func(
                    target_embeddings,
                    shifted_targets,
                    targets_length,
                    hiddens,
                    hiddens_length,
                    enc_state,
                    mode=tf.estimator.ModeKeys.PREDICT,
                    decoder_iterations=self._hparams.aux_decode_length)
                aux_dec_outputs = decode_func()
                # Compute logits (protect central directions from the gradients).
                aux_logits_1 = model_utils.build_logits(
                    sequences=tf.expand_dims(aux_dec_outputs["rnn_output"],
                                             axis=2),
                    embeddings=target_embeddings,
                    vocab_size=target_vocab_size)
                aux_logits_1 = tf.where(self._is_central[key1],
                                        tf.stop_gradient(aux_logits_1),
                                        aux_logits_1)
                # Compute KL loss.
                logits = tf.squeeze(aux_logits_1, axis=2)
                if self._hparams.dec_agreement_loss_sparse:
                    target_ids = self.dec_outputs[key2]["sample_id"]
                    aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)(
                        logits, target_ids, targets_length)
                else:
                    aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"],
                                              axis=2)
                    target_probs = tf.nn.softmax(aux_logits_2, axis=-1)
                    aux_loss = aux_loss + losses.CrossEntropyLoss(
                        sparse=False)(logits, target_probs, targets_length)

        aux_loss = self._hparams.dec_agreement_coeff * aux_loss

        return aux_loss