def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # TODO(rpang): avoid inspecting 'encoder_outputs'. source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: source_seq_lengths = tf.cast( tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def _updated_statistics(self, var, partitioned_grads): """Returns updated Shampoo statistics L_t, R_t, etc. Args: var: tf.Variable associated with the gradient. partitioned_grads: Partitioned gradient tensor. Returns: A list of updated statistics matrices. """ precond_statistics_update = [] num_partitions = len(partitioned_grads) mat_stats = [] mat_grads = [] mat_dims = [] for pt_idx, pt_grad in enumerate(partitioned_grads): pt_shape = pt_grad.get_shape() preconditioner_exists_for_dim = ( self._preconditioner_available_for_dims(pt_shape)) rank = len(pt_shape) # Calculates the preconditioner statistics for each tensor. for i in range(rank): if preconditioner_exists_for_dim[i]: mat_stats.append( self.get_slot( var, self._statistics_key_for_partition_and_dim( i, pt_idx, num_partitions))) mat_grads.append(pt_grad) mat_dims.append(i) # axes is the list of indices to reduce - everything but # the current i. def _update_statistics(dim, stat_var, grad): """Update preconditioner statistics.""" with tf.name_scope("GradientStatistics"): var_rank = len(grad.get_shape()) axes = list(range(dim)) + list(range(dim + 1, var_rank)) new_stat = math_ops.tensordot(grad, grad, axes=(axes, axes)) if self._second_moment_averaging == 1.0: updated_stat = state_ops.assign_add(stat_var, new_stat) else: updated_stat = state_ops.assign_add( stat_var, (self._second_moment_averaging - 1.0) * stat_var + (1.0 - self._second_moment_averaging) * new_stat) return updated_stat if self._statistics_computation_frequency <= 1: for mat_stat, mat_grad, dim in zip(mat_stats, mat_grads, mat_dims): precond_statistics_update.append( _update_statistics(dim, mat_stat, mat_grad)) else: # NOTE: We rewrite tf.cond() as a while loop to avoid certain overheads # in XLA from buffer allocation. def _loop_body(mat_stats, mat_grads, mat_dims, unused_perform_step): precond_statistics_update_ops = [] for mat_stat, mat_grad, dim in zip(mat_stats, mat_grads, mat_dims): precond_statistics_update_ops.append( _update_statistics(dim, mat_stat, mat_grad)) with tf.control_dependencies(precond_statistics_update_ops): return tf.constant(False) loop_body_fn = functools.partial(_loop_body, mat_stats, mat_grads, mat_dims) precond_statistics_update.append( tf.while_loop(lambda perform_step: perform_step, loop_body_fn, [self._run_statistics_computation])) return precond_statistics_update
def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices): """Process all per-example tensor outfeed data for a TPU sess.run. Args: per_example_tensors: dict of key -> tensor as generated by TpuTrainStep. num_loops: number of times that TpuTrainStep will be executed by TpuTrain. num_devices: number of TPU cores assigned to this process. Returns: A dict of per-example tensors from the latest TpuTrainStep. """ if not per_example_tensors: return tf.no_op() tensor_shapes = [ py_utils.GetShape(per_example_tensors[key]) for key in sorted(per_example_tensors) ] tensor_types = [ tf.as_dtype(per_example_tensors[key].dtype) for key in sorted(per_example_tensors) ] def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in xrange(device_assignment.num_replicas): for core in xrange(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device( replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal( replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write( offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays) def LoopCond(i, *output_arrays): del output_arrays return i < num_loops output_arrays = [] for i in range(len(tensor_shapes)): output_arrays.append( tf.TensorArray(tensor_types[i], size=num_loops * num_devices, element_shape=tensor_shapes[i])) # Loop once for each time that TpuTrainStep runs. output_arrays = tf.while_loop(LoopCond, LoopBody, [0] + output_arrays, parallel_iterations=1)[1:] concatenated_arrays = [array.concat() for array in output_arrays] return dict(zip(sorted(per_example_tensors), concatenated_arrays))
def ComputePredictions(self, encoder_outputs, pronunciations, is_inference=False): """Computes the predictions from the encoder_outputs, updating losses. Despite the name, this function does the bulk of the decoding and loss computation, incrementing the loss at each time step. Args: encoder_outputs: a NestedMap consisting of outputs of the FeatureNeighborhoodEncoder with encoded - encoding of the input spelling neighbor_pronunciations_encoded - encodings of the neighbor prons neighbor_pronunciations_encoded - encodings of the neighbor spellings state - encoder state to which has been added dec_input - seed output for the decoder [*, 1] tensor consisting of sentence start indices (corresponding to "<s>") pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len] tensor of pronunciations is_inference: If False then uses teacher forcing else does autoregression. Returns: NestedMap with loss, per_sequence_losses,labels, a [*, max_pronunciation_len] tensor of predictions, and attention ([*, max_pronunciation_len, max_spelling_len]), and neighbor_attention ([*, max_pronunciation_len, max_neighbors]) tensors, along with the raw batch passed through from the encoder. """ p = self.params targets = pronunciations.pronunciations t_len = int(targets.get_shape().as_list()[1]) t_idx = tf.constant(0) attention = tf.TensorArray(dtype=tf.float32, size=t_len) neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len) outputs = tf.TensorArray(dtype=tf.float32, size=t_len) loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len) dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size) state = encoder_outputs.state # pylint: disable=missing-docstring def loop_body(t_idx, dec_input, attention, neighbor_attention, state, outputs): decoder_result = self.Decode(encoder_outputs, dec_input, state) outputs = outputs.write(t_idx, decoder_result.predictions) attention = attention.write(t_idx, decoder_result.attention_weights) neighbor_attention = neighbor_attention.write( t_idx, tf.cast(decoder_result.neighbor_attention_weights, dtype=tf.float32)) if is_inference: dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1), tf.int32) else: dec_input = targets[:, t_idx] t_idx = t_idx + 1 state = decoder_result.state return t_idx, dec_input, attention, neighbor_attention, state, outputs _, _, attention, neighbor_attention, state, outputs = tf.while_loop( loop_cond, loop_body, loop_vars=[ t_idx, dec_input, attention, neighbor_attention, state, outputs ]) outputs = tf.transpose(outputs.stack(), [1, 0, 2]) labels = tf.argmax(outputs, axis=-1) mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)), dtype=tf.float32) loss = self._loss_object(targets, outputs, sample_weight=mask) loss = tf.reduce_sum(loss, axis=1) per_sequence_losses = (loss / t_len) loss = tf.reduce_mean(per_sequence_losses) predictions = py_utils.NestedMap() predictions.loss = loss predictions.per_sequence_losses = per_sequence_losses predictions.labels = labels predictions.attention = tf.transpose(tf.squeeze(attention.stack()), perm=[1, 0, 2]) if p.use_neighbors: predictions.neighbor_attention = tf.transpose(tf.squeeze( neighbor_attention.stack()), perm=[1, 0, 2]) else: predictions.neighbor_attention = tf.squeeze( neighbor_attention.stack()) # Expose this for subsequent data analysis predictions.batch = encoder_outputs.batch return predictions
def _EncodeToIds(self, word): # Below: # * a token is a wordpiece ID. # * the tokens array will be merged in-place. # * the candidates array is an array of size len(tokens) - 1. # It contains the token for the merged wordpiece, if it exists, # -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]). # First, split into basic UTF-8 characters (letters). chars = tf.strings.unicode_split(word, 'UTF-8') tokens = self._StringToToken(chars) tokens = tf.where( tf.equal(tokens, NO_TOKEN), # Unseen character. tf.broadcast_to(self.unk_id, tf.shape(tokens)), tokens) # Create initial candidate list. candidates = tf.map_fn(self._MergeTokens, (tokens[:-1], tokens[1:]), dtype=tokens.dtype) def _ShouldMerge(unused_tokens, candidates): """Merge until not possible, or we abort early according to merge_prob.""" return tf.logical_and( tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)), tf.random.uniform([]) < self._merge_prob) def _MergeOneToken(tokens, i): return tf.expand_dims(self._MergeTokens( (tokens[i], tokens[i + 1])), axis=-1) def _MergeCandidates(tokens, candidates): """Merge in the reverse binary tree.""" best_id = tf.argmin(candidates, output_type=tf.int32) # Perform the merge at position best_id. tokens = tf.concat([ tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:] ], axis=0) # Recompute the merge candidates. # Only the neighbors of best_id need to be recomputed. empty = tf.zeros([0], dtype=candidates.dtype) def _MergeLeft(): return tf.concat([ candidates[:best_id - 1], _MergeOneToken(tokens, best_id - 1) ], axis=0) left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft) def _MergeRight(): return tf.concat([ _MergeOneToken(tokens, best_id), candidates[best_id + 2:] ], axis=0) right_candidates = tf.cond( tf.greater_equal(best_id, tf.size(tokens) - 1), lambda: empty, _MergeRight) candidates = tf.concat([left_candidates, right_candidates], axis=0) return tokens, candidates return tf.while_loop(_ShouldMerge, _MergeCandidates, (tokens, candidates), parallel_iterations=1, back_prop=False)[0]
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): """Takes a tensor of strings and returns id/padding tensors. This generates `token_ids`, `target_ids`, and `paddings` in the format that is expected for tokenizers. This performs padding to a fixed length and appends the end-of-sentence token as appropriate. Args: strs: a string Tensor. max_length: a python integer. The second dimension of the returned arrays. All sequences are padded or truncated to that length. append_eos: a python bool. See `BaseTokenizer` for explanation. languages: A vector of strings with the same length as `strs`. Returns: token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences always end with EOS unless the sequence exceeds the maximum length. Always padded with EOS. target_ids: a tensor of sequences of WPM ids not starting with SOS but ending with EOS. Always padded with EOS. paddings: a tensor of floats indicating, at each position, whether the corresponding position is padded. """ p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo( tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo( tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.to_int32(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings
def FarthestPointSampler(points, padding, num_sampled_points, precomputed_squared_distance=None, num_seeded_points=0, random_seed=None): """Samples num_sampled_points from points using farthest point sampling. Algorithm: 1. Start by selecting a random point and adding to a selected set. 2. For all remaining points, find the furthest point from those selected. 3. Add furthest point to selected. 4. Repeat 2-3 until num_sampled_points are selected. More details at https://en.wikipedia.org/wiki/Farthest-first_traversal This output of this function can be used with tf.batch_gather to extract the desired points, for example: tf.batch_gather(points, sampled_idx) Args: points: floating point tf.Tensor of shape [N, P1, dims] padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is real, and 1 otherwise. num_sampled_points: integer number of points to sample. precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of distances between each point. if None, distances will be computed on the fly. num_seeded_points: If num_seeded_points > 0, then the first num_seeded_points in points are considered to be seeded in the FPS sampling. Note that we assume that these points are *not* padded, and do not check padding when seeding them. random_seed: optional integer random seed to use with all the random ops. Returns: A tuple of tf.Tensors (sampled_idx, closest_idx) of types (tf.int32, tf.int32). sampled_idx is of shape [N, num_sampled_points] representing the indices selected using the sampler. This will have range of [0, P1]. closest_idx is of shape [N, P1] representing the indices of the closest sampled points for each input point. closest_idx is used in PCNN as part of the pooling operation: each point is assigned to the closest sampled point and a max is taken over them. This will have a range of [0, P2] with the index of the closest sampled point that remains. """ points = py_utils.HasRank(points, 3) batch_size, num_points, dims = py_utils.GetShape(points, 3) points = py_utils.with_dependencies( [py_utils.assert_greater_equal(num_points, num_sampled_points)], points) # Add a tiny bit of noise to the distance matrix or points so all # points are unique. This will also ensure true repeated points # like padded points are only selected after all valid points are selected. if precomputed_squared_distance is not None: precomputed_squared_distance = py_utils.HasShape( precomputed_squared_distance, [batch_size, num_points, num_points]) precomputed_squared_distance += tf.random.uniform( (batch_size, num_points, 1), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) else: points += tf.random.uniform((batch_size, num_points, dims), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) # TensorArray to store the sampled indices in the loop. sampled_idx = tf.TensorArray(tf.int32, num_sampled_points) # Initialize distance_to_selected to inf for all points. distance_to_selected = float('inf') * tf.ones((batch_size, num_points)) # For tracking the index to the closest selected point. closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32) # Current loop index counter. curr_idx = tf.constant(0, dtype=tf.int32) # Get number of valid points (1 is padded, so num_points - num_padded). num_valid_points = tf.cast( tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1), dtype=tf.int32) def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where( tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax( padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size,), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond( tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond( tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx _, _, sampled_idx, closest_idx = tf.while_loop( lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points), _BodyFn, loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx), back_prop=False, maximum_iterations=num_sampled_points) sampled_idx = sampled_idx.stack() # num_sampled_points x n sampled_idx = tf.transpose(sampled_idx, [1, 0]) if isinstance(batch_size, int) and isinstance(num_sampled_points, int): sampled_idx.set_shape((batch_size, num_sampled_points)) return sampled_idx, closest_idx
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): del languages p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, text, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" if tf.is_tensor(i): text_i = tf.gather(text, i) else: text_i = text[i] ids = self._tokenizer.tokenize(text_i).merge_dims(0, -1) ids.set_shape([None]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) sos_ids = tf.concat([[self.sos_id], ids], axis=0) if p.prepend_sos: ids = sos_ids # This truncates after the EOS is added, so some sentences might # not have EOS at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(sos_ids, [max_length], 0)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], 0)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings