def _RelPositionBias(query, abs_pos_emb): """Computes relative position bias for general cases.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)] abs_pos_emb = tf.reverse(abs_pos_emb, [0]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Convert to [B, N, T, T] # part1 term_bd_left = term_bd[:, :, :, :t] term_bd_left = tf.reverse(term_bd_left, [2, 3]) term_bd_left = RelShift(term_bd_left) # [B, N, T, T] term_bd_left = tf.reverse(term_bd_left, [2, 3]) # part 2 term_bd_right = term_bd[:, :, :, t - 1:] # [B, N, T, T] term_bd_right = RelShift(term_bd_right) # [lower triangle] mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0) # stitching togather return tf.where(mask > 0, term_bd_left, term_bd_right)
def _RelPositionBiasCausal(query, abs_pos_emb): """Computes relative position bias for causal self attention.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Retain only half and change order to [T-1, T-2, ... 0] # [T, N, H] abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t] # [B, N, T, L=T] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Perform shifting. term_bd = tf.reverse(term_bd, [2, 3]) term_bd = RelShift(term_bd) return tf.reverse(term_bd, [2, 3])
def add_point_cloud(self, feature, laser_names, range_image_pose): """Convert the range images in `feature` to 3D point clouds. Adds the point cloud data to the tf.Example feature map. Args: feature: A tf.Example feature map. laser_names: A list of laser names (e.g., 'TOP', 'REAR', 'SIDE_LEFT'). range_image_pose: A range image pose Tensor for the GBR. """ for laser_name in laser_names: beam_inclinations = np.array( feature['%s_beam_inclinations' % laser_name].float_list.value[:]) # beam_inclinations will be populated if there is a non-uniform # beam configuration (e.g., for the TOP lasers). Others that have # uniform beam inclinations are only parameterized by the min and max. # We use these min and max if the beam_inclinations are not present, # and turn them into a uniform inclinations array. if beam_inclinations.size == 0: beam_inclination_min = feature['%s_beam_inclination_min' % laser_name].float_list.value[:] beam_inclination_max = feature['%s_beam_inclination_max' % laser_name].float_list.value[:] laser_ri_name = '%s_ri1' % laser_name range_image_shape = feature[laser_ri_name + '_shape'].int64_list.value[:] height = tf.cast(range_image_shape[0], tf.float32) beam_inclinations = tf.constant( [beam_inclination_min[0], beam_inclination_max[0]]) beam_inclinations = range_image_utils.compute_inclination( beam_inclinations, height) beam_extrinsics = np.array( feature['%s_extrinsics' % laser_name].float_list.value[:]).reshape(4, 4) for ri_type in ['ri1', 'ri2']: laser_ri_name = '%s_%s' % (laser_name, ri_type) # For each of the 4 features of the lasers: range_image = np.array( feature[laser_ri_name].float_list.value[:]) range_image_shape = feature[laser_ri_name + '_shape'].int64_list.value[:] range_image = range_image.reshape(range_image_shape) # Compute mask. At the moment, invalid values in the range image # representation are indicated via a -1. entry. Callers are expected # to create this mask when passing into the conversion function below. range_image_mask = range_image[..., 0] >= 0 # Get the 'range' feature from the range images. range_image_range = range_image[..., 0] # Call utility to convert point cloud to cartesian coordinates. # # API expects a batch dimension for all inputs. batched_pixel_pose = None batched_frame_pose = None # At the moment, only the GBR has per-pixel pose. if laser_name == 'TOP': batched_pixel_pose = range_image_pose[tf.newaxis, ...] batched_frame_pose = self.frame_pose[tf.newaxis, ...] batched_range_image_range = tf.convert_to_tensor( range_image_range[np.newaxis, ...], dtype=tf.float32) batched_extrinsics = tf.convert_to_tensor( beam_extrinsics[np.newaxis, ...], dtype=tf.float32) batched_inclinations = tf.convert_to_tensor( beam_inclinations[np.newaxis, ...], dtype=tf.float32) batched_inclinations = tf.reverse(batched_inclinations, axis=[-1]) range_image_cartesian = ( range_image_utils.extract_point_cloud_from_range_image( batched_range_image_range, batched_extrinsics, batched_inclinations, pixel_pose=batched_pixel_pose, frame_pose=batched_frame_pose)) points_xyz = tf.gather_nd(range_image_cartesian[0], tf.where(range_image_mask)) # Fetch the features corresponding to each xyz coordinate and # concatentate them together. points_features = tf.cast( tf.gather_nd(range_image[..., 1:], tf.where(range_image_mask)), tf.float32) points_data = tf.concat([points_xyz, points_features], axis=-1) # Add laser feature to output. # # Skip embedding shape since we assume that all points have six features # and so we can reconstruct the number of points. points_list = list(points_data.numpy().reshape([-1])) feature['laser_%s' % laser_ri_name].float_list.value[:] = points_list
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params assert isinstance(self.cell, rnn_cell.RNNCell) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = self.cell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = self.cell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=self.cell.FPropWithProjectedInput, cell_type=self.cell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = self.cell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state
def _XYZFromRangeImage(self, lidar_image, lidar_image_mask, extrinsics, inclinations, pixel_pose=None, frame_pose=None): """Extract the cartesian coordinates from the range image. Args: lidar_image: [H, W, C] range image Tensor. lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the lidar image are present. extrinsics: [4, 4] float matrix representing transformation matrix to world coordinates. inclinations: [V] beam inclinations vector. pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR. frame_pose: [4, 4] matrix representing vehicle to world transformation. Returns: [H, W, 3] range image cartesian coordinates. """ height, width, channels = py_utils.GetShape(lidar_image, 3) conversion_dtype = tf.float32 lidar_image = tf.cast(lidar_image, conversion_dtype) extrinsics = tf.cast(extrinsics, conversion_dtype) inclinations = tf.cast(inclinations, conversion_dtype) inclinations = tf.reverse(inclinations, axis=[-1]) az_correction = py_utils.HasShape( tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), []) ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) - .5) / tf.cast(width, conversion_dtype) ratios = py_utils.HasShape(ratios, [width]) azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis] azimuth = py_utils.HasShape(azimuth, [width]) lidar_image_mask = lidar_image_mask[..., tf.newaxis] lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels]) lidar_image = tf.where(lidar_image_mask, lidar_image, tf.zeros_like(lidar_image)) lidar_image_range = lidar_image[..., 0] azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width]) inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1]) cos_azimuth = tf.cos(azimuth) sin_azimuth = tf.sin(azimuth) cos_incl = tf.cos(inclinations) sin_incl = tf.sin(inclinations) x = cos_azimuth * cos_incl * lidar_image_range y = sin_azimuth * cos_incl * lidar_image_range z = sin_incl * lidar_image_range lidar_image_points = tf.stack([x, y, z], -1) lidar_image_points = py_utils.HasShape(lidar_image_points, [height, width, 3]) rotation = extrinsics[0:3, 0:3] translation = extrinsics[0:3, 3][tf.newaxis, ...] # Transform the image points in cartesian coordinates to # the world coordinate system using the extrinsics matrix. # # We first flatten the points, apply rotation, then # reshape to restore the original input and then apply # translation. lidar_image_points = tf.matmul( tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True) lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3]) lidar_image_points += translation lidar_image_points = py_utils.HasShape(lidar_image_points, [height, width, 3]) # GBR uses per pixel pose. if pixel_pose is not None: pixel_pose_rotation = pixel_pose[..., 0:3, 0:3] pixel_pose_translation = pixel_pose[..., 0:3, 3] lidar_image_points = tf.einsum( 'hwij,hwj->hwi', pixel_pose_rotation, lidar_image_points) + pixel_pose_translation if frame_pose is None: raise ValueError('frame_pose must be set when pixel_pose is set.') # To vehicle frame corresponding to the given frame_pose # [4, 4] world_to_vehicle = tf.linalg.inv(frame_pose) world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3] world_to_vehicle_translation = world_to_vehicle[0:3, 3] # [H, W, 3] lidar_image_points = tf.einsum( 'ij,hwj->hwi', world_to_vehicle_rotation, lidar_image_points) + world_to_vehicle_translation[tf.newaxis, tf.newaxis, :] return lidar_image_points
def GatherK(selected_pos, values, k, num_devices=1): """Gather up to k elements from given tensors at selected pos under SPMD. Example:: # Input k = 3 selected_pos = [ [0, 0, 1, 1], [0, 1, 1, 0], [0, 0, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], # topk(k=3) largest indices are selected in this row. ] value_2d = [ [1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23], [25, 27, 29, 31], [33, 35, 37, 39], ] # Output: output = [ [0, 5, 7], [0, 11, 13], [0, 0, 0], [25, 27, 29], [35, 37, 39], ] # Output padding: output_padding = [ [1, 0, 0], [1, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], ] Args: selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time]. values: a list of tensors, the rank of each is at least rank=2. [batch, time, ...]. k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a compile-time constant. num_devices: number of TPU devices used in xla_sharding SPMD. Returns: A tuple (output, padding). - output: a list of tensors of shape [batch, k, ...]. - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations. """ global_batch, seq_len = py_utils.GetShape(selected_pos, 2) if num_devices: device_batch = global_batch // num_devices else: device_batch = global_batch for i in range(len(values)): # Assert the first 2 dim of values[i] is [global_batch, seq_len] values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2) # indices are 1-based for now, to distinguish between padding and selected # locations. indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32) # [1, seq_len] indices = tf.expand_dims(indices, axis=0) # if 0, the position is not selected. # [1, seq_len] * [global_batch, seq_len] => [global_batch, t] # -- topk --> [global_batch, k] topk_indices, _ = tf.math.top_k( indices * tf.cast(selected_pos, indices.dtype), k) # [global_batch, k], sorted in ascending order. indices = tf.reverse(topk_indices, [-1]) # [global_batch, k], padded positions are '1's. padding = tf.cast(tf.equal(indices, 0), values[0].dtype) padding = Split(padding, 0, num_devices) # [global_batch, k], zero_based_indices mp_idx = tf.maximum(0, indices - 1) mp_idx = Split(mp_idx, 0, num_devices) # [device_batch, k] if num_devices > 1 and py_utils.use_tpu(): mp_idx = xla_sharding.auto_to_manual_spmd_partition( mp_idx, xla_sharding.get_op_sharding(mp_idx.op)) # [device_batch, k, 1] mp_idx = tf.expand_dims(mp_idx, -1) # [device_batch] batch_ids = tf.range(device_batch, dtype=tf.int32) # [device_batch, 1, 1] batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1]) # [device_batch, k, 1] batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1]) # [device_batch, k, 2] final_indices = tf.concat([batch_ids, mp_idx], axis=-1) output = [] for v in values: # Begin manually partition gather. v = Split(v, 0, num_devices) v_shape = v.shape.as_list() if num_devices > 1 and py_utils.use_tpu(): op_sharding = xla_sharding.get_op_sharding(v.op) v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding) # Returns [global_batch, k, ...] v_out = tf.gather_nd(v, final_indices) if num_devices > 1 and py_utils.use_tpu(): v_shape[1] = k v_out = xla_sharding.manual_to_auto_spmd_partition( v_out, op_sharding, full_shape=tf.TensorShape(v_shape)) output.append(v_out) return output, padding