Exemple #1
0
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)
    def add_point_cloud(self, feature, laser_names, range_image_pose):
        """Convert the range images in `feature` to 3D point clouds.

    Adds the point cloud data to the tf.Example feature map.

    Args:
      feature: A tf.Example feature map.
      laser_names: A list of laser names (e.g., 'TOP', 'REAR', 'SIDE_LEFT').
      range_image_pose: A range image pose Tensor for the top laser.
    """
        # Stash metadata for laser. These metadata can be useful
        # for reconstructing the range image.
        self.laser_info = {}

        for laser_name in laser_names:
            beam_inclinations = np.array(
                feature['%s_beam_inclinations' %
                        laser_name].float_list.value[:])
            # beam_inclinations will be populated if there is a non-uniform
            # beam configuration (e.g., for the TOP lasers).  Others that have
            # uniform beam inclinations are only parameterized by the min and max.
            # We use these min and max if the beam_inclinations are not present,
            # and turn them into a uniform inclinations array.
            if beam_inclinations.size == 0:
                beam_inclination_min = feature['%s_beam_inclination_min' %
                                               laser_name].float_list.value[:]
                beam_inclination_max = feature['%s_beam_inclination_max' %
                                               laser_name].float_list.value[:]

                laser_ri_name = '%s_ri1' % laser_name
                range_image_shape = feature[laser_ri_name +
                                            '_shape'].int64_list.value[:]
                height = tf.cast(range_image_shape[0], tf.float32)

                beam_inclinations = tf.constant(
                    [beam_inclination_min[0], beam_inclination_max[0]])
                beam_inclinations = range_image_utils.compute_inclination(
                    beam_inclinations, height)

            beam_extrinsics = np.array(
                feature['%s_extrinsics' %
                        laser_name].float_list.value[:]).reshape(4, 4)

            for ri_type in ['ri1', 'ri2']:
                laser_ri_name = '%s_%s' % (laser_name, ri_type)
                # For each of the 4 features of the lasers:
                range_image = np.array(
                    feature[laser_ri_name].float_list.value[:])
                range_image_shape = feature[laser_ri_name +
                                            '_shape'].int64_list.value[:]
                range_image = range_image.reshape(range_image_shape)
                # Compute mask.  At the moment, invalid values in the range image
                # representation are indicated via a -1. entry.  Callers are expected
                # to create this mask when passing into the conversion function below.
                range_image_mask = range_image[..., 0] >= 0

                # Get the 'range' feature from the range images.
                range_image_range = range_image[..., 0]

                # Call utility to convert point cloud to cartesian coordinates.
                #
                # API expects a batch dimension for all inputs.
                batched_pixel_pose = None
                batched_frame_pose = None
                # At the moment, only the top has per-pixel pose.
                if laser_name == 'TOP' and range_image_pose is not None:
                    batched_pixel_pose = range_image_pose[tf.newaxis, ...]
                    batched_frame_pose = self.frame_pose[tf.newaxis, ...]

                batched_range_image_range = tf.convert_to_tensor(
                    range_image_range[np.newaxis, ...], dtype=tf.float32)
                batched_extrinsics = tf.convert_to_tensor(
                    beam_extrinsics[np.newaxis, ...], dtype=tf.float32)
                batched_inclinations = tf.convert_to_tensor(
                    beam_inclinations[np.newaxis, ...], dtype=tf.float32)

                batched_inclinations = tf.reverse(batched_inclinations,
                                                  axis=[-1])

                range_image_cartesian = (
                    range_image_utils.extract_point_cloud_from_range_image(
                        batched_range_image_range,
                        batched_extrinsics,
                        batched_inclinations,
                        pixel_pose=batched_pixel_pose,
                        frame_pose=batched_frame_pose))

                info = py_utils.NestedMap()
                self.laser_info[laser_ri_name] = info
                info.range_image = range_image
                info.range_image_shape = range_image_shape

                ri_indices = tf.where(range_image_mask)
                points_xyz = tf.gather_nd(range_image_cartesian[0], ri_indices)
                info.num_points = tf.shape(points_xyz).numpy()[0]

                # Fetch the features corresponding to each xyz coordinate and
                # concatentate them together.
                points_features = tf.cast(
                    tf.gather_nd(range_image[..., 1:], ri_indices), tf.float32)
                if self._use_range_image_index_as_lidar_feature:
                    points_data = tf.concat([
                        points_xyz,
                        tf.cast(ri_indices, tf.float32), points_features[...,
                                                                         2:]
                    ],
                                            axis=-1)
                else:
                    points_data = tf.concat([points_xyz, points_features],
                                            axis=-1)

                # Add laser feature to output.
                #
                # Skip embedding shape since we assume that all points have six features
                # and so we can reconstruct the number of points.
                points_list = list(points_data.numpy().reshape([-1]))
                feature['laser_%s' %
                        laser_ri_name].float_list.value[:] = points_list

                laser_ri_flow_name = '%s_flow' % laser_ri_name
                if laser_ri_flow_name in feature:
                    range_image_flow = np.array(
                        feature[laser_ri_flow_name].float_list.value[:])
                    range_image_flow_shape = feature[
                        laser_ri_flow_name + '_shape'].int64_list.value[:]
                    range_image_flow = range_image_flow.reshape(
                        range_image_flow_shape)
                    flow_data = tf.cast(
                        tf.gather_nd(range_image_flow, ri_indices), tf.float32)
                    flow_list = list(flow_data.numpy().reshape([-1]))
                    feature['laser_%s' %
                            laser_ri_flow_name].float_list.value[:] = flow_list
Exemple #3
0
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
    """Computes attention according to a sparsity pattern.

  We use the following capital letters to denote shape parameters:
    B = batch size
    S = length of the source sequence
    T = length of the target sequence
    N = number of attention heads
    H = dimensions of each attention head
    K = number of clusters
    W = attention window (K <= S)

  The 'sparsity_indices' is a tensor of integral type where the last dimension
  contains W indices (W is the attention window) for each corresponding position
  along S in 'k' that the query is allowed to attend to.

  For example, if sparsity_indices[batch_idx, target time step, head_idx] =
  [1, 7, 8], it means that token in the query attends to values with indices
  1, 7, and 8, and the attention window here is 3.

  The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1
  is reserved to mean paddings, distinct from the value (S-1).

  For example, if W=S and 'sparsity_indices' contains range(S) on the last
  dimension, this degenerates to the original full attention.

  We require that 'sparsity_indices' does not contain duplicates (except for -1
  to indicate paddings), but we do not require 'sparsity_indices' to be sorted.

  Note that this implementation is flexible and geneic but is not optimized for
  time or space complexity. Please consider grouping queries that attend to the
  same subset of values first for efficiency.

  Args:
    q: (projected) queries, [B, T, N, H];
    k: (projected) keys, [B, S, N, H];
    v: (projected) values, [B, S, N, H];
    sparsity_indices: [B, T, N, W], where W is the attention window;
    paddings: paddings for keys, [B, S] if not None.

  Returns:
    output: the encoded output, [B, T, N, H].
    atten_probs: the attention weights, [B, T, N, S].
  """
    q = tf.convert_to_tensor(q)
    k = tf.convert_to_tensor(k)
    v = tf.convert_to_tensor(v)
    sparsity_indices = tf.convert_to_tensor(sparsity_indices)

    k = py_utils.HasRank(k, 4)
    _, source_length, _, dim_per_head = py_utils.GetShape(k, 4)
    sparsity_indices = py_utils.HasRank(sparsity_indices, 4)
    batch_size, target_length, num_heads, attention_window = py_utils.GetShape(
        sparsity_indices, 4)
    py_utils.assert_less_equal(
        attention_window, source_length,
        'The provided sparsity_indices has attention window '
        ' > source length. This is likely an error.')

    # To prepare for gathering the relevant vectors from 'k', we prepare
    # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to
    # slices in 'k' indexed by (batch index, source time step, head index),
    # where the source length index comes from the original W dimension in
    # 'sparsity_indices'.
    seq_idx = tf.expand_dims(sparsity_indices, axis=-1)
    # Overwrite the paddings -1 with valid gather indices (zeros). We will
    # fix the logits with -inf in these positions later.
    seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx)
    batch_idx = tf.reshape(
        tf.range(0, batch_size, dtype=sparsity_indices.dtype),
        [batch_size, 1, 1, 1, 1])
    batch_idx = tf.tile(batch_idx,
                        [1, target_length, num_heads, attention_window, 1])
    head_idx = tf.reshape(tf.range(0, num_heads, dtype=sparsity_indices.dtype),
                          [1, 1, num_heads, 1, 1])
    head_idx = tf.tile(head_idx,
                       [batch_size, target_length, 1, attention_window, 1])
    # [B, T, N, W, 3], where last dimension is (batch index, source length index,
    # head index).
    gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1)

    # Both the gathered k and v have shape [B, T, N, W, H]
    k = tf.gather_nd(k, gather_idx)
    v = tf.gather_nd(v, gather_idx)

    if paddings is None:
        paddings = tf.zeros([batch_size, source_length])
    paddings = tf.convert_to_tensor(paddings)
    paddings = tf.expand_dims(paddings, axis=-1)
    # [B, S, N]
    paddings = tf.tile(paddings, [1, 1, num_heads])
    # [B, T, N, W]
    paddings = tf.gather_nd(paddings, gather_idx)

    logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k)
    logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))

    very_negative_logits = (tf.ones_like(logits) * logits.dtype.max *
                            tf.constant(-0.7, dtype=logits.dtype))
    padded_logits = tf.where(
        tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
        very_negative_logits, logits)

    # [B, T, N, W]
    atten_probs = tf.nn.softmax(padded_logits, name='attention_weights')
    atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits),
                           atten_probs)
    output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v)

    # Scatter 'atten_probs' back into the original source length.
    # [B, T, N, W, 1]
    batch_idx = tf.tile(
        tf.range(batch_size)[:, None, None, None, None],
        [1, target_length, num_heads, attention_window, 1])
    # [B, T, N, W, 1]
    target_seq_idx = tf.tile(
        tf.range(target_length)[None, :, None, None, None],
        [batch_size, 1, num_heads, attention_window, 1])
    # [B, T, N, W, 1]
    head_idx = tf.tile(
        tf.range(num_heads)[None, None, :, None, None],
        [batch_size, target_length, 1, attention_window, 1])
    # seq_idx: [B, T, N, W, 1]
    # [B, T, N, W, 4]
    scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1)
    # [B, T, N, S]
    scattered_probs = tf.scatter_nd(
        scatter_idx, atten_probs,
        [batch_size, target_length, num_heads, source_length])
    return output, scattered_probs
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Exemple #5
0
  def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
    """Loop body for farthest point sampler."""

    def _GetRandomRealPoint():
      """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
      random_values = tf.random.uniform((batch_size, num_points),
                                        minval=0,
                                        maxval=1,
                                        dtype=tf.float32,
                                        seed=random_seed)
      random_values = tf.where(
          tf.equal(padding, 0.0), random_values, padding * 10)
      return tf.argmin(random_values, axis=1, output_type=tf.int32)

    def _GetFurthestPoint():
      """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
      # Set padded points distance to negative so they aren't selected.
      padding_masked_distance_to_selected = tf.where(
          tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
              (batch_size, num_points), dtype=tf.float32))
      # But only do this when we still have valid points left.
      padding_masked_distance_to_selected = tf.where(
          tf.less(curr_idx, num_valid_points),
          padding_masked_distance_to_selected, distance_to_selected)
      return tf.argmax(
          padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)

    def _GetSeededPoint():
      """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
      return tf.ones((batch_size,), dtype=tf.int32) * curr_idx

    # Select indices for this loop iteration.
    def _Seeded():
      return tf.cond(
          tf.less(curr_idx, num_seeded_points), _GetSeededPoint,
          _GetFurthestPoint)

    def _Real():
      return tf.cond(
          tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)

    new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real)
    sampled_idx = sampled_idx.write(curr_idx, new_selected)

    # Extract the distance to the latest point selected to update
    # distance_to_selected.
    new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected],
                                       axis=1)
    if precomputed_squared_distance is not None:
      new_distance = tf.gather_nd(precomputed_squared_distance,
                                  new_selected_gather_idx)
    else:
      new_points = tf.reshape(
          tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims])
      new_distance = tf.reshape(
          SquaredDistanceMatrix(points, new_points), [batch_size, num_points])

    is_newly_closest = tf.less(new_distance, distance_to_selected)
    distance_to_selected = tf.minimum(distance_to_selected, new_distance)

    # Track the index to the closest selected point.
    new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
    closest_idx = tf.cond(
        tf.equal(curr_idx, 0),
        # At the first loop iteration, the init points are the closest.
        lambda: new_selected_tiled,
        # Otherwise, update with the new points based on the distances.
        lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx))
    return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
Exemple #6
0
    def FProp(self, theta, input_batch):
        # pyformat: disable
        """Compute features for the pillars and convert them back to a dense grid.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors. Following
        keys are required:

        - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz],
          where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels
          in each axis dimension).
        - pillar_points: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3 + num_laser_features]
        - pillar_centers: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3]
        - pillar_locations: Float tensor with shape [batch size, num_pillars, 3]

    Returns:
      The dense features with shape [b, nx, ny, nz * fdims].
    """
        # pyformat: enable
        p = self.params
        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]

        # Compute number of points per pillar and prepare for broadcasting.
        pillar_num_points = tf.gather_nd(input_batch.grid_num_points,
                                         input_batch.pillar_locations,
                                         batch_dims=1)
        pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis]

        # Compute mean by computing sum and dividing by number of points. Clip the
        # denominator by 1.0 to gracefully handle empty pillars.
        pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keepdims=True)
        pillar_means = pillar_sum / tf.maximum(
            tf.cast(pillar_num_points, tf.float32), 1.0)

        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                  values=[
                                      pillar_xyz - pillar_means, pillar_feats,
                                      tf.tile(pillar_means,
                                              [1, 1, npoints, 1]),
                                      tf.tile(pillar_centers,
                                              [1, 1, npoints, 1])
                                  ])
        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,
                                                pillar_concat)

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = SparseToDense(grid_shape=(nx, ny, nz),
                                       locations=pillar_locations,
                                       feats=pillar_features)
        return dense_features
def GatherK(selected_pos, values, k, num_devices=1):
  """Gather up to k elements from given tensors at selected pos under SPMD.

  Example::

    # Input
    k = 3

    selected_pos = [
        [0, 0, 1, 1],
        [0, 1, 1, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],  # topk(k=3) largest indices are selected in this row.
    ]

    value_2d = [
        [1, 3, 5, 7],
        [9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39],
    ]

    # Output:
    output = [
        [0, 5, 7],
        [0, 11, 13],
        [0, 0, 0],
        [25, 27, 29],
        [35, 37, 39],
    ]

    # Output padding:
    output_padding = [
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0],
        [0, 0, 0],
    ]

  Args:
    selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time].
    values: a list of tensors, the rank of each is at least rank=2. [batch,
      time, ...].
    k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a
      compile-time constant.
    num_devices: number of TPU devices used in xla_sharding SPMD.

  Returns:
    A tuple (output, padding).

    - output: a list of tensors of shape [batch, k, ...].
    - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations.
  """
  global_batch, seq_len = py_utils.GetShape(selected_pos, 2)
  if num_devices:
    device_batch = global_batch // num_devices
  else:
    device_batch = global_batch

  for i in range(len(values)):
    # Assert the first 2 dim of values[i] is [global_batch, seq_len]
    values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2)
  # indices are 1-based for now, to distinguish between padding and selected
  # locations.
  indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32)
  # [1, seq_len]
  indices = tf.expand_dims(indices, axis=0)

  # if 0, the position is not selected.
  # [1, seq_len] * [global_batch, seq_len] => [global_batch, t]
  # -- topk --> [global_batch, k]
  topk_indices, _ = tf.math.top_k(
      indices * tf.cast(selected_pos, indices.dtype), k)

  # [global_batch, k], sorted in ascending order.
  indices = tf.reverse(topk_indices, [-1])
  # [global_batch, k], padded positions are '1's.
  padding = tf.cast(tf.equal(indices, 0), values[0].dtype)
  padding = Split(padding, 0, num_devices)

  # [global_batch, k], zero_based_indices
  mp_idx = tf.maximum(0, indices - 1)
  mp_idx = Split(mp_idx, 0, num_devices)

  # [device_batch, k]
  if num_devices > 1 and py_utils.use_tpu():
    mp_idx = xla_sharding.auto_to_manual_spmd_partition(
        mp_idx, xla_sharding.get_op_sharding(mp_idx.op))
  # [device_batch, k, 1]
  mp_idx = tf.expand_dims(mp_idx, -1)

  # [device_batch]
  batch_ids = tf.range(device_batch, dtype=tf.int32)
  # [device_batch, 1, 1]
  batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1])
  # [device_batch, k, 1]
  batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1])

  # [device_batch, k, 2]
  final_indices = tf.concat([batch_ids, mp_idx], axis=-1)

  output = []
  for v in values:
    # Begin manually partition gather.
    v = Split(v, 0, num_devices)
    v_shape = v.shape.as_list()
    if num_devices > 1 and py_utils.use_tpu():
      op_sharding = xla_sharding.get_op_sharding(v.op)
      v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding)
    # Returns [global_batch, k, ...]
    v_out = tf.gather_nd(v, final_indices)

    if num_devices > 1 and py_utils.use_tpu():
      v_shape[1] = k
      v_out = xla_sharding.manual_to_auto_spmd_partition(
          v_out, op_sharding, full_shape=tf.TensorShape(v_shape))
    output.append(v_out)

  return output, padding