def Encode(self, text): """Converts string `text` to integer ids and the encoded string. Encoding includes prefixing the beginning-of-word token to each word. Returns: (ids, tokens) where ids is the encoded integer ids and tokens is the encoded string. """ words = tf.sparse.to_dense(tf.strings.split([text]), default_value='')[0] num_words = tf.size(words) ids_ta = tf.TensorArray(tf.int32, 0, dynamic_size=True) def _WordsToIds(i, words, ids_ta): encoded_ids = self._EncodeToIds(BOW_STR + words[i]) ids_ta = ids_ta.scatter( tf.range(ids_ta.size(), ids_ta.size() + tf.size(encoded_ids)), encoded_ids) return i + 1, words, ids_ta _, _, ids_ta = tf.while_loop( lambda i, *_: i < num_words, _WordsToIds, loop_vars=(tf.constant(0, tf.int32), words, ids_ta), parallel_iterations=30, back_prop=False) ids = ids_ta.stack() return ids, self._TokenToString(ids)
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0: # For rank > 1 tensors we make use of an efficient matmul based gather # on tpu that takes in account the range of the values. For R1, we # rely on the tf.gather and xla to optimize it efficiently for R1 # layout. if x_in.shape.ndims > 1: if p.batch_major_state: num_hyps = tf.shape(old_hyp_ids)[0] x_out = beam_search_tpu_ops.fast_gather( x_in, old_hyp_ids, num_hyps, max_value=None, batch_major_state=p.batch_major_state) else: # Use corrected indices only here for batch major compute as # key/value caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1 x_out, _ = tf.while_loop( lambda _, t: t <= cur_step, _GatherStep, (x_in, tf.zeros([], tf.int32))) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def wrap_computation_in_while_loop(op_fn, n, host_device): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def computation(i): ops = op_fn() if not isinstance(ops, list): ops = [ops] with tf.control_dependencies(ops): return tf.Print(i + 1, [i], 'while_loop:') with tf.device(host_device): return tf.while_loop( lambda i: tf.less(i, n), computation, [tf.constant(0)], parallel_iterations=1)
def _EncodeToIds(self, word): # Below: # * a token is a wordpiece ID. # * the tokens array will be merged in-place. # * the candidates array is an array of size len(tokens) - 1. # It contains the token for the merged wordpiece, if it exists, # -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]). # First, split into basic UTF-8 characters (letters). chars = tf.strings.unicode_split(word, 'UTF-8') tokens = self._StringToToken(chars) tokens = tf.where( tf.equal(tokens, NO_TOKEN), # Unseen character. tf.broadcast_to(self.unk_id, tf.shape(tokens)), tokens) # Create initial candidate list. candidates = tf.map_fn( self._MergeTokens, (tokens[:-1], tokens[1:]), dtype=tokens.dtype) def _ShouldMerge(unused_tokens, candidates): """Merge until not possible, or we abort early according to merge_prob.""" return tf.math.logical_and( tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)), tf.random.uniform([]) < self._merge_prob) def _MergeOneToken(tokens, i): return tf.expand_dims( self._MergeTokens((tokens[i], tokens[i + 1])), axis=-1) def _MergeCandidates(tokens, candidates): """Merge in the reverse binary tree.""" best_id = tf.argmin(candidates, output_type=tf.int32) # Perform the merge at position best_id. tokens = tf.concat( [tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:]], axis=0) # Recompute the merge candidates. # Only the neighbors of best_id need to be recomputed. empty = tf.zeros([0], dtype=candidates.dtype) def _MergeLeft(): return tf.concat( [candidates[:best_id - 1], _MergeOneToken(tokens, best_id - 1)], axis=0) left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft) def _MergeRight(): return tf.concat( [_MergeOneToken(tokens, best_id), candidates[best_id + 2:]], axis=0) right_candidates = tf.cond( tf.greater_equal(best_id, tf.size(tokens) - 1), lambda: empty, _MergeRight) candidates = tf.concat([left_candidates, right_candidates], axis=0) return tokens, candidates return tf.while_loop( _ShouldMerge, _MergeCandidates, (tokens, candidates), parallel_iterations=1, back_prop=False)[0]
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.math.logical_and( cur_step < max_steps, tf.math.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices): """Process all per-example tensor outfeed data for a TPU sess.run. Args: per_example_tensors: dict of key -> tensor as generated by TpuTrainStep. num_loops: number of times that TpuTrainStep will be executed by TpuTrain. num_devices: number of TPU cores assigned to this process. Returns: A dict of per-example tensors from the latest TpuTrainStep. """ if not per_example_tensors: return tf.no_op() tensor_shapes = [ py_utils.GetShape(per_example_tensors[key]) for key in sorted(per_example_tensors) ] tensor_types = [ tf.as_dtype(per_example_tensors[key].dtype) for key in sorted(per_example_tensors) ] def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): for core in range(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal(replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write(offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays) def LoopCond(i, *output_arrays): del output_arrays return i < num_loops output_arrays = [] for i in range(len(tensor_shapes)): output_arrays.append( tf.TensorArray( tensor_types[i], size=num_loops * num_devices, element_shape=tensor_shapes[i])) # Loop once for each time that TpuTrainStep runs. output_arrays = tf.while_loop( LoopCond, LoopBody, [0] + output_arrays, parallel_iterations=1)[1:] concatenated_arrays = [array.concat() for array in output_arrays] return dict(zip(sorted(per_example_tensors), concatenated_arrays))
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): """Takes a tensor of strings and returns id/padding tensors. This generates `token_ids`, `target_ids`, and `paddings` in the format that is expected for tokenizers. This performs padding to a fixed length and appends the end-of-sentence token as appropriate. Args: strs: a string Tensor. max_length: a python integer. The second dimension of the returned arrays. All sequences are padded or truncated to that length. append_eos: a python bool. See `BaseTokenizer` for explanation. languages: A vector of strings with the same length as `strs`. Returns: A tuple of 3 tensors: - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences always end with EOS unless the sequence exceeds the maximum length. Always padded with EOS. - target_ids: a tensor of sequences of WPM ids not starting with SOS but ending with EOS. Always padded with EOS. - paddings: a tensor of floats indicating, at each position, whether the corresponding position is padded. """ p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo( tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo( tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)