def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def add_point_cloud(self, feature, laser_names, range_image_pose): """Convert the range images in `feature` to 3D point clouds. Adds the point cloud data to the tf.Example feature map. Args: feature: A tf.Example feature map. laser_names: A list of laser names (e.g., 'TOP', 'REAR', 'SIDE_LEFT'). range_image_pose: A range image pose Tensor for the top laser. """ # Stash metadata for laser. These metadata can be useful # for reconstructing the range image. self.laser_info = {} for laser_name in laser_names: beam_inclinations = np.array( feature['%s_beam_inclinations' % laser_name].float_list.value[:]) # beam_inclinations will be populated if there is a non-uniform # beam configuration (e.g., for the TOP lasers). Others that have # uniform beam inclinations are only parameterized by the min and max. # We use these min and max if the beam_inclinations are not present, # and turn them into a uniform inclinations array. if beam_inclinations.size == 0: beam_inclination_min = feature['%s_beam_inclination_min' % laser_name].float_list.value[:] beam_inclination_max = feature['%s_beam_inclination_max' % laser_name].float_list.value[:] laser_ri_name = '%s_ri1' % laser_name range_image_shape = feature[laser_ri_name + '_shape'].int64_list.value[:] height = tf.cast(range_image_shape[0], tf.float32) beam_inclinations = tf.constant( [beam_inclination_min[0], beam_inclination_max[0]]) beam_inclinations = range_image_utils.compute_inclination( beam_inclinations, height) beam_extrinsics = np.array( feature['%s_extrinsics' % laser_name].float_list.value[:]).reshape(4, 4) for ri_type in ['ri1', 'ri2']: laser_ri_name = '%s_%s' % (laser_name, ri_type) # For each of the 4 features of the lasers: range_image = np.array( feature[laser_ri_name].float_list.value[:]) range_image_shape = feature[laser_ri_name + '_shape'].int64_list.value[:] range_image = range_image.reshape(range_image_shape) # Compute mask. At the moment, invalid values in the range image # representation are indicated via a -1. entry. Callers are expected # to create this mask when passing into the conversion function below. range_image_mask = range_image[..., 0] >= 0 # Get the 'range' feature from the range images. range_image_range = range_image[..., 0] # Call utility to convert point cloud to cartesian coordinates. # # API expects a batch dimension for all inputs. batched_pixel_pose = None batched_frame_pose = None # At the moment, only the top has per-pixel pose. if laser_name == 'TOP' and range_image_pose is not None: batched_pixel_pose = range_image_pose[tf.newaxis, ...] batched_frame_pose = self.frame_pose[tf.newaxis, ...] batched_range_image_range = tf.convert_to_tensor( range_image_range[np.newaxis, ...], dtype=tf.float32) batched_extrinsics = tf.convert_to_tensor( beam_extrinsics[np.newaxis, ...], dtype=tf.float32) batched_inclinations = tf.convert_to_tensor( beam_inclinations[np.newaxis, ...], dtype=tf.float32) batched_inclinations = tf.reverse(batched_inclinations, axis=[-1]) range_image_cartesian = ( range_image_utils.extract_point_cloud_from_range_image( batched_range_image_range, batched_extrinsics, batched_inclinations, pixel_pose=batched_pixel_pose, frame_pose=batched_frame_pose)) info = py_utils.NestedMap() self.laser_info[laser_ri_name] = info info.range_image = range_image info.range_image_shape = range_image_shape ri_indices = tf.where(range_image_mask) points_xyz = tf.gather_nd(range_image_cartesian[0], ri_indices) info.num_points = tf.shape(points_xyz).numpy()[0] # Fetch the features corresponding to each xyz coordinate and # concatentate them together. points_features = tf.cast( tf.gather_nd(range_image[..., 1:], ri_indices), tf.float32) if self._use_range_image_index_as_lidar_feature: points_data = tf.concat([ points_xyz, tf.cast(ri_indices, tf.float32), points_features[..., 2:] ], axis=-1) else: points_data = tf.concat([points_xyz, points_features], axis=-1) # Add laser feature to output. # # Skip embedding shape since we assume that all points have six features # and so we can reconstruct the number of points. points_list = list(points_data.numpy().reshape([-1])) feature['laser_%s' % laser_ri_name].float_list.value[:] = points_list laser_ri_flow_name = '%s_flow' % laser_ri_name if laser_ri_flow_name in feature: range_image_flow = np.array( feature[laser_ri_flow_name].float_list.value[:]) range_image_flow_shape = feature[ laser_ri_flow_name + '_shape'].int64_list.value[:] range_image_flow = range_image_flow.reshape( range_image_flow_shape) flow_data = tf.cast( tf.gather_nd(range_image_flow, ri_indices), tf.float32) flow_list = list(flow_data.numpy().reshape([-1])) feature['laser_%s' % laser_ri_flow_name].float_list.value[:] = flow_list
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None): """Computes attention according to a sparsity pattern. We use the following capital letters to denote shape parameters: B = batch size S = length of the source sequence T = length of the target sequence N = number of attention heads H = dimensions of each attention head K = number of clusters W = attention window (K <= S) The 'sparsity_indices' is a tensor of integral type where the last dimension contains W indices (W is the attention window) for each corresponding position along S in 'k' that the query is allowed to attend to. For example, if sparsity_indices[batch_idx, target time step, head_idx] = [1, 7, 8], it means that token in the query attends to values with indices 1, 7, and 8, and the attention window here is 3. The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1 is reserved to mean paddings, distinct from the value (S-1). For example, if W=S and 'sparsity_indices' contains range(S) on the last dimension, this degenerates to the original full attention. We require that 'sparsity_indices' does not contain duplicates (except for -1 to indicate paddings), but we do not require 'sparsity_indices' to be sorted. Note that this implementation is flexible and geneic but is not optimized for time or space complexity. Please consider grouping queries that attend to the same subset of values first for efficiency. Args: q: (projected) queries, [B, T, N, H]; k: (projected) keys, [B, S, N, H]; v: (projected) values, [B, S, N, H]; sparsity_indices: [B, T, N, W], where W is the attention window; paddings: paddings for keys, [B, S] if not None. Returns: output: the encoded output, [B, T, N, H]. atten_probs: the attention weights, [B, T, N, S]. """ q = tf.convert_to_tensor(q) k = tf.convert_to_tensor(k) v = tf.convert_to_tensor(v) sparsity_indices = tf.convert_to_tensor(sparsity_indices) k = py_utils.HasRank(k, 4) _, source_length, _, dim_per_head = py_utils.GetShape(k, 4) sparsity_indices = py_utils.HasRank(sparsity_indices, 4) batch_size, target_length, num_heads, attention_window = py_utils.GetShape( sparsity_indices, 4) py_utils.assert_less_equal( attention_window, source_length, 'The provided sparsity_indices has attention window ' ' > source length. This is likely an error.') # To prepare for gathering the relevant vectors from 'k', we prepare # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to # slices in 'k' indexed by (batch index, source time step, head index), # where the source length index comes from the original W dimension in # 'sparsity_indices'. seq_idx = tf.expand_dims(sparsity_indices, axis=-1) # Overwrite the paddings -1 with valid gather indices (zeros). We will # fix the logits with -inf in these positions later. seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx) batch_idx = tf.reshape( tf.range(0, batch_size, dtype=sparsity_indices.dtype), [batch_size, 1, 1, 1, 1]) batch_idx = tf.tile(batch_idx, [1, target_length, num_heads, attention_window, 1]) head_idx = tf.reshape(tf.range(0, num_heads, dtype=sparsity_indices.dtype), [1, 1, num_heads, 1, 1]) head_idx = tf.tile(head_idx, [batch_size, target_length, 1, attention_window, 1]) # [B, T, N, W, 3], where last dimension is (batch index, source length index, # head index). gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1) # Both the gathered k and v have shape [B, T, N, W, H] k = tf.gather_nd(k, gather_idx) v = tf.gather_nd(v, gather_idx) if paddings is None: paddings = tf.zeros([batch_size, source_length]) paddings = tf.convert_to_tensor(paddings) paddings = tf.expand_dims(paddings, axis=-1) # [B, S, N] paddings = tf.tile(paddings, [1, 1, num_heads]) # [B, T, N, W] paddings = tf.gather_nd(paddings, gather_idx) logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k) logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype)) very_negative_logits = (tf.ones_like(logits) * logits.dtype.max * tf.constant(-0.7, dtype=logits.dtype)) padded_logits = tf.where( tf.math.logical_or(sparsity_indices < 0, paddings > 0.0), very_negative_logits, logits) # [B, T, N, W] atten_probs = tf.nn.softmax(padded_logits, name='attention_weights') atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits), atten_probs) output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v) # Scatter 'atten_probs' back into the original source length. # [B, T, N, W, 1] batch_idx = tf.tile( tf.range(batch_size)[:, None, None, None, None], [1, target_length, num_heads, attention_window, 1]) # [B, T, N, W, 1] target_seq_idx = tf.tile( tf.range(target_length)[None, :, None, None, None], [batch_size, 1, num_heads, attention_window, 1]) # [B, T, N, W, 1] head_idx = tf.tile( tf.range(num_heads)[None, None, :, None, None], [batch_size, target_length, 1, attention_window, 1]) # seq_idx: [B, T, N, W, 1] # [B, T, N, W, 4] scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1) # [B, T, N, S] scattered_probs = tf.scatter_nd( scatter_idx, atten_probs, [batch_size, target_length, num_heads, source_length]) return output, scattered_probs
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where( tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax( padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size,), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond( tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond( tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
def FProp(self, theta, input_batch): # pyformat: disable """Compute features for the pillars and convert them back to a dense grid. Args: theta: A `.NestedMap` object containing variable values of this task. input_batch: A `.NestedMap` object containing input tensors. Following keys are required: - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz], where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels in each axis dimension). - pillar_points: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3 + num_laser_features] - pillar_centers: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3] - pillar_locations: Float tensor with shape [batch size, num_pillars, 3] Returns: The dense features with shape [b, nx, ny, nz * fdims]. """ # pyformat: enable p = self.params bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4) # Process points to concatenate a set of fixed features (e.g., # add means, centers, normalize points to means). num_features = 3 + p.num_laser_features pillar_points = py_utils.HasShape(input_batch.pillar_points, [bs, -1, -1, num_features]) _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4) pillar_xyz = pillar_points[..., :3] # Compute number of points per pillar and prepare for broadcasting. pillar_num_points = tf.gather_nd(input_batch.grid_num_points, input_batch.pillar_locations, batch_dims=1) pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis] # Compute mean by computing sum and dividing by number of points. Clip the # denominator by 1.0 to gracefully handle empty pillars. pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keepdims=True) pillar_means = pillar_sum / tf.maximum( tf.cast(pillar_num_points, tf.float32), 1.0) pillar_feats = pillar_points[..., 3:] pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3]) pillar_concat = tf.concat(axis=3, values=[ pillar_xyz - pillar_means, pillar_feats, tf.tile(pillar_means, [1, 1, npoints, 1]), tf.tile(pillar_centers, [1, 1, npoints, 1]) ]) # Featurize pillars. pillar_features = self.featurizer.FProp(theta.featurizer, pillar_concat) # Convert back to the dense grid. pillar_locations = py_utils.HasShape(input_batch.pillar_locations, [bs, npillars, 3]) dense_features = SparseToDense(grid_shape=(nx, ny, nz), locations=pillar_locations, feats=pillar_features) return dense_features
def GatherK(selected_pos, values, k, num_devices=1): """Gather up to k elements from given tensors at selected pos under SPMD. Example:: # Input k = 3 selected_pos = [ [0, 0, 1, 1], [0, 1, 1, 0], [0, 0, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], # topk(k=3) largest indices are selected in this row. ] value_2d = [ [1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23], [25, 27, 29, 31], [33, 35, 37, 39], ] # Output: output = [ [0, 5, 7], [0, 11, 13], [0, 0, 0], [25, 27, 29], [35, 37, 39], ] # Output padding: output_padding = [ [1, 0, 0], [1, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], ] Args: selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time]. values: a list of tensors, the rank of each is at least rank=2. [batch, time, ...]. k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a compile-time constant. num_devices: number of TPU devices used in xla_sharding SPMD. Returns: A tuple (output, padding). - output: a list of tensors of shape [batch, k, ...]. - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations. """ global_batch, seq_len = py_utils.GetShape(selected_pos, 2) if num_devices: device_batch = global_batch // num_devices else: device_batch = global_batch for i in range(len(values)): # Assert the first 2 dim of values[i] is [global_batch, seq_len] values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2) # indices are 1-based for now, to distinguish between padding and selected # locations. indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32) # [1, seq_len] indices = tf.expand_dims(indices, axis=0) # if 0, the position is not selected. # [1, seq_len] * [global_batch, seq_len] => [global_batch, t] # -- topk --> [global_batch, k] topk_indices, _ = tf.math.top_k( indices * tf.cast(selected_pos, indices.dtype), k) # [global_batch, k], sorted in ascending order. indices = tf.reverse(topk_indices, [-1]) # [global_batch, k], padded positions are '1's. padding = tf.cast(tf.equal(indices, 0), values[0].dtype) padding = Split(padding, 0, num_devices) # [global_batch, k], zero_based_indices mp_idx = tf.maximum(0, indices - 1) mp_idx = Split(mp_idx, 0, num_devices) # [device_batch, k] if num_devices > 1 and py_utils.use_tpu(): mp_idx = xla_sharding.auto_to_manual_spmd_partition( mp_idx, xla_sharding.get_op_sharding(mp_idx.op)) # [device_batch, k, 1] mp_idx = tf.expand_dims(mp_idx, -1) # [device_batch] batch_ids = tf.range(device_batch, dtype=tf.int32) # [device_batch, 1, 1] batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1]) # [device_batch, k, 1] batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1]) # [device_batch, k, 2] final_indices = tf.concat([batch_ids, mp_idx], axis=-1) output = [] for v in values: # Begin manually partition gather. v = Split(v, 0, num_devices) v_shape = v.shape.as_list() if num_devices > 1 and py_utils.use_tpu(): op_sharding = xla_sharding.get_op_sharding(v.op) v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding) # Returns [global_batch, k, ...] v_out = tf.gather_nd(v, final_indices) if num_devices > 1 and py_utils.use_tpu(): v_shape[1] = k v_out = xla_sharding.manual_to_auto_spmd_partition( v_out, op_sharding, full_shape=tf.TensorShape(v_shape)) output.append(v_out) return output, padding