Exemple #1
0
    def _decode_target(
        self,
        labels,
        encoder_outputs,
        encoder_state,
        encoder_sequence_length,
        step=None,
        training=None,
    ):
        params = self.params
        target_inputs = self.labels_inputter(labels, training=training)
        input_fn = lambda ids: self.labels_inputter({"ids": ids},
                                                    training=training)

        sampling_probability = None
        if training:
            sampling_probability = decoder_util.get_sampling_probability(
                step,
                read_probability=params.get(
                    "scheduled_sampling_read_probability"),
                schedule_type=params.get("scheduled_sampling_type"),
                k=params.get("scheduled_sampling_k"),
            )

        initial_state = self.decoder.initial_state(
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length,
            initial_state=encoder_state,
        )
        logits, _, attention = self.decoder(
            target_inputs,
            self.labels_inputter.get_length(labels),
            state=initial_state,
            input_fn=input_fn,
            sampling_probability=sampling_probability,
            training=training,
        )
        outputs = dict(logits=logits, attention=attention)

        noisy_ids = labels.get("noisy_ids")
        if noisy_ids is not None and params.get("contrastive_learning"):
            # In case of contrastive learning, also forward the erroneous
            # translation to compute its log likelihood later.
            noisy_inputs = self.labels_inputter({"ids": noisy_ids},
                                                training=training)
            noisy_logits, _, _ = self.decoder(
                noisy_inputs,
                labels["noisy_length"],
                state=initial_state,
                input_fn=input_fn,
                sampling_probability=sampling_probability,
                training=training,
            )
            outputs["noisy_logits"] = noisy_logits
        return outputs
    def _build(self, features, labels, params, mode, config=None):
        features_length = self.features_inputter.get_length(features)
        log_dir = config.model_dir if config is not None else None

        source_input_scope = self._get_input_scope(default_name="encoder")
        target_input_scope = self._get_input_scope(default_name="decoder")

        source_inputs = _maybe_reuse_embedding_fn(
            lambda ids: self.source_inputter.transform_data(
                ids, mode=mode, log_dir=log_dir),
            scope=source_input_scope)(features)

        with tf.variable_scope("encoder"):
            encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
                source_inputs, sequence_length=features_length, mode=mode)

        target_vocab_size = self.target_inputter.vocabulary_size
        target_dtype = self.target_inputter.dtype
        target_embedding_fn = _maybe_reuse_embedding_fn(
            lambda ids: self.target_inputter.make_inputs(
                {"ids": ids}, training=mode == tf.estimator.ModeKeys.TRAIN),
            scope=target_input_scope)

        if labels is not None:
            target_inputs = _maybe_reuse_embedding_fn(
                lambda ids: self.target_inputter.transform_data(
                    ids, mode=mode, log_dir=log_dir),
                scope=target_input_scope)(labels)

            with tf.variable_scope("decoder"):
                sampling_probability = None
                if mode == tf.estimator.ModeKeys.TRAIN:
                    sampling_probability = get_sampling_probability(
                        tf.train.get_or_create_global_step(),
                        read_probability=params.get(
                            "scheduled_sampling_read_probability"),
                        schedule_type=params.get("scheduled_sampling_type"),
                        k=params.get("scheduled_sampling_k"))

                logits, _, _, attention = self.decoder.decode(
                    target_inputs,
                    self.labels_inputter.get_length(labels),
                    vocab_size=target_vocab_size,
                    initial_state=encoder_state,
                    sampling_probability=sampling_probability,
                    embedding=target_embedding_fn,
                    mode=mode,
                    memory=encoder_outputs,
                    memory_sequence_length=encoder_sequence_length,
                    return_alignment_history=True)
                if "alignment" in labels:
                    outputs = {"logits": logits, "attention": attention}
                else:
                    outputs = logits
        else:
            outputs = None

        if mode != tf.estimator.ModeKeys.TRAIN:
            with tf.variable_scope("decoder", reuse=labels is not None):
                batch_size = tf.shape(
                    tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0]
                beam_width = params.get("beam_width", 1)
                maximum_iterations = params.get("maximum_iterations", 250)
                minimum_length = params.get("minimum_decoding_length", 0)
                sample_from = params.get("sampling_topk", 1)
                start_tokens = tf.fill([batch_size],
                                       constants.START_OF_SENTENCE_ID)
                end_token = constants.END_OF_SENTENCE_ID

                if beam_width <= 1:
                    sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode(
                        target_embedding_fn,
                        start_tokens,
                        end_token,
                        vocab_size=target_vocab_size,
                        initial_state=encoder_state,
                        maximum_iterations=maximum_iterations,
                        minimum_length=minimum_length,
                        mode=mode,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        dtype=target_dtype,
                        return_alignment_history=True,
                        sample_from=sample_from)
                else:
                    length_penalty = params.get("length_penalty", 0)
                    sampled_ids, _, sampled_length, log_probs, alignment = (
                        self.decoder.dynamic_decode_and_search(
                            target_embedding_fn,
                            start_tokens,
                            end_token,
                            vocab_size=target_vocab_size,
                            initial_state=encoder_state,
                            beam_width=beam_width,
                            length_penalty=length_penalty,
                            maximum_iterations=maximum_iterations,
                            minimum_length=minimum_length,
                            mode=mode,
                            memory=encoder_outputs,
                            memory_sequence_length=encoder_sequence_length,
                            dtype=target_dtype,
                            return_alignment_history=True,
                            sample_from=sample_from))

            target_vocab_rev = self.target_inputter.vocabulary_lookup_reverse()
            target_tokens = target_vocab_rev.lookup(
                tf.cast(sampled_ids, tf.int64))

            if params.get("replace_unknown_target", False):
                if alignment is None:
                    raise TypeError(
                        "replace_unknown_target is not compatible with decoders "
                        "that don't return alignment history")
                if not isinstance(self.source_inputter,
                                  inputters.WordEmbedder):
                    raise TypeError(
                        "replace_unknown_target is only defined when the source "
                        "inputter is a WordEmbedder")
                source_tokens = features["tokens"]
                if beam_width > 1:
                    source_tokens = tf.contrib.seq2seq.tile_batch(
                        source_tokens, multiplier=beam_width)
                # Merge batch and beam dimensions.
                original_shape = tf.shape(target_tokens)
                target_tokens = tf.reshape(target_tokens,
                                           [-1, original_shape[-1]])
                attention = tf.reshape(
                    alignment,
                    [-1, tf.shape(alignment)[2],
                     tf.shape(alignment)[3]])
                replaced_target_tokens = replace_unknown_target(
                    target_tokens, source_tokens, attention)
                target_tokens = tf.reshape(replaced_target_tokens,
                                           original_shape)

            predictions = {
                "tokens": target_tokens,
                "length": sampled_length,
                "log_probs": log_probs
            }
            if alignment is not None:
                predictions["alignment"] = alignment
        else:
            predictions = None

        return outputs, predictions
    def _call(self, features, labels, params, mode):
        training = mode == tf.estimator.ModeKeys.TRAIN

        features_length = self.features_inputter.get_length(features)
        source_inputs = self.features_inputter.make_inputs(features,
                                                           training=training)
        with tf.variable_scope("encoder"):
            encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
                source_inputs, sequence_length=features_length, mode=mode)

        target_vocab_size = self.labels_inputter.vocabulary_size
        target_dtype = self.labels_inputter.dtype
        if labels is not None:
            target_inputs = self.labels_inputter.make_inputs(labels,
                                                             training=training)
            with tf.variable_scope("decoder"):
                sampling_probability = None
                if mode == tf.estimator.ModeKeys.TRAIN:
                    sampling_probability = get_sampling_probability(
                        tf.train.get_or_create_global_step(),
                        read_probability=params.get(
                            "scheduled_sampling_read_probability"),
                        schedule_type=params.get("scheduled_sampling_type"),
                        k=params.get("scheduled_sampling_k"))

                logits, _, _, attention = self.decoder.decode(
                    target_inputs,
                    self.labels_inputter.get_length(labels),
                    vocab_size=target_vocab_size,
                    initial_state=encoder_state,
                    sampling_probability=sampling_probability,
                    embedding=self.labels_inputter.embedding,
                    output_layer=self.output_layer,
                    mode=mode,
                    memory=encoder_outputs,
                    memory_sequence_length=encoder_sequence_length,
                    return_alignment_history=True)
                if "alignment" in labels:
                    outputs = {"logits": logits, "attention": attention}
                else:
                    outputs = logits
        else:
            outputs = None

        if mode != tf.estimator.ModeKeys.TRAIN:
            with tf.variable_scope("decoder", reuse=labels is not None):
                batch_size = tf.shape(
                    tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0]
                beam_width = params.get("beam_width", 1)
                start_tokens = tf.fill([batch_size],
                                       constants.START_OF_SENTENCE_ID)
                end_token = constants.END_OF_SENTENCE_ID
                sampled_ids, _, sampled_length, log_probs, alignment = (
                    self.decoder.dynamic_decode_and_search(
                        self.labels_inputter.embedding,
                        start_tokens,
                        end_token,
                        vocab_size=target_vocab_size,
                        initial_state=encoder_state,
                        output_layer=self.output_layer,
                        beam_width=beam_width,
                        length_penalty=params.get("length_penalty", 0),
                        maximum_iterations=params.get("maximum_iterations",
                                                      250),
                        minimum_length=params.get("minimum_decoding_length",
                                                  0),
                        mode=mode,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        dtype=target_dtype,
                        return_alignment_history=True,
                        sample_from=params.get("sampling_topk"),
                        sample_temperature=params.get("sampling_temperature")))

            target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse()
            target_tokens = target_vocab_rev.lookup(
                tf.cast(sampled_ids, tf.int64))

            if params.get("replace_unknown_target", False):
                if alignment is None:
                    raise TypeError(
                        "replace_unknown_target is not compatible with decoders "
                        "that don't return alignment history")
                if not isinstance(self.features_inputter,
                                  inputters.WordEmbedder):
                    raise TypeError(
                        "replace_unknown_target is only defined when the source "
                        "inputter is a WordEmbedder")
                source_tokens = features["tokens"]
                if beam_width > 1:
                    source_tokens = tf.contrib.seq2seq.tile_batch(
                        source_tokens, multiplier=beam_width)
                # Merge batch and beam dimensions.
                original_shape = tf.shape(target_tokens)
                target_tokens = tf.reshape(target_tokens,
                                           [-1, original_shape[-1]])
                align_shape = shape_list(alignment)
                attention = tf.reshape(alignment, [
                    align_shape[0] * align_shape[1], align_shape[2],
                    align_shape[3]
                ])
                # We don't have attention for </s> but ensure that the attention time dimension matches
                # the tokens time dimension.
                attention = reducer.align_in_time(attention,
                                                  tf.shape(target_tokens)[1])
                replaced_target_tokens = replace_unknown_target(
                    target_tokens, source_tokens, attention)
                target_tokens = tf.reshape(replaced_target_tokens,
                                           original_shape)

            predictions = {
                "tokens": target_tokens,
                "length": sampled_length,
                "log_probs": log_probs
            }
            if alignment is not None:
                predictions["alignment"] = alignment
        else:
            predictions = None

        return outputs, predictions
  def _call(self, features, labels, params, mode):
    training = mode == tf.estimator.ModeKeys.TRAIN

    features_length = self.features_inputter.get_length(features)
    source_inputs = self.features_inputter.make_inputs(features, training=training)
    with tf.variable_scope("encoder"):
      encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
          source_inputs,
          sequence_length=features_length,
          mode=mode)

    target_vocab_size = self.labels_inputter.vocabulary_size
    target_dtype = self.labels_inputter.dtype
    if labels is not None:
      sampling_probability = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        sampling_probability = get_sampling_probability(
            tf.train.get_or_create_global_step(),
            read_probability=params.get("scheduled_sampling_read_probability"),
            schedule_type=params.get("scheduled_sampling_type"),
            k=params.get("scheduled_sampling_k"))

      def _decode_inputs(inputs, length, reuse=None):
        with tf.variable_scope("decoder", reuse=reuse):
          return self.decoder.decode(
              inputs,
              length,
              vocab_size=target_vocab_size,
              initial_state=encoder_state,
              sampling_probability=sampling_probability,
              embedding=self.labels_inputter.embedding,
              output_layer=self.output_layer,
              mode=mode,
              memory=encoder_outputs,
              memory_sequence_length=encoder_sequence_length,
              return_alignment_history=True)

      target_inputs = self.labels_inputter.make_inputs(labels, training=training)
      logits, _, _, attention = _decode_inputs(target_inputs, labels["length"])
      if "alignment" in labels:
        outputs = {
            "logits": logits,
            "attention": attention
        }
      else:
        outputs = logits

      noisy_ids = labels.get("noisy_ids")
      if noisy_ids is not None and params.get("contrastive_learning"):
        # In case of contrastive learning, also forward the erroneous
        # translation to compute its log likelihood later.
        noisy_inputs = self.labels_inputter.make_inputs({"ids": noisy_ids}, training=training)
        noisy_logits = _decode_inputs(noisy_inputs, labels["noisy_length"], reuse=True)[0]
        if not isinstance(outputs, dict):
          outputs = dict(logits=outputs)
        outputs["noisy_logits"] = noisy_logits
    else:
      outputs = None

    if mode != tf.estimator.ModeKeys.TRAIN:
      with tf.variable_scope("decoder", reuse=labels is not None):
        batch_size = tf.shape(tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0]
        beam_width = params.get("beam_width", 1)
        start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID)
        end_token = constants.END_OF_SENTENCE_ID
        sampled_ids, _, sampled_length, log_probs, alignment = (
            self.decoder.dynamic_decode_and_search(
                self.labels_inputter.embedding,
                start_tokens,
                end_token,
                vocab_size=target_vocab_size,
                initial_state=encoder_state,
                output_layer=self.output_layer,
                beam_width=beam_width,
                length_penalty=params.get("length_penalty", 0),
                maximum_iterations=params.get("maximum_iterations", 250),
                minimum_length=params.get("minimum_decoding_length", 0),
                mode=mode,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length,
                dtype=target_dtype,
                return_alignment_history=True,
                sample_from=params.get("sampling_topk"),
                sample_temperature=params.get("sampling_temperature"),
                coverage_penalty=params.get("coverage_penalty", 0)))

      target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse()
      target_tokens = target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64))

      if params.get("replace_unknown_target", False):
        if alignment is None:
          raise TypeError("replace_unknown_target is not compatible with decoders "
                          "that don't return alignment history")
        if not isinstance(self.features_inputter, inputters.WordEmbedder):
          raise TypeError("replace_unknown_target is only defined when the source "
                          "inputter is a WordEmbedder")
        source_tokens = features["tokens"]
        if beam_width > 1:
          source_tokens = tf.contrib.seq2seq.tile_batch(source_tokens, multiplier=beam_width)
        # Merge batch and beam dimensions.
        original_shape = tf.shape(target_tokens)
        target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]])
        align_shape = shape_list(alignment)
        attention = tf.reshape(
            alignment, [align_shape[0] * align_shape[1], align_shape[2], align_shape[3]])
        # We don't have attention for </s> but ensure that the attention time dimension matches
        # the tokens time dimension.
        attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1])
        replaced_target_tokens = replace_unknown_target(target_tokens, source_tokens, attention)
        target_tokens = tf.reshape(replaced_target_tokens, original_shape)

      decoding_noise = params.get("decoding_noise")
      if decoding_noise:
        sampled_length -= 1  # Ignore </s>
        target_tokens, sampled_length = _add_noise(
            target_tokens,
            sampled_length,
            decoding_noise,
            params.get("decoding_subword_token", "■"))
        sampled_length += 1
        alignment = None  # Invalidate alignments.

      predictions = {
          "tokens": target_tokens,
          "length": sampled_length,
          "log_probs": log_probs
      }
      if alignment is not None:
        predictions["alignment"] = alignment

      num_hypotheses = params.get("num_hypotheses", 1)
      if num_hypotheses > 0:
        if num_hypotheses > beam_width:
          raise ValueError("n_best cannot be greater than beam_width")
        for key, value in six.iteritems(predictions):
          predictions[key] = value[:, :num_hypotheses]
    else:
      predictions = None

    return outputs, predictions
Exemple #5
0
  def _build(self, features, labels, params, mode, config):
    features_length = self._get_features_length(features)

    with tf.variable_scope("encoder"):
      source_inputs = self.source_inputter.transform_data(
          features,
          mode=mode,
          log_dir=config.model_dir)
      encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
          source_inputs,
          sequence_length=features_length,
          mode=mode)

    target_vocab_size = self.target_inputter.vocabulary_size

    with tf.variable_scope("decoder") as decoder_scope:
      if labels is not None:
        sampling_probability = get_sampling_probability(
            tf.train.get_or_create_global_step(),
            read_probability=params.get("scheduled_sampling_read_probability"),
            schedule_type=params.get("scheduled_sampling_type"),
            k=params.get("scheduled_sampling_k"))

        target_inputs = self.target_inputter.transform_data(
            labels,
            mode=mode,
            log_dir=config.model_dir)
        logits, _, _ = self.decoder.decode(
            target_inputs,
            self._get_labels_length(labels),
            target_vocab_size,
            initial_state=encoder_state,
            sampling_probability=sampling_probability,
            embedding=self._scoped_target_embedding_fn(mode, decoder_scope),
            mode=mode,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)
      else:
        logits = None

    if mode != tf.estimator.ModeKeys.TRAIN:
      with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope:
        batch_size = tf.shape(encoder_sequence_length)[0]
        beam_width = params.get("beam_width", 1)
        maximum_iterations = params.get("maximum_iterations", 250)
        start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID)
        end_token = constants.END_OF_SENTENCE_ID

        if beam_width <= 1:
          sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode(
              self._scoped_target_embedding_fn(mode, decoder_scope),
              start_tokens,
              end_token,
              target_vocab_size,
              initial_state=encoder_state,
              maximum_iterations=maximum_iterations,
              mode=mode,
              memory=encoder_outputs,
              memory_sequence_length=encoder_sequence_length)
        else:
          length_penalty = params.get("length_penalty", 0)
          sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode_and_search(
              self._scoped_target_embedding_fn(mode, decoder_scope),
              start_tokens,
              end_token,
              target_vocab_size,
              initial_state=encoder_state,
              beam_width=beam_width,
              length_penalty=length_penalty,
              maximum_iterations=maximum_iterations,
              mode=mode,
              memory=encoder_outputs,
              memory_sequence_length=encoder_sequence_length)

      target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file(
          self.target_inputter.vocabulary_file,
          vocab_size=target_vocab_size - self.target_inputter.num_oov_buckets,
          default_value=constants.UNKNOWN_TOKEN)

      predictions = {
          "tokens": target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)),
          "length": sampled_length,
          "log_probs": log_probs
      }
    else:
      predictions = None

    return logits, predictions
Exemple #6
0
    def testSamplingProbability(self):
        step = tf.constant(5, dtype=tf.int64)
        large_step = tf.constant(1000, dtype=tf.int64)
        self.assertIsNone(decoder.get_sampling_probability(step))
        with self.assertRaises(ValueError):
            decoder.get_sampling_probability(step, schedule_type="linear")
        with self.assertRaises(ValueError):
            decoder.get_sampling_probability(step, schedule_type="linear", k=1)
        with self.assertRaises(TypeError):
            decoder.get_sampling_probability(step, schedule_type="foo", k=1)

        constant_sample_prob = decoder.get_sampling_probability(
            step, read_probability=0.9)
        linear_sample_prob = decoder.get_sampling_probability(
            step, read_probability=1.0, schedule_type="linear", k=0.1)
        linear_sample_prob_same = decoder.get_sampling_probability(
            step, read_probability=2.0, schedule_type="linear", k=0.1)
        linear_sample_prob_inf = decoder.get_sampling_probability(
            large_step, read_probability=1.0, schedule_type="linear", k=0.1)
        exp_sample_prob = decoder.get_sampling_probability(
            step, schedule_type="exponential", k=0.8)
        inv_sig_sample_prob = decoder.get_sampling_probability(
            step, schedule_type="inverse_sigmoid", k=1)

        self.assertAlmostEqual(0.1, constant_sample_prob)
        self.assertAlmostEqual(0.5, self.evaluate(linear_sample_prob))
        self.assertAlmostEqual(0.5, self.evaluate(linear_sample_prob_same))
        self.assertAlmostEqual(1.0, self.evaluate(linear_sample_prob_inf))
        self.assertAlmostEqual(1.0 - pow(0.8, 5),
                               self.evaluate(exp_sample_prob))
        self.assertAlmostEqual(1.0 - (1.0 / (1.0 + math.exp(5.0 / 1.0))),
                               self.evaluate(inv_sig_sample_prob))
  def _build(self, features, labels, params, mode, config):
    features_length = self._get_features_length(features)

    with tf.variable_scope("encoder"):
      source_inputs = self.source_inputter.transform_data(
          features,
          mode=mode,
          log_dir=config.model_dir)
      encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
          source_inputs,
          sequence_length=features_length,
          mode=mode)

    target_vocab_size = self.target_inputter.vocabulary_size
    target_dtype = self.target_inputter.dtype

    with tf.variable_scope("decoder") as decoder_scope:
      if labels is not None:
        sampling_probability = get_sampling_probability(
            tf.train.get_or_create_global_step(),
            read_probability=params.get("scheduled_sampling_read_probability"),
            schedule_type=params.get("scheduled_sampling_type"),
            k=params.get("scheduled_sampling_k"))

        target_inputs = self.target_inputter.transform_data(
            labels,
            mode=mode,
            log_dir=config.model_dir)
        logits, _, _ = self.decoder.decode(
            target_inputs,
            self._get_labels_length(labels),
            target_vocab_size,
            initial_state=encoder_state,
            sampling_probability=sampling_probability,
            embedding=self._scoped_target_embedding_fn(mode, decoder_scope),
            mode=mode,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)
      else:
        logits = None

    if mode != tf.estimator.ModeKeys.TRAIN:
      with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope:
        batch_size = tf.shape(encoder_sequence_length)[0]
        beam_width = params.get("beam_width", 1)
        maximum_iterations = params.get("maximum_iterations", 250)
        start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID)
        end_token = constants.END_OF_SENTENCE_ID

        if beam_width <= 1:
          sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode(
              self._scoped_target_embedding_fn(mode, decoder_scope),
              start_tokens,
              end_token,
              target_vocab_size,
              initial_state=encoder_state,
              maximum_iterations=maximum_iterations,
              mode=mode,
              memory=encoder_outputs,
              memory_sequence_length=encoder_sequence_length,
              dtype=target_dtype)
        else:
          length_penalty = params.get("length_penalty", 0)
          sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode_and_search(
              self._scoped_target_embedding_fn(mode, decoder_scope),
              start_tokens,
              end_token,
              target_vocab_size,
              initial_state=encoder_state,
              beam_width=beam_width,
              length_penalty=length_penalty,
              maximum_iterations=maximum_iterations,
              mode=mode,
              memory=encoder_outputs,
              memory_sequence_length=encoder_sequence_length,
              dtype=target_dtype)

      target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file(
          self.target_inputter.vocabulary_file,
          vocab_size=target_vocab_size - self.target_inputter.num_oov_buckets,
          default_value=constants.UNKNOWN_TOKEN)

      predictions = {
          "tokens": target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)),
          "length": sampled_length,
          "log_probs": log_probs
      }
    else:
      predictions = None

    return logits, predictions
    def _build(self, features, labels, params, mode, config=None):
        features_length = self._get_features_length(features)
        log_dir = config.model_dir if config is not None else None

        with tf.variable_scope("encoder"):
            source_inputs = self.source_inputter.transform_data(
                features, mode=mode, log_dir=log_dir)
            encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
                source_inputs, sequence_length=features_length, mode=mode)

        target_vocab_size = self.target_inputter.vocabulary_size
        target_dtype = self.target_inputter.dtype

        with tf.variable_scope("decoder") as decoder_scope:
            if labels is not None:
                sampling_probability = get_sampling_probability(
                    tf.train.get_or_create_global_step(),
                    read_probability=params.get(
                        "scheduled_sampling_read_probability"),
                    schedule_type=params.get("scheduled_sampling_type"),
                    k=params.get("scheduled_sampling_k"))

                target_inputs = self.target_inputter.transform_data(
                    labels, mode=mode, log_dir=log_dir)
                logits, _, _ = self.decoder.decode(
                    target_inputs,
                    self._get_labels_length(labels),
                    vocab_size=target_vocab_size,
                    initial_state=encoder_state,
                    sampling_probability=sampling_probability,
                    embedding=self._scoped_target_embedding_fn(
                        mode, decoder_scope),
                    mode=mode,
                    memory=encoder_outputs,
                    memory_sequence_length=encoder_sequence_length)
            else:
                logits = None

        if mode != tf.estimator.ModeKeys.TRAIN:
            with tf.variable_scope(decoder_scope, reuse=labels
                                   is not None) as decoder_scope:
                batch_size = tf.shape(encoder_sequence_length)[0]
                beam_width = params.get("beam_width", 1)
                maximum_iterations = params.get("maximum_iterations", 250)
                start_tokens = tf.fill([batch_size],
                                       constants.START_OF_SENTENCE_ID)
                end_token = constants.END_OF_SENTENCE_ID

                if beam_width <= 1:
                    sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode(
                        self._scoped_target_embedding_fn(mode, decoder_scope),
                        start_tokens,
                        end_token,
                        vocab_size=target_vocab_size,
                        initial_state=encoder_state,
                        maximum_iterations=maximum_iterations,
                        mode=mode,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        dtype=target_dtype,
                        return_alignment_history=True)
                else:
                    length_penalty = params.get("length_penalty", 0)
                    sampled_ids, _, sampled_length, log_probs, alignment = (
                        self.decoder.dynamic_decode_and_search(
                            self._scoped_target_embedding_fn(
                                mode, decoder_scope),
                            start_tokens,
                            end_token,
                            vocab_size=target_vocab_size,
                            initial_state=encoder_state,
                            beam_width=beam_width,
                            length_penalty=length_penalty,
                            maximum_iterations=maximum_iterations,
                            mode=mode,
                            memory=encoder_outputs,
                            memory_sequence_length=encoder_sequence_length,
                            dtype=target_dtype,
                            return_alignment_history=True))

            target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file(
                self.target_inputter.vocabulary_file,
                vocab_size=target_vocab_size -
                self.target_inputter.num_oov_buckets,
                default_value=constants.UNKNOWN_TOKEN)
            target_tokens = target_vocab_rev.lookup(
                tf.cast(sampled_ids, tf.int64))

            if params.get("replace_unknown_target", False):
                if alignment is None:
                    raise TypeError(
                        "replace_unknown_target is not compatible with decoders "
                        "that don't return alignment history")
                if not isinstance(self.source_inputter,
                                  inputters.WordEmbedder):
                    raise TypeError(
                        "replace_unknown_target is only defined when the source "
                        "inputter is a WordEmbedder")
                source_tokens = features["tokens"]
                if beam_width > 1:
                    source_tokens = tf.contrib.seq2seq.tile_batch(
                        source_tokens, multiplier=beam_width)
                # Merge batch and beam dimensions.
                original_shape = tf.shape(target_tokens)
                target_tokens = tf.reshape(target_tokens,
                                           [-1, original_shape[-1]])
                attention = tf.reshape(
                    alignment,
                    [-1, tf.shape(alignment)[2],
                     tf.shape(alignment)[3]])
                replaced_target_tokens = replace_unknown_target(
                    target_tokens, source_tokens, attention)
                target_tokens = tf.reshape(replaced_target_tokens,
                                           original_shape)

            predictions = {
                "tokens": target_tokens,
                "length": sampled_length,
                "log_probs": log_probs
            }
            if alignment is not None:
                predictions["alignment"] = alignment
        else:
            predictions = None

        return logits, predictions