Exemplo n.º 1
0
    def dynamic_decode(
        self,
        embeddings,
        start_ids,
        end_id=constants.END_OF_SENTENCE_ID,
        initial_state=None,
        decoding_strategy=None,
        sampler=None,
        maximum_iterations=None,
        minimum_iterations=0,
        tflite_output_size=None,
    ):
        """Decodes dynamically from :obj:`start_ids`.

        Args:
          embeddings: Target embeddings or :class:`opennmt.inputters.WordEmbedder`
            to apply on decoded ids.
          start_ids: Initial input IDs of shape :math:`[B]`.
          end_id: ID of the end of sequence token.
          initial_state: Initial decoder state.
          decoding_strategy: A :class:`opennmt.utils.DecodingStrategy`
            instance that define the decoding logic. Defaults to a greedy search.
          sampler: A :class:`opennmt.utils.Sampler` instance that samples
            predictions from the model output. Defaults to an argmax sampling.
          maximum_iterations: The maximum number of iterations to decode for.
          minimum_iterations: The minimum number of iterations to decode for.
          tflite_output_size: If not None will run TFLite safe, is the size of 1D output tensor.

        Returns:
          A :class:`opennmt.utils.DecodingResult` instance.

        See Also:
          :func:`opennmt.utils.dynamic_decode`
        """
        if tflite_output_size is not None:
            input_fn = lambda ids: embeddings.tflite_call(ids)
        elif isinstance(embeddings, text_inputter.WordEmbedder):
            input_fn = lambda ids: embeddings({"ids": ids})
        else:
            input_fn = lambda ids: tf.nn.embedding_lookup(embeddings, ids)

        # TODO: find a better way to pass the state reorder flags.
        if hasattr(decoding_strategy, "_set_state_reorder_flags"):
            state_reorder_flags = self._get_state_reorder_flags()
            decoding_strategy._set_state_reorder_flags(state_reorder_flags)

        return decoding.dynamic_decode(
            lambda ids, step, state: self(input_fn(ids), step, state),
            start_ids,
            end_id=end_id,
            initial_state=initial_state,
            decoding_strategy=decoding_strategy,
            sampler=sampler,
            maximum_iterations=maximum_iterations,
            minimum_iterations=minimum_iterations,
            attention_history=self.support_alignment_history,
            attention_size=tf.shape(self.memory)[1]
            if self.support_alignment_history else None,
            tflite_output_size=tflite_output_size,
        )
Exemplo n.º 2
0
 def testGreedyDecodeWithMaximumIterations(self):
     logits_fn = _generate_logits_fn(10, [[4, 5, 6, 2], [3, 8, 2, 8]])
     ids, lengths, _, _, _ = decoding.dynamic_decode(logits_fn, [1, 1],
                                                     end_id=2,
                                                     maximum_iterations=2)
     self.assertAllEqual(self.evaluate(ids), [[[4, 5]], [[3, 8]]])
     self.assertAllEqual(self.evaluate(lengths), [[2], [2]])
Exemplo n.º 3
0
  def call(self, features, labels=None, training=None, step=None):
    outputs, predictions = None, None

    ids, length = features["ids"], features["length"]
    if labels is not None:
      # For training and evaluation, forward the full sequence.
      logits, _ = self._decode(
          labels.get("ids", ids),
          labels.get("length", length),
          training=training)
      outputs = dict(logits=logits)
    else:
      assert_fixed_length = tf.debugging.Assert(
          tf.reduce_all(tf.equal(length, tf.reduce_max(length))),
          ["Language model does not support variable length contexts during "
           "generation, consider setting batch_size or length_bucket_width to 1"])
      assert_non_empty_start = tf.debugging.Assert(
          tf.math.not_equal(tf.math.reduce_max(length), 0),
          ["The language model requires a context sequence to initialize the decoding. "
           "If you want nonconditional sequence generation, you should configure the "
           "sequence_controls parameter before training."])

      # Run decoder on the context, if any.
      with tf.control_dependencies([assert_fixed_length, assert_non_empty_start]):
        context_ids, start_ids = tf.split(ids, [tf.shape(ids)[1] - 1, 1], axis=1)
        context_length = length - 1
        batch_size = tf.shape(context_length)[0]
        state = tf.cond(
            tf.equal(tf.reduce_sum(context_length), 0),
            true_fn=lambda: self.decoder.initial_state(batch_size=batch_size, dtype=self.dtype),
            false_fn=lambda: self._decode(context_ids, context_length)[1])

      params = self.params

      def _decode_with_step_offset(ids, step, state):
        return self._decode(ids, step + context_length[0], state)

      # Iteratively decode from the last decoder state.
      sampled_ids, sampled_length, _, _, _ = decoding.dynamic_decode(
          _decode_with_step_offset,
          tf.squeeze(start_ids, 1),
          initial_state=state,
          sampler=decoding.Sampler.from_params(params),
          maximum_iterations=params.get("maximum_decoding_length", 250),
          minimum_iterations=params.get("minimum_decoding_length", 0))
      sampled_ids = tf.reshape(sampled_ids, [batch_size, -1])
      sampled_length = tf.reshape(sampled_length, [batch_size])

      # Build the full prediction.
      if self.features_inputter.mark_start:
        # Remove leading <s> if included in the context sequence.
        ids = ids[:, 1:]
        length -= 1
      full_ids = tf.concat([ids, sampled_ids], 1)
      full_length = length + sampled_length
      tokens = self.features_inputter.ids_to_tokens.lookup(full_ids)
      predictions = dict(tokens=tokens, length=full_length)

    return outputs, predictions
  def _call(self, features, labels, params, mode):
    training = mode == tf.estimator.ModeKeys.TRAIN
    outputs, predictions = None, None

    ids, length = features["ids"], features["length"]
    if mode != tf.estimator.ModeKeys.PREDICT:
      # For training and evaluation, forward the full sequence.
      logits, _ = self._decode(ids, length, training=training)
      outputs = dict(logits=logits)
    else:
      assert_fixed_length = tf.debugging.Assert(
          tf.reduce_all(tf.equal(length, tf.reduce_max(length))),
          ["Language model does not support variable length contexts during "
           "generation, consider setting batch_size or bucket_width to 1"])

      # Run decoder one the context, if any.
      with tf.control_dependencies([assert_fixed_length]):
        context_ids, start_ids = tf.split(ids, [tf.shape(ids)[1] - 1, 1], axis=1)
        context_length = length - 1
        batch_size = tf.shape(context_length)[0]
        state = tf.cond(
            tf.equal(tf.reduce_sum(context_length), 0),
            true_fn=lambda: self.decoder.get_initial_state(batch_size=batch_size, dtype=self.dtype),
            false_fn=lambda: self._decode(context_ids, context_length)[1],
            name=self.name + "/")  # Force the name scope.

      sampling_topk = params.get("sampling_topk")
      if sampling_topk is not None and sampling_topk != 1:
        sampler = decoding.RandomSampler(
            from_top_k=sampling_topk, temperature=params.get("sampling_temperature"))
      else:
        sampler = decoding.BestSampler()

      def _decode_with_step_offset(ids, step, state):
        return self._decode(ids, step + context_length[0], state)

      # Iteratively decode from the last decoder state.
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        sampled_ids, sampled_length, _, _, _ = decoding.dynamic_decode(
            _decode_with_step_offset,
            tf.squeeze(start_ids, 1),
            initial_state=state,
            sampler=sampler,
            maximum_iterations=params.get("maximum_iterations", 250),
            minimum_iterations=params.get("minimum_decoding_length", 0))
        sampled_ids = tf.squeeze(sampled_ids, 1)
        sampled_length = tf.squeeze(sampled_length, 1)

      # Build the full prediction.
      full_ids = tf.concat([ids, sampled_ids], 1)
      full_length = length + sampled_length
      vocab_rev = self.examples_inputter.vocabulary_lookup_reverse()
      tokens = vocab_rev.lookup(full_ids)
      predictions = dict(tokens=tokens, length=full_length)

    return outputs, predictions
Exemplo n.º 5
0
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size=None,
                                  initial_state=None,
                                  output_layer=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  minimum_length=0,
                                  mode=tf.estimator.ModeKeys.PREDICT,
                                  memory=None,
                                  memory_sequence_length=None,
                                  dtype=None,
                                  return_alignment_history=False,
                                  sample_from=None,
                                  sample_temperature=None,
                                  coverage_penalty=0.0):
        """Decodes dynamically from :obj:`start_tokens` with beam search.

    Usually used for inference.

    Args:
      embedding: The embedding tensor or a callable that takes word ids.
      start_tokens: The start token ids with shape :math:`[B]`.
      end_token: The end token id.
      vocab_size: The output vocabulary size. Must be set if :obj:`output_layer`
        is not set.
      initial_state: The initial state as a (possibly nested tuple of...) tensors.
      output_layer: Optional layer to apply to the output prior sampling.
        Must be set if :obj:`vocab_size` is not set.
      beam_width: The width of the beam.
      length_penalty: The length penalty weight during beam search.
      maximum_iterations: The maximum number of decoding iterations.
      minimum_length: The minimum length of decoded sequences (:obj:`end_token`
        excluded).
      mode: A ``tf.estimator.ModeKeys`` mode.
      memory: (optional) Memory values to query.
      memory_sequence_length: (optional) Memory values length.
      dtype: The data type. Required if :obj:`memory` is ``None``.
      return_alignment_history: If ``True``, also returns the alignment
        history from the attention layer (``None`` will be returned if
        unsupported by the decoder).
      sample_from: Sample predictions from the :obj:`sample_from` most likely
        tokens. If 0, sample from the full output distribution.
      sample_temperature: Value dividing logits. In random sampling, a high
        value generates more random samples.
      coverage_penalty: The coverage penalty weight during beam search.

    Returns:
      A tuple ``(predicted_ids, state, sequence_length, log_probs)`` or
      ``(predicted_ids, state, sequence_length, log_probs, alignment_history)``
      if :obj:`return_alignment_history` is ``True``.
    """
        batch_size = tf.shape(start_tokens)[0] * beam_width
        if dtype is None:
            if memory is None:
                raise ValueError(
                    "dtype argument is required when no memory is set")
            dtype = tf.contrib.framework.nest.flatten(memory)[0].dtype

        if beam_width > 1:
            if initial_state is not None:
                initial_state = tf.contrib.seq2seq.tile_batch(
                    initial_state, multiplier=beam_width)
            if memory is not None:
                memory = tf.contrib.seq2seq.tile_batch(memory,
                                                       multiplier=beam_width)
            if memory_sequence_length is not None:
                memory_sequence_length = tf.contrib.seq2seq.tile_batch(
                    memory_sequence_length, multiplier=beam_width)

        embedding_fn = get_embedding_fn(embedding)
        step_fn, initial_state = self.step_fn(
            mode,
            batch_size,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=dtype)
        if output_layer is None:
            if vocab_size is None:
                raise ValueError(
                    "vocab_size must be known when the output_layer is not set"
                )
            output_layer = build_output_layer(self.output_size,
                                              vocab_size,
                                              dtype=dtype)

        def _symbols_to_logits_fn(ids, step, state):
            inputs = embedding_fn(ids)
            returned_values = step_fn(step, inputs, state, mode)
            if self.support_alignment_history:
                outputs, state, attention = returned_values
            else:
                outputs, state = returned_values
                attention = None
            logits = output_layer(outputs)
            return logits, state, attention

        if beam_width == 1:
            decoding_strategy = decoding.GreedySearch()
        else:
            decoding_strategy = decoding.BeamSearch(
                beam_width,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty)

        if sample_from is not None and sample_from != 1:
            sampler = decoding.RandomSampler(from_top_k=sample_from,
                                             temperature=sample_temperature)
        else:
            sampler = decoding.BestSampler()

        outputs, lengths, log_probs, attention, state = decoding.dynamic_decode(
            _symbols_to_logits_fn,
            start_tokens,
            end_id=end_token,
            initial_state=initial_state,
            decoding_strategy=decoding_strategy,
            sampler=sampler,
            maximum_iterations=maximum_iterations,
            minimum_iterations=minimum_length,
            attention_history=self.support_alignment_history
            and not isinstance(memory, (list, tuple)),
            attention_size=tf.shape(memory)[1]
            if self.support_alignment_history else None)

        # For backward compatibility, include </s> in length.
        lengths = tf.minimum(lengths + 1, tf.shape(outputs)[2])
        if return_alignment_history:
            return (outputs, state, lengths, log_probs, attention)
        return (outputs, state, lengths, log_probs)