def dynamic_decode( self, embeddings, start_ids, end_id=constants.END_OF_SENTENCE_ID, initial_state=None, decoding_strategy=None, sampler=None, maximum_iterations=None, minimum_iterations=0, tflite_output_size=None, ): """Decodes dynamically from :obj:`start_ids`. Args: embeddings: Target embeddings or :class:`opennmt.inputters.WordEmbedder` to apply on decoded ids. start_ids: Initial input IDs of shape :math:`[B]`. end_id: ID of the end of sequence token. initial_state: Initial decoder state. decoding_strategy: A :class:`opennmt.utils.DecodingStrategy` instance that define the decoding logic. Defaults to a greedy search. sampler: A :class:`opennmt.utils.Sampler` instance that samples predictions from the model output. Defaults to an argmax sampling. maximum_iterations: The maximum number of iterations to decode for. minimum_iterations: The minimum number of iterations to decode for. tflite_output_size: If not None will run TFLite safe, is the size of 1D output tensor. Returns: A :class:`opennmt.utils.DecodingResult` instance. See Also: :func:`opennmt.utils.dynamic_decode` """ if tflite_output_size is not None: input_fn = lambda ids: embeddings.tflite_call(ids) elif isinstance(embeddings, text_inputter.WordEmbedder): input_fn = lambda ids: embeddings({"ids": ids}) else: input_fn = lambda ids: tf.nn.embedding_lookup(embeddings, ids) # TODO: find a better way to pass the state reorder flags. if hasattr(decoding_strategy, "_set_state_reorder_flags"): state_reorder_flags = self._get_state_reorder_flags() decoding_strategy._set_state_reorder_flags(state_reorder_flags) return decoding.dynamic_decode( lambda ids, step, state: self(input_fn(ids), step, state), start_ids, end_id=end_id, initial_state=initial_state, decoding_strategy=decoding_strategy, sampler=sampler, maximum_iterations=maximum_iterations, minimum_iterations=minimum_iterations, attention_history=self.support_alignment_history, attention_size=tf.shape(self.memory)[1] if self.support_alignment_history else None, tflite_output_size=tflite_output_size, )
def testGreedyDecodeWithMaximumIterations(self): logits_fn = _generate_logits_fn(10, [[4, 5, 6, 2], [3, 8, 2, 8]]) ids, lengths, _, _, _ = decoding.dynamic_decode(logits_fn, [1, 1], end_id=2, maximum_iterations=2) self.assertAllEqual(self.evaluate(ids), [[[4, 5]], [[3, 8]]]) self.assertAllEqual(self.evaluate(lengths), [[2], [2]])
def call(self, features, labels=None, training=None, step=None): outputs, predictions = None, None ids, length = features["ids"], features["length"] if labels is not None: # For training and evaluation, forward the full sequence. logits, _ = self._decode( labels.get("ids", ids), labels.get("length", length), training=training) outputs = dict(logits=logits) else: assert_fixed_length = tf.debugging.Assert( tf.reduce_all(tf.equal(length, tf.reduce_max(length))), ["Language model does not support variable length contexts during " "generation, consider setting batch_size or length_bucket_width to 1"]) assert_non_empty_start = tf.debugging.Assert( tf.math.not_equal(tf.math.reduce_max(length), 0), ["The language model requires a context sequence to initialize the decoding. " "If you want nonconditional sequence generation, you should configure the " "sequence_controls parameter before training."]) # Run decoder on the context, if any. with tf.control_dependencies([assert_fixed_length, assert_non_empty_start]): context_ids, start_ids = tf.split(ids, [tf.shape(ids)[1] - 1, 1], axis=1) context_length = length - 1 batch_size = tf.shape(context_length)[0] state = tf.cond( tf.equal(tf.reduce_sum(context_length), 0), true_fn=lambda: self.decoder.initial_state(batch_size=batch_size, dtype=self.dtype), false_fn=lambda: self._decode(context_ids, context_length)[1]) params = self.params def _decode_with_step_offset(ids, step, state): return self._decode(ids, step + context_length[0], state) # Iteratively decode from the last decoder state. sampled_ids, sampled_length, _, _, _ = decoding.dynamic_decode( _decode_with_step_offset, tf.squeeze(start_ids, 1), initial_state=state, sampler=decoding.Sampler.from_params(params), maximum_iterations=params.get("maximum_decoding_length", 250), minimum_iterations=params.get("minimum_decoding_length", 0)) sampled_ids = tf.reshape(sampled_ids, [batch_size, -1]) sampled_length = tf.reshape(sampled_length, [batch_size]) # Build the full prediction. if self.features_inputter.mark_start: # Remove leading <s> if included in the context sequence. ids = ids[:, 1:] length -= 1 full_ids = tf.concat([ids, sampled_ids], 1) full_length = length + sampled_length tokens = self.features_inputter.ids_to_tokens.lookup(full_ids) predictions = dict(tokens=tokens, length=full_length) return outputs, predictions
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN outputs, predictions = None, None ids, length = features["ids"], features["length"] if mode != tf.estimator.ModeKeys.PREDICT: # For training and evaluation, forward the full sequence. logits, _ = self._decode(ids, length, training=training) outputs = dict(logits=logits) else: assert_fixed_length = tf.debugging.Assert( tf.reduce_all(tf.equal(length, tf.reduce_max(length))), ["Language model does not support variable length contexts during " "generation, consider setting batch_size or bucket_width to 1"]) # Run decoder one the context, if any. with tf.control_dependencies([assert_fixed_length]): context_ids, start_ids = tf.split(ids, [tf.shape(ids)[1] - 1, 1], axis=1) context_length = length - 1 batch_size = tf.shape(context_length)[0] state = tf.cond( tf.equal(tf.reduce_sum(context_length), 0), true_fn=lambda: self.decoder.get_initial_state(batch_size=batch_size, dtype=self.dtype), false_fn=lambda: self._decode(context_ids, context_length)[1], name=self.name + "/") # Force the name scope. sampling_topk = params.get("sampling_topk") if sampling_topk is not None and sampling_topk != 1: sampler = decoding.RandomSampler( from_top_k=sampling_topk, temperature=params.get("sampling_temperature")) else: sampler = decoding.BestSampler() def _decode_with_step_offset(ids, step, state): return self._decode(ids, step + context_length[0], state) # Iteratively decode from the last decoder state. with tf.variable_scope(tf.get_variable_scope(), reuse=True): sampled_ids, sampled_length, _, _, _ = decoding.dynamic_decode( _decode_with_step_offset, tf.squeeze(start_ids, 1), initial_state=state, sampler=sampler, maximum_iterations=params.get("maximum_iterations", 250), minimum_iterations=params.get("minimum_decoding_length", 0)) sampled_ids = tf.squeeze(sampled_ids, 1) sampled_length = tf.squeeze(sampled_length, 1) # Build the full prediction. full_ids = tf.concat([ids, sampled_ids], 1) full_length = length + sampled_length vocab_rev = self.examples_inputter.vocabulary_lookup_reverse() tokens = vocab_rev.lookup(full_ids) predictions = dict(tokens=tokens, length=full_length) return outputs, predictions
def dynamic_decode_and_search(self, embedding, start_tokens, end_token, vocab_size=None, initial_state=None, output_layer=None, beam_width=5, length_penalty=0.0, maximum_iterations=250, minimum_length=0, mode=tf.estimator.ModeKeys.PREDICT, memory=None, memory_sequence_length=None, dtype=None, return_alignment_history=False, sample_from=None, sample_temperature=None, coverage_penalty=0.0): """Decodes dynamically from :obj:`start_tokens` with beam search. Usually used for inference. Args: embedding: The embedding tensor or a callable that takes word ids. start_tokens: The start token ids with shape :math:`[B]`. end_token: The end token id. vocab_size: The output vocabulary size. Must be set if :obj:`output_layer` is not set. initial_state: The initial state as a (possibly nested tuple of...) tensors. output_layer: Optional layer to apply to the output prior sampling. Must be set if :obj:`vocab_size` is not set. beam_width: The width of the beam. length_penalty: The length penalty weight during beam search. maximum_iterations: The maximum number of decoding iterations. minimum_length: The minimum length of decoded sequences (:obj:`end_token` excluded). mode: A ``tf.estimator.ModeKeys`` mode. memory: (optional) Memory values to query. memory_sequence_length: (optional) Memory values length. dtype: The data type. Required if :obj:`memory` is ``None``. return_alignment_history: If ``True``, also returns the alignment history from the attention layer (``None`` will be returned if unsupported by the decoder). sample_from: Sample predictions from the :obj:`sample_from` most likely tokens. If 0, sample from the full output distribution. sample_temperature: Value dividing logits. In random sampling, a high value generates more random samples. coverage_penalty: The coverage penalty weight during beam search. Returns: A tuple ``(predicted_ids, state, sequence_length, log_probs)`` or ``(predicted_ids, state, sequence_length, log_probs, alignment_history)`` if :obj:`return_alignment_history` is ``True``. """ batch_size = tf.shape(start_tokens)[0] * beam_width if dtype is None: if memory is None: raise ValueError( "dtype argument is required when no memory is set") dtype = tf.contrib.framework.nest.flatten(memory)[0].dtype if beam_width > 1: if initial_state is not None: initial_state = tf.contrib.seq2seq.tile_batch( initial_state, multiplier=beam_width) if memory is not None: memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width) if memory_sequence_length is not None: memory_sequence_length = tf.contrib.seq2seq.tile_batch( memory_sequence_length, multiplier=beam_width) embedding_fn = get_embedding_fn(embedding) step_fn, initial_state = self.step_fn( mode, batch_size, initial_state=initial_state, memory=memory, memory_sequence_length=memory_sequence_length, dtype=dtype) if output_layer is None: if vocab_size is None: raise ValueError( "vocab_size must be known when the output_layer is not set" ) output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype) def _symbols_to_logits_fn(ids, step, state): inputs = embedding_fn(ids) returned_values = step_fn(step, inputs, state, mode) if self.support_alignment_history: outputs, state, attention = returned_values else: outputs, state = returned_values attention = None logits = output_layer(outputs) return logits, state, attention if beam_width == 1: decoding_strategy = decoding.GreedySearch() else: decoding_strategy = decoding.BeamSearch( beam_width, length_penalty=length_penalty, coverage_penalty=coverage_penalty) if sample_from is not None and sample_from != 1: sampler = decoding.RandomSampler(from_top_k=sample_from, temperature=sample_temperature) else: sampler = decoding.BestSampler() outputs, lengths, log_probs, attention, state = decoding.dynamic_decode( _symbols_to_logits_fn, start_tokens, end_id=end_token, initial_state=initial_state, decoding_strategy=decoding_strategy, sampler=sampler, maximum_iterations=maximum_iterations, minimum_iterations=minimum_length, attention_history=self.support_alignment_history and not isinstance(memory, (list, tuple)), attention_size=tf.shape(memory)[1] if self.support_alignment_history else None) # For backward compatibility, include </s> in length. lengths = tf.minimum(lengths + 1, tf.shape(outputs)[2]) if return_alignment_history: return (outputs, state, lengths, log_probs, attention) return (outputs, state, lengths, log_probs)