def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def _ProcessBeamSearchDecodeOut(self, input_batch, encoder_outputs, decoder_outs): self.r1_shape = decoder_outs[0] self.r2_shape = decoder_outs[1] self.r3_shape = decoder_outs[2] tf.logging.info('r1_shape: %s', self.r1_shape) tf.logging.info('r2_shape: %s', self.r2_shape) tf.logging.info('r3_shape: %s', self.r3_shape) hyps = decoder_outs[3] prev_hyps = decoder_outs[4] done_hyps = decoder_outs[5] scores = decoder_outs[6] atten_probs = decoder_outs[7] eos_scores = decoder_outs[8] eos_atten_probs = decoder_outs[9] source_seq_lengths = decoder_outs[10] tlen = tf.cast( tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0), tf.int32) ret_dict = { 'target_ids': input_batch.tgt.ids[:, 1:], 'eval_weight': input_batch.eval_weight, 'tlen': tlen, 'hyps': hyps, 'prev_hyps': prev_hyps, 'done_hyps': done_hyps, 'scores': scores, 'atten_probs': atten_probs, 'eos_scores': eos_scores, 'eos_atten_probs': eos_atten_probs, 'source_seq_lengths': source_seq_lengths, } return ret_dict
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): """Takes a tensor of strings and returns id/padding tensors. This generates `token_ids`, `target_ids`, and `paddings` in the format that is expected for tokenizers. This performs padding to a fixed length and appends the end-of-sentence token as appropriate. Args: strs: a string Tensor. max_length: a python integer. The second dimension of the returned arrays. All sequences are padded or truncated to that length. append_eos: a python bool. See `BaseTokenizer` for explanation. languages: A vector of strings with the same length as `strs`. Returns: A tuple of 3 tensors: - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences always end with EOS unless the sequence exceeds the maximum length. Always padded with EOS. - target_ids: a tensor of sequences of WPM ids not starting with SOS but ending with EOS. Always padded with EOS. - paddings: a tensor of floats indicating, at each position, whether the corresponding position is padded. """ p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo( tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo( tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)