예제 #1
0
def _RelPositionBias(query, abs_pos_emb):
    """Computes relative position bias for general cases."""
    _, t, n, h = py_utils.GetShape(query)
    abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

    # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
    # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)]
    abs_pos_emb = tf.reverse(abs_pos_emb, [0])

    # [B, N, T, L=2T-1]
    term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

    # Convert to [B, N, T, T]
    # part1
    term_bd_left = term_bd[:, :, :, :t]
    term_bd_left = tf.reverse(term_bd_left, [2, 3])
    term_bd_left = RelShift(term_bd_left)
    # [B, N, T, T]
    term_bd_left = tf.reverse(term_bd_left, [2, 3])
    # part 2
    term_bd_right = term_bd[:, :, :, t - 1:]
    # [B, N, T, T]
    term_bd_right = RelShift(term_bd_right)
    # [lower triangle]
    mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0)

    # stitching togather
    return tf.where(mask > 0, term_bd_left, term_bd_right)
예제 #2
0
def _RelPositionBiasCausal(query, abs_pos_emb):
    """Computes relative position bias for causal self attention."""
    _, t, n, h = py_utils.GetShape(query)

    abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

    # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
    # Retain only half and change order to [T-1, T-2, ... 0]
    # [T, N, H]
    abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

    # [B, N, T, L=T]
    term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

    # Perform shifting.
    term_bd = tf.reverse(term_bd, [2, 3])
    term_bd = RelShift(term_bd)
    return tf.reverse(term_bd, [2, 3])
예제 #3
0
    def add_point_cloud(self, feature, laser_names, range_image_pose):
        """Convert the range images in `feature` to 3D point clouds.

    Adds the point cloud data to the tf.Example feature map.

    Args:
      feature: A tf.Example feature map.
      laser_names: A list of laser names (e.g., 'TOP', 'REAR', 'SIDE_LEFT').
      range_image_pose: A range image pose Tensor for the GBR.
    """
        for laser_name in laser_names:
            beam_inclinations = np.array(
                feature['%s_beam_inclinations' %
                        laser_name].float_list.value[:])
            # beam_inclinations will be populated if there is a non-uniform
            # beam configuration (e.g., for the TOP lasers).  Others that have
            # uniform beam inclinations are only parameterized by the min and max.
            # We use these min and max if the beam_inclinations are not present,
            # and turn them into a uniform inclinations array.
            if beam_inclinations.size == 0:
                beam_inclination_min = feature['%s_beam_inclination_min' %
                                               laser_name].float_list.value[:]
                beam_inclination_max = feature['%s_beam_inclination_max' %
                                               laser_name].float_list.value[:]

                laser_ri_name = '%s_ri1' % laser_name
                range_image_shape = feature[laser_ri_name +
                                            '_shape'].int64_list.value[:]
                height = tf.cast(range_image_shape[0], tf.float32)

                beam_inclinations = tf.constant(
                    [beam_inclination_min[0], beam_inclination_max[0]])
                beam_inclinations = range_image_utils.compute_inclination(
                    beam_inclinations, height)

            beam_extrinsics = np.array(
                feature['%s_extrinsics' %
                        laser_name].float_list.value[:]).reshape(4, 4)

            for ri_type in ['ri1', 'ri2']:
                laser_ri_name = '%s_%s' % (laser_name, ri_type)
                # For each of the 4 features of the lasers:
                range_image = np.array(
                    feature[laser_ri_name].float_list.value[:])
                range_image_shape = feature[laser_ri_name +
                                            '_shape'].int64_list.value[:]
                range_image = range_image.reshape(range_image_shape)
                # Compute mask.  At the moment, invalid values in the range image
                # representation are indicated via a -1. entry.  Callers are expected
                # to create this mask when passing into the conversion function below.
                range_image_mask = range_image[..., 0] >= 0

                # Get the 'range' feature from the range images.
                range_image_range = range_image[..., 0]

                # Call utility to convert point cloud to cartesian coordinates.
                #
                # API expects a batch dimension for all inputs.
                batched_pixel_pose = None
                batched_frame_pose = None
                # At the moment, only the GBR has per-pixel pose.
                if laser_name == 'TOP':
                    batched_pixel_pose = range_image_pose[tf.newaxis, ...]
                    batched_frame_pose = self.frame_pose[tf.newaxis, ...]

                batched_range_image_range = tf.convert_to_tensor(
                    range_image_range[np.newaxis, ...], dtype=tf.float32)
                batched_extrinsics = tf.convert_to_tensor(
                    beam_extrinsics[np.newaxis, ...], dtype=tf.float32)
                batched_inclinations = tf.convert_to_tensor(
                    beam_inclinations[np.newaxis, ...], dtype=tf.float32)

                batched_inclinations = tf.reverse(batched_inclinations,
                                                  axis=[-1])

                range_image_cartesian = (
                    range_image_utils.extract_point_cloud_from_range_image(
                        batched_range_image_range,
                        batched_extrinsics,
                        batched_inclinations,
                        pixel_pose=batched_pixel_pose,
                        frame_pose=batched_frame_pose))

                points_xyz = tf.gather_nd(range_image_cartesian[0],
                                          tf.where(range_image_mask))

                # Fetch the features corresponding to each xyz coordinate and
                # concatentate them together.
                points_features = tf.cast(
                    tf.gather_nd(range_image[..., 1:],
                                 tf.where(range_image_mask)), tf.float32)
                points_data = tf.concat([points_xyz, points_features], axis=-1)

                # Add laser feature to output.
                #
                # Skip embedding shape since we assume that all points have six features
                # and so we can reconstruct the number of points.
                points_list = list(points_data.numpy().reshape([-1]))
                feature['laser_%s' %
                        laser_ri_name].float_list.value[:] = points_list
예제 #4
0
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        assert isinstance(self.cell, rnn_cell.RNNCell)

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = self.cell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = self.cell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=self.cell.FPropWithProjectedInput,
            cell_type=self.cell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = self.cell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state
예제 #5
0
  def _XYZFromRangeImage(self,
                         lidar_image,
                         lidar_image_mask,
                         extrinsics,
                         inclinations,
                         pixel_pose=None,
                         frame_pose=None):
    """Extract the cartesian coordinates from the range image.

    Args:
       lidar_image: [H, W, C] range image Tensor.
       lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the
         lidar image are present.
       extrinsics: [4, 4] float matrix representing transformation matrix to
         world coordinates.
       inclinations: [V] beam inclinations vector.
       pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR.
       frame_pose: [4, 4] matrix representing vehicle to world transformation.

    Returns:
      [H, W, 3] range image cartesian coordinates.
    """
    height, width, channels = py_utils.GetShape(lidar_image, 3)

    conversion_dtype = tf.float32
    lidar_image = tf.cast(lidar_image, conversion_dtype)
    extrinsics = tf.cast(extrinsics, conversion_dtype)
    inclinations = tf.cast(inclinations, conversion_dtype)
    inclinations = tf.reverse(inclinations, axis=[-1])

    az_correction = py_utils.HasShape(
        tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), [])
    ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) -
              .5) / tf.cast(width, conversion_dtype)
    ratios = py_utils.HasShape(ratios, [width])

    azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis]
    azimuth = py_utils.HasShape(azimuth, [width])

    lidar_image_mask = lidar_image_mask[..., tf.newaxis]
    lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels])
    lidar_image = tf.where(lidar_image_mask, lidar_image,
                           tf.zeros_like(lidar_image))
    lidar_image_range = lidar_image[..., 0]

    azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width])
    inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1])

    cos_azimuth = tf.cos(azimuth)
    sin_azimuth = tf.sin(azimuth)
    cos_incl = tf.cos(inclinations)
    sin_incl = tf.sin(inclinations)

    x = cos_azimuth * cos_incl * lidar_image_range
    y = sin_azimuth * cos_incl * lidar_image_range
    z = sin_incl * lidar_image_range

    lidar_image_points = tf.stack([x, y, z], -1)
    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    rotation = extrinsics[0:3, 0:3]
    translation = extrinsics[0:3, 3][tf.newaxis, ...]

    # Transform the image points in cartesian coordinates to
    # the world coordinate system using the extrinsics matrix.
    #
    # We first flatten the points, apply rotation, then
    # reshape to restore the original input and then apply
    # translation.
    lidar_image_points = tf.matmul(
        tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True)
    lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3])
    lidar_image_points += translation

    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    # GBR uses per pixel pose.
    if pixel_pose is not None:
      pixel_pose_rotation = pixel_pose[..., 0:3, 0:3]
      pixel_pose_translation = pixel_pose[..., 0:3, 3]
      lidar_image_points = tf.einsum(
          'hwij,hwj->hwi', pixel_pose_rotation,
          lidar_image_points) + pixel_pose_translation
      if frame_pose is None:
        raise ValueError('frame_pose must be set when pixel_pose is set.')
      # To vehicle frame corresponding to the given frame_pose
      # [4, 4]
      world_to_vehicle = tf.linalg.inv(frame_pose)
      world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3]
      world_to_vehicle_translation = world_to_vehicle[0:3, 3]
      # [H, W, 3]
      lidar_image_points = tf.einsum(
          'ij,hwj->hwi', world_to_vehicle_rotation,
          lidar_image_points) + world_to_vehicle_translation[tf.newaxis,
                                                             tf.newaxis, :]

    return lidar_image_points
예제 #6
0
def GatherK(selected_pos, values, k, num_devices=1):
  """Gather up to k elements from given tensors at selected pos under SPMD.

  Example::

    # Input
    k = 3

    selected_pos = [
        [0, 0, 1, 1],
        [0, 1, 1, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],  # topk(k=3) largest indices are selected in this row.
    ]

    value_2d = [
        [1, 3, 5, 7],
        [9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39],
    ]

    # Output:
    output = [
        [0, 5, 7],
        [0, 11, 13],
        [0, 0, 0],
        [25, 27, 29],
        [35, 37, 39],
    ]

    # Output padding:
    output_padding = [
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0],
        [0, 0, 0],
    ]

  Args:
    selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time].
    values: a list of tensors, the rank of each is at least rank=2. [batch,
      time, ...].
    k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a
      compile-time constant.
    num_devices: number of TPU devices used in xla_sharding SPMD.

  Returns:
    A tuple (output, padding).

    - output: a list of tensors of shape [batch, k, ...].
    - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations.
  """
  global_batch, seq_len = py_utils.GetShape(selected_pos, 2)
  if num_devices:
    device_batch = global_batch // num_devices
  else:
    device_batch = global_batch

  for i in range(len(values)):
    # Assert the first 2 dim of values[i] is [global_batch, seq_len]
    values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2)
  # indices are 1-based for now, to distinguish between padding and selected
  # locations.
  indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32)
  # [1, seq_len]
  indices = tf.expand_dims(indices, axis=0)

  # if 0, the position is not selected.
  # [1, seq_len] * [global_batch, seq_len] => [global_batch, t]
  # -- topk --> [global_batch, k]
  topk_indices, _ = tf.math.top_k(
      indices * tf.cast(selected_pos, indices.dtype), k)

  # [global_batch, k], sorted in ascending order.
  indices = tf.reverse(topk_indices, [-1])
  # [global_batch, k], padded positions are '1's.
  padding = tf.cast(tf.equal(indices, 0), values[0].dtype)
  padding = Split(padding, 0, num_devices)

  # [global_batch, k], zero_based_indices
  mp_idx = tf.maximum(0, indices - 1)
  mp_idx = Split(mp_idx, 0, num_devices)

  # [device_batch, k]
  if num_devices > 1 and py_utils.use_tpu():
    mp_idx = xla_sharding.auto_to_manual_spmd_partition(
        mp_idx, xla_sharding.get_op_sharding(mp_idx.op))
  # [device_batch, k, 1]
  mp_idx = tf.expand_dims(mp_idx, -1)

  # [device_batch]
  batch_ids = tf.range(device_batch, dtype=tf.int32)
  # [device_batch, 1, 1]
  batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1])
  # [device_batch, k, 1]
  batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1])

  # [device_batch, k, 2]
  final_indices = tf.concat([batch_ids, mp_idx], axis=-1)

  output = []
  for v in values:
    # Begin manually partition gather.
    v = Split(v, 0, num_devices)
    v_shape = v.shape.as_list()
    if num_devices > 1 and py_utils.use_tpu():
      op_sharding = xla_sharding.get_op_sharding(v.op)
      v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding)
    # Returns [global_batch, k, ...]
    v_out = tf.gather_nd(v, final_indices)

    if num_devices > 1 and py_utils.use_tpu():
      v_shape[1] = k
      v_out = xla_sharding.manual_to_auto_spmd_partition(
          v_out, op_sharding, full_shape=tf.TensorShape(v_shape))
    output.append(v_out)

  return output, padding