def ComputeWer(hyps, refs):
    """Computes word errors in hypotheses relative to reference transcripts.

    hyps: Hypotheses, represented as string tensors of shape [N].
    refs: References, represented as string tensors of shape [N].

    An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds
    to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1]
    corresponds to the number of words in refs[i].
    def _NormalizeWhitespace(s):
        return tf.regex_replace(tf.strings.strip(s), r'\s+', ' ')

    hyps = _NormalizeWhitespace(hyps)
    refs = _NormalizeWhitespace(refs)

    hyps = py_utils.HasRank(hyps, 1)
    refs = py_utils.HasRank(refs, 1)
    hyps = py_utils.HasShape(hyps, tf.shape(refs))

    word_errors = tf.cast(
                         normalize=False), tf.int64)

    # Count number of spaces in reference, and increment by 1 to get total number
    # of words.
    ref_words = tf.cast(
        tf.strings.length(tf.regex_replace(refs, '[^ ]', '')) + 1, tf.int64)
    # Set number of words to 0 if the reference was empty.
    ref_words = tf.where(tf.equal(refs, ''),
                         tf.zeros_like(ref_words, tf.int64), ref_words)

    return tf.concat(
        [tf.expand_dims(word_errors, -1),
         tf.expand_dims(ref_words, -1)],
  def StreamStep(self, theta, inputs, paddings, state0):
    """Apply a singele step of convolution to input_tensor.

    Only supports 1d causal convolution. Doesn't support dilation.

      theta: A NestedMap of layer params.
      inputs: A Tensor of shape [b, t, 1, c]
      paddings: A 0/1 valued tensor of shape [b, t].
      state0: A NestedMap of tensors of the same struct as returned by

      outputs: A Tensor of shape [b, t, 1, c * channel_multiplier]
      padding: the same as input paddings.
      state1: A NestedMap of the same struct as input state
    p = self.params
    assert p.filter_shape[1] == 1, (
        'StreamStep only supports 1d causal convolution.')
    assert p.filter_stride[0] == 1, ('StreamStep doesn\'t support striding')
    assert p.dilation_rate == (1, 1), ('StreamStep doesn\'t support dilation')

    with tf.name_scope(
      inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
      paddings = py_utils.HasShape(paddings, py_utils.GetShape(inputs)[:2])

      concat_inputs = tf.concat(
          [state0.context, inputs * (1 - py_utils.AppendDims(paddings, 2))],
      outputs = tf.nn.depthwise_conv2d(
          strides=(1, 1, 1, 1),
          dilations=(1, 1),
      new_context = concat_inputs[:, -(p.filter_shape[0] - 1):]
      return outputs, paddings, py_utils.NestedMap(context=new_context)
    def testSpeed(self):
        num_bboxes_list = [500, 1000, 10000]
        num_classes_list = [3, 10, 25]

        for num_bboxes in num_bboxes_list:
            for num_classes in num_classes_list:
                bboxes_3d = tf.random.uniform((num_bboxes, 7),
                # Make half zero so we can see behavior with very low values that
                # will get filtered out quickly.
                class_scores = tf.concat([
                    tf.random.uniform((num_bboxes // 2, num_classes),
                    tf.zeros((num_bboxes // 2, num_classes), dtype=tf.float32)

                with self.session():
                    outputs = ops.non_max_suppression_3d(
                        nms_iou_threshold=[0.1] * num_classes,
                        score_threshold=[0.3] * num_classes)

                    timings = []
                    for _ in range(10):
                        start = time.time()
                        _ = self.evaluate(outputs)
                        end = time.time()
                        timings.append(end - start)
                    avg = sum(timings) / len(timings)
                    print('[{},{},{},{},{}]'.format(num_bboxes, num_classes,
                                                    min(timings), avg,
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

      Projected inputs.
        p = self.params
        with tf.name_scope(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                    symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims *
                                      p.output_dims), tf.int64) * 2)
            use_tpu = py_utils.use_tpu()
            shape = inputs.shape
            if use_tpu and (shape is not None and shape.rank is not None
                            and shape.rank < 26):
                # Avoids reshape if feasible and uses Einsum.
                if shape.rank == 2:
                    return tf.matmul(inputs, theta.w)
                    s = ''.join([chr(x) for x in range(97, 123)])  #
                    r = shape.rank
                    return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs,

            input_dim = py_utils.GetShape(inputs)[-1]
            act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w)
            output_dim = tf.shape(theta.w)[-1]
            act = tf.reshape(
                act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0))
            return act
  def _CellFeaturizer(self, theta, input_batch):
    """Featurizes each center location."""
    # Validate Shapes
    cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
    batch_size, num_centers, num_points_per_cell = py_utils.GetShape(
        cell_feature, 3)

    cell_points_xyz = py_utils.HasShape(
        [batch_size, num_centers, num_points_per_cell, 3])
    cell_center_xyz = py_utils.HasShape(input_batch.cell_center_xyz,
                                        [batch_size, num_centers, 3])

    cell_points_padding = py_utils.HasShape(
        [batch_size, num_centers, num_points_per_cell])

    # Center each cell
    cell_center_xyz = tf.reshape(cell_center_xyz,
                                 [batch_size, num_centers, 1, 3])
    centered_cell_points_xyz = cell_points_xyz - cell_center_xyz
    concat_feature = tf.concat([
        centered_cell_points_xyz, cell_feature
    ], axis=-1)  # pyformat: disable

    # Featurize point clouds at each center.
    point_input = py_utils.NestedMap({
        'points': centered_cell_points_xyz,
        'features': concat_feature,
        'padding': cell_points_padding,
    featurized_cell = self.cell_featurizer.FProp(theta.cell_featurizer,
    featurized_cell = py_utils.HasShape(featurized_cell,
                                        [batch_size, num_centers, -1])
    return featurized_cell
    def testCuDNNInitializerWrapper(self):
        if not tf.test.is_gpu_available(cuda_only=True):
        dirs = [UNIDIR, BIDIR]
        input_nodes = [128, 256, 512, 1024]
        cell_nodes = [128, 256, 512, 1024]
        dtypes = [tf.float32, tf.float64]
        for direction, input_dim, cell_dim, dtype in itertools.product(
                dirs, input_nodes, cell_nodes, dtypes):
            with self.session(use_gpu=True, graph=tf.Graph()):
                base_init = tf.ones_initializer()
                cudnn_initializer = cudnn_rnn_utils.CuDNNLSTMInitializer(
                    input_dim, cell_dim, direction)
                actual = cudnn_initializer.InitOpaqueParams(dtype,
                num_dir = 1 if direction == UNIDIR else 2
                expected = tf.concat([
                    tf.ones([num_dir * 4 * cell_dim * (cell_dim + input_dim)],
                    tf.zeros([num_dir * 8 * cell_dim], dtype=dtype)

                self.assertAllClose(expected, actual)
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices.

      A list of tensors corresponding to stacked decoded outputs. The decoder
      outputs are stacked on the first dimension (usually corresponds to
      batch size).
    num_decode_tensors = len(self.decode_nm.Flatten())
    outfeed_ops = [[]] * num_decode_tensors
    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    num_cores_per_replica = (1 if self.spmd else
    for replica in range(device_assignment.num_replicas):
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.decode_nm.Flatten()],
              shapes=[x.shape for x in self.decode_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
            outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
    return [tf.concat(per_outfeed, axis=0) for per_outfeed in outfeed_ops]
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
  """Computes attention according to a sparsity pattern.

  We use the following capital letters to denote shape parameters:
    B = batch size
    S = length of the source sequence
    T = length of the target sequence
    N = number of attention heads
    H = dimensions of each attention head
    K = number of clusters
    W = attention window (K <= S)

  The 'sparsity_indices' is a tensor of integral type where the last dimension
  contains W indices (W is the attention window) for each corresponding position
  along S in 'k' that the query is allowed to attend to.

  For example, if sparsity_indices[batch_idx, target time step, head_idx] =
  [1, 7, 8], it means that token in the query attends to values with indices
  1, 7, and 8, and the attention window here is 3.

  The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1
  is reserved to mean paddings, distinct from the value (S-1).

  For example, if W=S and 'sparsity_indices' contains range(S) on the last
  dimension, this degenerates to the original full attention.

  We require that 'sparsity_indices' does not contain duplicates (except for -1
  to indicate paddings), but we do not require 'sparsity_indices' to be sorted.

  Note that this implementation is flexible and generic but is not optimized for
  time or space complexity. Please consider grouping queries that attend to the
  same subset of values first for efficiency.

    q: (projected) queries, [B, T, N, H];
    k: (projected) keys, [B, S, N, H];
    v: (projected) values, [B, S, N, H];
    sparsity_indices: [B, T, N, W], where W is the attention window;
    paddings: paddings for keys, [B, S] if not None.

    output: the encoded output, [B, T, N, H].
    atten_probs: the attention weights, [B, T, N, S].
  q = tf.convert_to_tensor(q)
  k = tf.convert_to_tensor(k)
  v = tf.convert_to_tensor(v)
  sparsity_indices = tf.convert_to_tensor(sparsity_indices)

  k = py_utils.HasRank(k, 4)
  _, source_length, _, dim_per_head = py_utils.GetShape(k, 4)
  sparsity_indices = py_utils.HasRank(sparsity_indices, 4)
  batch_size, target_length, num_heads, attention_window = py_utils.GetShape(
      sparsity_indices, 4)
      attention_window, source_length,
      'The provided sparsity_indices has attention window '
      ' > source length. This is likely an error.')

  # To prepare for gathering the relevant vectors from 'k', we prepare
  # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to
  # slices in 'k' indexed by (batch index, source time step, head index),
  # where the source length index comes from the original W dimension in
  # 'sparsity_indices'.
  seq_idx = tf.expand_dims(sparsity_indices, axis=-1)
  # Overwrite the paddings -1 with valid gather indices (zeros). We will
  # fix the logits with -inf in these positions later.
  seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx)
  batch_idx = tf.reshape(
      tf.range(0, batch_size, dtype=sparsity_indices.dtype),
      [batch_size, 1, 1, 1, 1])
  batch_idx = tf.tile(batch_idx,
                      [1, target_length, num_heads, attention_window, 1])
  head_idx = tf.reshape(
      tf.range(0, num_heads, dtype=sparsity_indices.dtype),
      [1, 1, num_heads, 1, 1])
  head_idx = tf.tile(head_idx,
                     [batch_size, target_length, 1, attention_window, 1])
  # [B, T, N, W, 3], where last dimension is (batch index, source length index,
  # head index).
  gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1)

  # Both the gathered k and v have shape [B, T, N, W, H]
  k = tf.gather_nd(k, gather_idx)
  v = tf.gather_nd(v, gather_idx)

  if paddings is None:
    paddings = tf.zeros([batch_size, source_length])
  paddings = tf.convert_to_tensor(paddings)
  paddings = tf.expand_dims(paddings, axis=-1)
  # [B, S, N]
  paddings = tf.tile(paddings, [1, 1, num_heads])
  # [B, T, N, W]
  paddings = tf.gather_nd(paddings, gather_idx)

  logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k)
  logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))

  very_negative_logits = (
      tf.ones_like(logits) * logits.dtype.max *
      tf.constant(-0.7, dtype=logits.dtype))
  padded_logits = tf.where(
      tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
      very_negative_logits, logits)

  # [B, T, N, W]
  atten_probs = tf.nn.softmax(padded_logits, name='attention_weights')
  atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits),
  output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v)

  # Scatter 'atten_probs' back into the original source length.
  # [B, T, N, W, 1]
  batch_idx = tf.tile(
      tf.range(batch_size)[:, None, None, None, None],
      [1, target_length, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  target_seq_idx = tf.tile(
      tf.range(target_length)[None, :, None, None, None],
      [batch_size, 1, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  head_idx = tf.tile(
      tf.range(num_heads)[None, None, :, None, None],
      [batch_size, target_length, 1, attention_window, 1])
  # seq_idx: [B, T, N, W, 1]
  # [B, T, N, W, 4]
  scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1)
  # [B, T, N, S]
  scattered_probs = tf.scatter_nd(
      scatter_idx, atten_probs,
      [batch_size, target_length, num_heads, source_length])
  return output, scattered_probs
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

      outputs, out_paddings pair.
        p = self.params
        with tf.name_scope(
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            if p.v2_padding:
                padded_inputs, slice_len = _PadForLengthCompatibleStridesV2(
                    inputs, p.filter_stride[0], 'SAME', 0.)
                out = self._ApplyConv(theta, padded_inputs)
                if p.filter_stride[0] > 1:
                    slice_end = py_utils.GetShape(out)[1] - slice_len
                    out = out[:, :slice_end, :, :]
                out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # implementation. Consider updating it to be the actual shape.
            if p.v2_padding:
                conv_padding = _ComputeConvOutputPaddingV2(
                conv_padding = ComputeConvOutputPadding(

            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
 def _PadZeros():
     padded = []
     for t in tensor_list:
         shape = tf.concat([[num_points_out], tf.shape(t)[1:]], axis=0)
         padded.append(tf.zeros(shape=shape, dtype=t.dtype))
     return padded
    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring'loop_vars=%r', loop_vars)'dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
  'length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)'loop_vars=%r', loop_vars)'dec_state=%r', dec_state)
        return loop_vars, dec_state
 def _MergeLeft():
     return tf.concat([
         candidates[:best_id - 1],
         _MergeOneToken(tokens, best_id - 1)
    def FProp(self, theta, input_batch):
        # pyformat: disable
        """Compute features for the pillars and convert them back to a dense grid.

      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors. Following
        keys are required:

        - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz],
          where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels
          in each axis dimension).
        - pillar_points: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3 + num_laser_features]
        - pillar_centers: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3]
        - pillar_locations: Float tensor with shape [batch size, num_pillars, 3]

      The dense features with shape [b, nx, ny, nz * fdims].
        # pyformat: enable
        p = self.params
        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]

        # Compute number of points per pillar and prepare for broadcasting.
        pillar_num_points = tf.gather_nd(input_batch.grid_num_points,
        pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis]

        # Compute mean by computing sum and dividing by number of points. Clip the
        # denominator by 1.0 to gracefully handle empty pillars.
        pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keep_dims=True)
        pillar_means = pillar_sum / tf.maximum(
            tf.cast(pillar_num_points, tf.float32), 1.0)

        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                      pillar_xyz - pillar_means, pillar_feats,
                                              [1, 1, npoints, 1]),
                                              [1, 1, npoints, 1])
        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = SparseToDense(grid_shape=(nx, ny, nz),
        return dense_features
 def _Concat(self, name, *subs):
     r"""Concatenate outputs from \*subs along the last dimensions."""
     return self._Par(name, lambda *xs: tf.concat(xs, axis=-1), *subs)
    def add_point_cloud(self, feature, laser_names, range_image_pose):
        """Convert the range images in `feature` to 3D point clouds.

    Adds the point cloud data to the tf.Example feature map.

      feature: A tf.Example feature map.
      laser_names: A list of laser names (e.g., 'TOP', 'REAR', 'SIDE_LEFT').
      range_image_pose: A range image pose Tensor for the GBR.
        for laser_name in laser_names:
            beam_inclinations = np.array(
                feature['%s_beam_inclinations' %
            # beam_inclinations will be populated if there is a non-uniform
            # beam configuration (e.g., for the TOP lasers).  Others that have
            # uniform beam inclinations are only parameterized by the min and max.
            # We use these min and max if the beam_inclinations are not present,
            # and turn them into a uniform inclinations array.
            if beam_inclinations.size == 0:
                beam_inclination_min = feature['%s_beam_inclination_min' %
                beam_inclination_max = feature['%s_beam_inclination_max' %

                laser_ri_name = '%s_ri1' % laser_name
                range_image_shape = feature[laser_ri_name +
                height = tf.cast(range_image_shape[0], tf.float32)

                beam_inclinations = tf.constant(
                    [beam_inclination_min[0], beam_inclination_max[0]])
                beam_inclinations = range_image_utils.compute_inclination(
                    beam_inclinations, height)

            beam_extrinsics = np.array(
                feature['%s_extrinsics' %
                        laser_name].float_list.value[:]).reshape(4, 4)

            for ri_type in ['ri1', 'ri2']:
                laser_ri_name = '%s_%s' % (laser_name, ri_type)
                # For each of the 4 features of the lasers:
                range_image = np.array(
                range_image_shape = feature[laser_ri_name +
                range_image = range_image.reshape(range_image_shape)
                # Compute mask.  At the moment, invalid values in the range image
                # representation are indicated via a -1. entry.  Callers are expected
                # to create this mask when passing into the conversion function below.
                range_image_mask = range_image[..., 0] >= 0

                # Get the 'range' feature from the range images.
                range_image_range = range_image[..., 0]

                # Call utility to convert point cloud to cartesian coordinates.
                # API expects a batch dimension for all inputs.
                batched_pixel_pose = None
                batched_frame_pose = None
                # At the moment, only the GBR has per-pixel pose.
                if laser_name == 'TOP':
                    batched_pixel_pose = range_image_pose[tf.newaxis, ...]
                    batched_frame_pose = self.frame_pose[tf.newaxis, ...]

                batched_range_image_range = tf.convert_to_tensor(
                    range_image_range[np.newaxis, ...], dtype=tf.float32)
                batched_extrinsics = tf.convert_to_tensor(
                    beam_extrinsics[np.newaxis, ...], dtype=tf.float32)
                batched_inclinations = tf.convert_to_tensor(
                    beam_inclinations[np.newaxis, ...], dtype=tf.float32)

                batched_inclinations = tf.reverse(batched_inclinations,

                range_image_cartesian = (

                points_xyz = tf.gather_nd(range_image_cartesian[0],

                # Fetch the features corresponding to each xyz coordinate and
                # concatentate them together.
                points_features = tf.cast(
                    tf.gather_nd(range_image[..., 1:],
                                 tf.where(range_image_mask)), tf.float32)
                points_data = tf.concat([points_xyz, points_features], axis=-1)

                # Add laser feature to output.
                # Skip embedding shape since we assume that all points have six features
                # and so we can reconstruct the number of points.
                points_list = list(points_data.numpy().reshape([-1]))
                feature['laser_%s' %
                        laser_ri_name].float_list.value[:] = points_list
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` expected to contain lasers.points_xyz,
        lasers.points_feature, lasers.points_padding, cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

      A `.NestedMap` object containing residuals and classification_logits.
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  1, 'input_batch shapes: ')
        cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
        batch_size, num_centers = py_utils.GetShape(cell_feature, 2)

        featurized_cell = self._CellFeaturizer(theta, input_batch)

        # Project each featurized_cell features to each bbox per center.
        featurized_anchors = self.cell_feature_projector.FProp(
            theta.cell_feature_projector, featurized_cell)

        # Reshape output so that we have features per offset.
        featurized_anchors = tf.reshape(
            [batch_size, num_centers, p.num_anchor_bboxes_offsets, -1])

        # Predict localization residuals.
        predicted_residuals = self.localization_regressor.FProp(
            theta.localization_regressor, featurized_anchors)
        predicted_residuals = tf.reshape(
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        if any([p.oracle_location, p.oracle_dimension, p.oracle_rotation]):
            gt_residuals = py_utils.HasShape(
                [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
            residuals = []
            if p.oracle_location:
                residuals.append(gt_residuals[..., 0:3])
                residuals.append(predicted_residuals[..., 0:3])

            if p.oracle_dimension:
                residuals.append(gt_residuals[..., 3:6])
                residuals.append(predicted_residuals[..., 3:6])

            if p.oracle_rotation:
                residuals.append(gt_residuals[..., 6:])
                residuals.append(predicted_residuals[..., 6:])
            predicted_residuals = tf.concat(residuals, axis=-1)

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        # Predict object classification at each bbox.
        predicted_classification_logits = self.classifier.FProp(
            theta.classifier, featurized_anchors)
        predicted_classification_logits = tf.reshape(
            predicted_classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,

        if p.oracle_classification:
            assigned_gt_labels = py_utils.HasShape(
                [batch_size, num_centers, p.num_anchor_bboxes_per_center])
            predicted_classification_logits = tf.one_hot(
                assigned_gt_labels, p.num_classes)

        return py_utils.NestedMap({
    def test_max_assign_batch_version(self):
        # 2x2 example
        score1 = tf.convert_to_tensor([[0.5, 1.0], [0.2, 0.6]])
        row_sums1 = tf.convert_to_tensor([1.0, 1.0])
        col_sums1 = tf.convert_to_tensor([1.0, 1.0])
        upper_bound1 = tf.ones_like(score1)

        # 3x3 example
        score2 = tf.convert_to_tensor([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]])
        row_sums2 = tf.convert_to_tensor([1.0, 1.0, 1.0])
        col_sums2 = tf.convert_to_tensor([1.0, 1.0, 1.0])
        upper_bound2 = tf.ones_like(score2)

        score1 = score1[tf.newaxis]
        row_sums1 = row_sums1[tf.newaxis]
        col_sums1 = col_sums1[tf.newaxis]
        upper_bound1 = upper_bound1[tf.newaxis]

        score2 = score2[tf.newaxis]
        row_sums2 = row_sums2[tf.newaxis]
        col_sums2 = col_sums2[tf.newaxis]
        upper_bound2 = upper_bound2[tf.newaxis]

        # A batch with example 1 and example 2. We need to pad example 1.
        # Padded scores should have very large negative value.
        # Padded sums and upper bound should be zero.
        score1_ = tf.pad(score1, [[0, 0], [0, 1], [0, 1]],
        row_sums1_ = tf.pad(row_sums1, [[0, 0], [0, 1]])
        col_sums1_ = tf.pad(col_sums1, [[0, 0], [0, 1]])
        upper_bound1_ = tf.pad(upper_bound1, [[0, 0], [0, 1], [0, 1]])
        score3 = tf.concat([score1_, score2], axis=0)
        row_sums3 = tf.concat([row_sums1_, row_sums2], axis=0)
        col_sums3 = tf.concat([col_sums1_, col_sums2], axis=0)
        upper_bound3 = tf.concat([upper_bound1_, upper_bound2], axis=0)

        results1 = differentiable_assignment.max_assignment(
        results2 = differentiable_assignment.max_assignment(
        results3 = differentiable_assignment.max_assignment(
        assignment1 = results1[0]
        assignment2 = results2[0]
        assignment3 = results3[0]

        print("Test case - batched:")
        print("Used iter:", results1[1], results2[1], results3[1])
        print("Delta:", results1[-1], results2[-1], results3[-1])
                               assignment3[0, :2, :2],
        self.assertNDArrayNear(assignment2[0], assignment3[1], err=1e-4)
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` expected to contain cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

      A `.NestedMap` object containing residuals and classification_logits.
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  1, 'input_batch shapes: ')

        cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
        batch_size, num_centers, num_points_per_cell = py_utils.GetShape(
            cell_feature, 3)

        cell_points_xyz = py_utils.HasShape(
            [batch_size, num_centers, num_points_per_cell, 3])
        cell_center_xyz = py_utils.HasShape(input_batch.cell_center_xyz,
                                            [batch_size, num_centers, 3])

        cell_points_padding = py_utils.HasShape(
            [batch_size, num_centers, num_points_per_cell])

        # TODO(jngiam): Make concat_feature computation a layer or configureable.
        cell_center_xyz = tf.reshape(cell_center_xyz,
                                     [batch_size, num_centers, 1, 3])
        centered_cell_points_xyz = cell_points_xyz - cell_center_xyz
        concat_feature = tf.concat([
            tf.tile(cell_center_xyz, [1, 1, num_points_per_cell, 1]),
            centered_cell_points_xyz, cell_feature
                                   axis=-1)  # pyformat: disable

        # Featurize point clouds at each center.
        point_input = py_utils.NestedMap({
            'points': centered_cell_points_xyz,
            'features': concat_feature,
            'padding': cell_points_padding,
        featurized_cell = self.cell_featurizer.FProp(theta.cell_featurizer,
        featurized_cell = py_utils.HasShape(featurized_cell,
                                            [batch_size, num_centers, -1])

        # Predict localization residuals.
        predicted_residuals = self.localization_regressor.FProp(
            theta.localization_regressor, featurized_cell)
        predicted_residuals = tf.reshape(
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        # Predict object classification at each bbox.
        predicted_classification_logits = self.classifier.FProp(
            theta.classifier, featurized_cell)
        predicted_classification_logits = tf.reshape(
            predicted_classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,

        return py_utils.NestedMap({
    def testMelFeaturesPaddedLeftStacked(self):
        p = self.params
        p.stack_left_context = 2
        p.frame_stride = p.stack_left_context + 1
        mel_frontend = p.Instantiate()
        sample_rate, pcm = self._GetPcm()
        pcm *= 32768

        # Convert to 4D [batch, time, packet, channels].
        sample_count = tf.shape(pcm)[1]
        packet_size = 11  # A non-round number.
        trimmed_pcm = pcm[:, 0:(sample_count // packet_size) * packet_size]
        src_inputs = tf.reshape(trimmed_pcm, (1, -1, packet_size, 1))

        # Create paddings such that the first 455 packets are unpadded.
        paddings = tf.concat([
            tf.zeros([1, 455], dtype=tf.float32),
            tf.ones([1, tf.shape(src_inputs)[1] - 455], dtype=tf.float32)
        # frame_step=240, frame_size=601, +1202 left padded frames
        # 455 packets * 11 frames rounds = 5005 frames, rounds down to 21 mel
        # frames. Divide by 3 for stacking = 7.
        expected_unpadded = 7

        outputs = mel_frontend.FPropDefaultTheta(
            py_utils.NestedMap(src_inputs=src_inputs, paddings=paddings))
        log_mel = outputs.src_inputs
        paddings = outputs.paddings
        with self.session():
            pcm = self.evaluate(pcm)
  'pcm: ~ %s = %s', pcm.shape, pcm)
            self.assertGreater(33000, np.amax(pcm))
            self.assertGreater(np.amax(pcm), 2.)
            log_mel, paddings, sample_rate = self.evaluate(
                [log_mel, paddings, sample_rate])
            self.assertEqual(sample_rate, p.sample_rate)
            self.assertEqual(paddings.shape, log_mel.shape[0:2])
            self.assertAllEqual(paddings[:, 0:expected_unpadded],
                                np.zeros([1, expected_unpadded]))
                paddings[:, expected_unpadded:],
                np.ones([1, paddings.shape[1] - expected_unpadded]))
            # log_mel ~ [batch, time, feature_size, channel]
  'mel ~ %s', log_mel.shape)
            # Squeeze the batch and channel dimensions out.
            log_mel = np.squeeze(log_mel, axis=(0, 3))
            t = log_mel.shape[0]
            mu = np.sum(log_mel, axis=0) / t
            d = log_mel - mu
            v = np.sum(d * d, axis=0) / (t - 1)
            s = np.sqrt(v)
  'Found mean = %s', mu)
  'Found stddev = %s', s)
            ref_mean = (13.38236332, 13.2698698, 13.45229626, 13.26469517,
                        13.46731281, 13.31649303)
            ref_stddev = (1.52104115, 1.27433181, 1.41266346, 1.27072334,
                          1.41251481, 1.28583682)
            self.assertAllClose(mu, ref_mean, atol=1e-4)
            self.assertAllClose(s, ref_stddev, atol=1e-3)
    def FProp(self,
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

      ValueError: If invalid params.
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
    def FProp(self, theta, batch, state0=None):
        """Encodes source as represented by 'inputs' and 'paddings'.

      theta: A NestedMap object containing weights' values of this
        layer and its children layers.
      batch: A NestedMap with fields:

        - src_inputs - The inputs tensor. It is expected to be of shape [batch,
          time, feature_dim, channels].
        - paddings - The paddings tensor. It is expected to be of shape [batch,
      state0: Recurrent input state. Not supported/ignored by this encoder.

      A NestedMap containing

      - 'encoded': a feature tensor of shape [time, batch, depth]
      - 'padding': a 0/1 tensor of shape [time, batch]
      - 'state': the updated recurrent state
      - '${layer_type}_${layer_index}': The per-layer encoder output. Each one
        is a NestedMap containing 'encoded' and 'padding' similar to regular
        final outputs, except that 'encoded' from conv or conv_lstm layers are
        of shape [time, batch, depth, channels].
        p = self.params
        inputs, paddings = batch.src_inputs, batch.paddings
        outputs = py_utils.NestedMap()
        with tf.name_scope(
            # Adding specAugmentation.
            if p.use_specaugment and not self.do_eval:
                inputs, paddings = self.specaugment.FProp(
                    theta.specaugment, inputs, paddings)
            # Add a few extra padded timesteps at the end. This is for ensuring the
            # correctness of the conv-layers at the edges.
            if p.pad_steps > 0:
                # inplace_update() is not supported by TPU for now. Since we have done
                # padding on the input_generator, we may avoid this additional padding.
                assert not py_utils.use_tpu()
                inputs_pad = tf.zeros(
                    inplace_ops.inplace_update(tf.shape(inputs), 1,
                                               p.pad_steps), inputs.dtype)
                paddings_pad = tf.ones(
                    inplace_ops.inplace_update(tf.shape(paddings), 1,
                                               p.pad_steps), paddings.dtype)
                inputs = tf.concat([inputs, inputs_pad], 1, name='inputs')
                paddings = tf.concat([paddings, paddings_pad], 1)

            plots = [
                    tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs')

            conv_out = inputs
            out_padding = paddings
            for i, conv_layer in enumerate(self.conv):
                conv_out, out_padding = conv_layer.FProp(
                    theta.conv[i], conv_out, out_padding)
                if p.extra_per_layer_outputs:
                    conv_out *= (1.0 -
                                 out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_%d' % i] = py_utils.NestedMap(
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        tf.transpose(conv_out, [0, 1, 3, 2]), out_padding,
                        'conv_%d_out' % i))

            def TransposeFirstTwoDims(t):
                first_dim = tf.shape(t)[0]
                second_dim = tf.shape(t)[1]
                t_new = tf.transpose(
                    tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2])
                t_shape_new = tf.concat([[second_dim], [first_dim],
                                         tf.shape(t)[2:]], 0)
                return tf.reshape(t_new, t_shape_new)

            # Now the conv-lstm part.
            conv_lstm_out = conv_out
            conv_lstm_out_padding = out_padding
            for i, (rnn, cnn) in enumerate(
                    zip(self.conv_lstm_rnn, self.conv_lstm_cnn)):
                conv_lstm_in = conv_lstm_out
                # Move time dimension to be the first.
                conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in)
                conv_lstm_in = tf.expand_dims(conv_lstm_in, 2)
                conv_lstm_in_padding = tf.expand_dims(
                    tf.transpose(conv_lstm_out_padding), 2)
                lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in,
                # Move time dimension to be the second.
                cnn_in = TransposeFirstTwoDims(lstm_out)
                cnn_in = tf.squeeze(cnn_in, 2)
                cnn_in_padding = conv_lstm_out_padding
                cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i],
                                                     cnn_in, cnn_in_padding)
                conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding
                if p.extra_per_layer_outputs:
                    conv_lstm_out *= (
                        1.0 -
                        conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_lstm_%d' % i] = py_utils.NestedMap(
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        conv_lstm_out, conv_lstm_out_padding,
                        'conv_lstm_%d_out' % i))

            # Need to do a reshape before starting the rnn layers.
            conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4)
            conv_lstm_out_shape = tf.shape(conv_lstm_out)
            new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0)
            conv_lstm_out = tf.reshape(conv_lstm_out, new_shape)
            if self._first_lstm_input_dim_pad:
                conv_lstm_out = tf.pad(
                    [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]])

            conv_lstm_out = py_utils.HasShape(
                conv_lstm_out, [-1, -1, self._first_lstm_input_dim])

            # Transpose to move the time dimension to be the first.
            rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2])
            rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding),
            # rnn_in is of shape [time, batch, depth]
            # rnn_padding is of shape [time, batch, 1]

            # Now the rnn layers.
            num_skips = 0
            for i in range(p.num_lstm_layers):
                rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding)
                residual_index = i - p.residual_start + 1
                if p.residual_start > 0 and residual_index >= 0:
                    if residual_index % p.residual_stride == 0:
                        residual_in = rnn_in
                    if residual_index % p.residual_stride == p.residual_stride - 1:
                        # Highway skip connection.
                        if p.highway_skip:
                            rnn_out = self.highway_skip[num_skips].FProp(
                                theta.highway_skip[num_skips], residual_in,
                            num_skips += 1
                            # Residual skip connection.
                            rnn_out += py_utils.HasShape(
                                residual_in, tf.shape(rnn_out))
                if p.project_lstm_output and (i < p.num_lstm_layers - 1):
                    # Projection layers.
                    rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out,
                if i == p.num_lstm_layers - 1:
                    rnn_out *= (1.0 - rnn_padding)
                if p.extra_per_layer_outputs:
                    rnn_out *= (1.0 - rnn_padding)
                    outputs['rnn_%d' % i] = py_utils.NestedMap(
                        encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2]))
                # Stacking layer connection.
                if p.layer_index_before_stacking == i:
                    # Stacking layer expects input tensor shape as [batch, time, feature].
                    # So transpose the tensors before and after the layer.
                    rnn_out, rnn_padding = self.stacking.FProp(
                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]))
                    rnn_out = tf.transpose(rnn_out, [1, 0, 2])
                    rnn_padding = tf.transpose(rnn_padding, [1, 0, 2])

                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]),
                        'rnn_%d_out' % i))
                rnn_in = rnn_out
            final_out = rnn_in


            outputs['encoded'] = final_out
            outputs['padding'] = tf.squeeze(rnn_padding, [2])
            outputs['state'] = py_utils.NestedMap()
            return outputs
 def _ConcatPerLayerVals(*per_layer_vals):
   vals_expanded_dim = [tf.expand_dims(v, 0) for v in per_layer_vals]
   return tf.concat(vals_expanded_dim, 0)
 def _GetDefaultPaddings(self, inputs):
     """Gets the default paddings for an input."""
     return tf.zeros(tf.concat([tf.shape(inputs)[:-1], [1]], 0),
 def _RemoveFirstEos(x):
     # We remove the element at position `first_eos_idx`, and pad with 0
     # to keep length unchanged.
     zero = tf.constant(0, shape=(1, ), dtype=x.dtype)
     return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero],
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
    for k, v in six.iteritems(output._asdict()):
      if v is None:
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
    def AssignAnchors(self,
        """Assigns anchors to bboxes using a similarity function (SSD-based).

    Each anchor box is assigned to the top matching ground truth box.
    Ground truth boxes can be assigned to multiple anchor boxes.

    Assignments can result in 3 outcomes:

      - Positive assignment (if score >= foreground_assignment_threshold):
        assigned_gt_labels will reflect the assigned box label and
        assigned_cls_mask will be set to 1.0
      - Background assignment (if score <= background_assignment_threshold):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 1.0
      - Ignore assignment (otherwise):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 0.0

    The detection loss function would usually:

      - Use assigned_cls_mask for weighting the classification loss. The mask
        is set such that the loss applies to foreground and background
        assignments only - ignored anchors will be set to 0.
      - Use assigned_reg_mask for weighting the regression loss. The mask is set
        such that the loss applies to foreground assignments only.

    The thresholds (foreground_assignment_threshold and
    background_assignment_threshold) should be tuned per dataset.

    TODO(jngiam): Consider having a separate threshold for regression boxes; a
    separate threshold is used in PointRCNN.

      anchor_bboxes: tf.float32. [A, 7], where [..., :] corresponds to box
        parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes: tf.float32. [G, 7], where [..., :] corresponds to ground truth
        box parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes_labels: tensor with shape [G]. Ground truth labels for each
        bounding box.
      gt_bboxes_mask: tensor with shape [G]. Mask for ground truth boxes, 1 iff
        the gt_bbox is a real bbox.
      foreground_assignment_threshold: Similarity score threshold for assigning
        foreground bounding boxes; scores need to be >=
        foreground_assignment_threshold to be assigned to foreground.
      background_assignment_threshold: Similarity score threshold for assigning
        background bounding boxes; scores need to be <=
        background_assignment_threshold to be assigned to background.
      background_class_id: class id to be assigned to anchors_gt_class if no
        anchor boxes match.
      force_match: Boolean specifying if force matching is enabled. If
        force matching is enabled, then matched anchors which are also the
        highest scoring with a ground-truth box are considered foreground
        matches as long as their similarity score > 0.
      similarity_fn: Function that computes the a similarity score (e.g., IOU)
        between pairs of bounding boxes. This function should take in two
        tensors corresponding to anchor and ground-truth bboxes, and return a
        matrix [A, G] with the similarity score between each pair of bboxes. The
        score must be non-negative, with greater scores representing more
        similar. The fore/background_assignment_thresholds will be applied to
        this score to determine if the an anchor is foreground, background or
        ignored. If set to None, the function will default to IOU2DRotatedBoxes.

      NestedMap with the following keys

      - assigned_gt_idx: shape [A] index corresponding to the index of the
        assigned ground truth box. Anchors not assigned to a ground truth box
        will have the index set to -1.
      - assigned_gt_bbox: shape [A, 7] bbox parameters assigned to each anchor.
      - assigned_gt_similarity_score: shape [A] (iou) score between the anchor
        and the gt bbox.
      - assigned_gt_labels: shape [A] label assigned to bbox.
      - assigned_cls_mask: shape [A] mask for classification loss per anchor.
        This should be 1.0 if the anchor has a foreground or background
        assignment; otherwise, it will be assigned to 0.0.
      - assigned_reg_mask: shape [A] mask for regression loss per anchor.
        This should be 1.0 if the anchor has a foreground assignment;
        otherwise, it will be assigned to 0.0.
        Note: background anchors do not have regression targets.
        if similarity_fn is None:
            similarity_fn = self.IOU2DRotatedBoxes

        # Shape validation.
        anchor_bboxes = py_utils.HasShape(anchor_bboxes, [-1, 7])
        num_anchor_bboxes, _ = py_utils.GetShape(anchor_bboxes, 2)
        gt_bboxes = py_utils.HasShape(gt_bboxes, [-1, 7])
        num_gt_bboxes, _ = py_utils.GetShape(gt_bboxes, 2)

        # Compute similarity score and reduce max by anchors and by ground-truth.
        similarity_score = similarity_fn(anchor_bboxes, gt_bboxes)
        similarity_score = py_utils.HasShape(
            similarity_score, [num_anchor_bboxes, num_gt_bboxes])

        # Reduce over ground-truth boxes, so we have the max score per anchor.
        anchor_max_score = tf.reduce_max(similarity_score, axis=1)
        anchor_max_idx = tf.argmax(similarity_score, axis=1)

        if force_match:
            # Reduce over anchors, so we have the max score per ground truth box.
            gt_max_score = tf.reduce_max(similarity_score,

            # Force matches occur when the top matching gt bbox for an anchor is the
            # top matching anchor for the gt bbox. When force matching, we match
            # these boxes as long as their similarity score exceeds 0.
            force_matches = (
                tf.equal(similarity_score, gt_max_score)
                & tf.equal(similarity_score, anchor_max_score[..., tf.newaxis])
                & tf.greater(similarity_score, 0.)
                & tf.cast(gt_bboxes_mask[tf.newaxis, ...], tf.bool))
            force_match_indicator = tf.reduce_any(force_matches, axis=1)
            force_match_idx = tf.argmax(tf.cast(force_matches, tf.int32),

            # In assigning foreground/background anchors later, force_match_indicator
            # is used to determine which anchors are force foreground, and the index
            # assigned will be taken from anchor_max_idx.

            # Force matchers must also be the max scoring gt bbox per anchor.
            # We overwrite anchor_max_idx to ensure that the right match is done.
            anchor_max_idx = tf.where(force_match_indicator, force_match_idx,

        # Ensure that max score boxes are not padded boxes by setting score to 0
        # for boxes that are padded.
        gathered_mask = tf.batch_gather(gt_bboxes_mask, anchor_max_idx)
        anchor_max_score = tf.where(tf.equal(gathered_mask, 1),

        # Boolean tensors corresponding to whether an anchor is background or
        # foreground based on thresholding.
        background_anchors = tf.less_equal(anchor_max_score,
        foreground_anchors = tf.greater_equal(anchor_max_score,
        if force_match:
            # Background anchors are below threshold and not force matches.
            background_anchors &= ~force_match_indicator
            # Foreground anchors are above thresholds or force matches.
            foreground_anchors |= force_match_indicator

        # Add dummy background bbox to gt_boxes to facilitate batch gather.
        dummy_bbox = tf.constant([[0, 0, 0, 1, 1, 1, 0]], dtype=tf.float32)

        # Since we are concatenating the dummy bbox, the index corresponds to the
        # number of boxes.
        dummy_bbox_idx = py_utils.GetShape(gt_bboxes, 1)[0]

        gt_bboxes = tf.concat([gt_bboxes, dummy_bbox], axis=0)
        gt_bboxes_labels = tf.concat([gt_bboxes_labels, [background_class_id]],

        # Gather indices so that all foreground boxes are gathered from gt_bboxes,
        # while all background and ignore boxes gather the dummy_bbox.
        anchor_gather_idx = tf.where(
            foreground_anchors, anchor_max_idx,

        # Gather the bboxes and weights.
        assigned_gt_bbox = tf.batch_gather(gt_bboxes, anchor_gather_idx)
        assigned_gt_labels = tf.batch_gather(gt_bboxes_labels,

        # Set masks for classification and regression losses.
        assigned_cls_mask = tf.cast(background_anchors | foreground_anchors,
        assigned_reg_mask = tf.cast(foreground_anchors, tf.float32)

        # Set assigned_gt_idx such that dummy boxes have idx = -1.
        assigned_gt_idx = tf.where(tf.equal(anchor_gather_idx, dummy_bbox_idx),
                                   tf.ones_like(anchor_gather_idx) * -1,
        assigned_gt_idx = tf.cast(assigned_gt_idx, tf.int32)

        return py_utils.NestedMap(
 def _MergeRight():
     return tf.concat([
         _MergeOneToken(tokens, best_id), candidates[best_id + 2:]
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

      A `.NestedMap` contains
        logits - [b, nx, ny, nz, na, 7 + num_classes]
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  0, 'input_batch shapes: ')

        # Make pillars representation from input_batch.
        dense_features = self.input_featurizer.FProp(theta.input_featurizer,

        # Backbone
        tf.logging.vlog(1, 'dense_features.shape = %s', dense_features.shape)
        act = self.backbone.FProp(theta.backbone, dense_features)
        tf.logging.vlog(1, 'act.shape = %s', act.shape)

        # Convert the output of the backbone into class logits and regression
        # residuals using two different layers.
        class_detection = self.class_detector.FProp(theta.class_detector, act)
        reg_detection = self.regression_detector.FProp(
            theta.regression_detector, act)
        bs, nx, ny, _ = py_utils.GetShape(class_detection, 4)
        predicted_classification_logits = tf.reshape(
            [bs, nx, ny, p.grid_size_z, p.num_anchors, p.num_classes])
        predicted_residuals = tf.reshape(
            reg_detection, [bs, nx, ny, p.grid_size_z, p.num_anchors, 7])

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        if p.oracle_location or p.oracle_dimension or p.oracle_rotation:
            gt_residuals = py_utils.HasShape(
                [bs, nx, ny, p.grid_size_z, p.num_anchors, 7])

            # Replace the predicted components with the ground truth if needed.
            if p.oracle_location:
                location = gt_residuals[..., 0:3]
                location = predicted_residuals[..., 0:3]

            if p.oracle_dimension:
                dimension = gt_residuals[..., 3:6]
                dimension = predicted_residuals[..., 3:6]

            if p.oracle_rotation:
                rotation = gt_residuals[..., 6:]
                rotation = predicted_residuals[..., 6:]
            predicted_residuals = tf.concat([location, dimension, rotation],

        ret = py_utils.NestedMap({

        if p.direction_classifier_weight > 0.0:
            predicted_dir = self.direction_classifier.FProp(
                theta.direction_classifier, act)
            predicted_dir = tf.reshape(
                predicted_dir, [bs, nx, ny, p.grid_size_z, p.num_anchors, 2])
            ret.predicted_dir = predicted_dir

        return ret
  def _CreateCanvasAndTargets(self, batch):
    # pyformat: disable
    """Create the canvas and targets.

      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    # pyformat: enable
    p = self.params

    if not p.is_eval:
      # Sample our src and tgt canvas.
      src_descriptor = self._SampleCanvasAndTargets(batch.src.ids,
      tgt_descriptor = self._SampleCanvasAndTargets(batch.tgt.ids,

      # Offset the src ids (to unshare embeddings between src/tgt). Note, we
      # only offset the canvas ids, but we do not offset the vocab ids. This
      # will result in unshared embeddings, but shared softmax. This is due to
      # GPU/TPU memory limitations, empirically it is known that unsharing
      # everything results in better performance.
      vocab_size = p.decoder.softmax.num_classes
      src_descriptor.canvas = tf.where(
          tf.equal(src_descriptor.canvas_paddings, 0),
          src_descriptor.canvas + vocab_size, src_descriptor.canvas)

      # Offset the tgt indices (need shift according to src length).
      batch_size = py_utils.GetShape(batch.src.ids)[0]
      # `target_batch` is a [num_targets, batch_size] tensor where each row
      # identifies which batch the target belongs to. Note the observation that,
      # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
      target_batch = tf.cast(
              tf.expand_dims(tf.range(batch_size), 0),
              tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)), tf.int32)
      src_lens = tf.cast(
          tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
      # `tgt_offset` is shape [num_targets] where each entry corresponds to the
      # offset needed for that target (due to the source length).
      tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
      # We shift the tgt slot without touching the batch or vocab.
      tgt_descriptor.target_indices += tf.concat(
          [tf.zeros_like(tgt_offset), tgt_offset,
           tf.zeros_like(tgt_offset)], 1)

      # The canvas is simply the sequence-level concat of the src and tgt.
      canvas, canvas_paddings = insertion.SequenceConcat(
          src_descriptor.canvas, src_descriptor.canvas_paddings,
          tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
      target_indices = tf.concat(
          [src_descriptor.target_indices, tgt_descriptor.target_indices], 0)
      target_weights = tf.concat(
          [src_descriptor.target_weights, tgt_descriptor.target_weights], 0)

      return py_utils.NestedMap(
    def _Extract(self, features):
        p = self.params

        source_id = py_utils.HasShape(features['image/source_id'], [])
        xmin = _Dense(features['object/image/bbox/xmin'])
        xmax = _Dense(features['object/image/bbox/xmax'])
        ymin = _Dense(features['object/image/bbox/ymin'])
        ymax = _Dense(features['object/image/bbox/ymax'])

        # 2d bounding box in image coordinates.
        bboxes = tf.stack([ymin, xmin, ymax, xmax], axis=1)
        bboxes_count = tf.shape(bboxes)[0]
        bboxes = py_utils.PadOrTrimTo(bboxes, [p.max_num_objects, 4])

        bboxes_padding = 1.0 - py_utils.PadOrTrimTo(tf.ones([bboxes_count]),

        dim_xyz = tf.reshape(_Dense(features['object/velo/bbox/dim_xyz']),
                             [-1, 3])
        loc_xyz = tf.reshape(_Dense(features['object/velo/bbox/xyz']), [-1, 3])
        phi = tf.reshape(_Dense(features['object/velo/bbox/phi']), [-1, 1])
        # bboxes_3d is in [x, y, z, dx, dy, dz, phi].
        bboxes_3d = tf.concat([loc_xyz, dim_xyz, phi], axis=1)

        cx, cy, _, dx, dy, _, _ = tf.unstack(bboxes_3d, num=7, axis=-1)
        bboxes_td = tf.stack([
            cy - dy / 2,
            cx - dx / 2,
            cy + dy / 2,
            cx + dx / 2,
                             axis=-1)  # pyformat: disable
        bboxes_td = py_utils.PadOrTrimTo(bboxes_td, [p.max_num_objects, 4])

        has_3d_info = tf.cast(_Dense(features['object/has_3d_info']),
        bboxes_3d_mask = py_utils.PadOrTrimTo(has_3d_info, [p.max_num_objects])
        bboxes_td_mask = bboxes_3d_mask

        # Fill in difficulties from bounding box height, truncation and occlusion.
        bb_height = ymax - ymin
        box_image_height = py_utils.PadOrTrimTo(bb_height, [p.max_num_objects])
        box_image_height *= bboxes_3d_mask

        # 0 to 3 indicating occlusion level. 0 means fully visible, 1 means partly,
        occlusion = tf.reshape(_Dense(features['object/occlusion']), [-1])
        occlusion = tf.cast(occlusion, tf.float32)
        occlusion = py_utils.PadOrTrimTo(occlusion, [p.max_num_objects])
        occlusion *= bboxes_3d_mask

        # Truncation: 0 -> not truncated, 1.0 -> truncated
        truncation = tf.reshape(_Dense(features['object/truncation']), [-1])
        truncation = py_utils.PadOrTrimTo(truncation, [p.max_num_objects])
        truncation *= bboxes_3d_mask

        difficulties = ComputeKITTIDifficulties(box_image_height, occlusion,
        difficulties = py_utils.PadOrTrimTo(difficulties, [p.max_num_objects])

        # Make a batch axis to call BBoxCorners, and take the first result back.
        bbox3d_corners = geometry.BBoxCorners(bboxes_3d[tf.newaxis, ...])[0]

        # Project the 3D bbox to the image plane.
        velo_to_image_plane = features['transform/velo_to_image_plane']
        bboxes3d_proj_to_image_plane = geometry.PointsToImagePlane(
            tf.reshape(bbox3d_corners, [-1, 3]), velo_to_image_plane)

        # Output is [num_objects, 8 corners per object, (x, y)].
        bboxes3d_proj_to_image_plane = tf.reshape(bboxes3d_proj_to_image_plane,
                                                  [-1, 8, 2])
        bboxes3d_proj_to_image_plane = py_utils.PadOrTrimTo(
            bboxes3d_proj_to_image_plane, [p.max_num_objects, 8, 2])

        texts = features['object/label'].values
        labels = ops.static_map_string_int(x=texts,

        labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects])
        texts = py_utils.PadOrTrimTo(texts, [p.max_num_objects])

        # Filter labels by setting bboxes_padding, bboxes_3d_mask, and
        # bboxes_td_mask appropriately.
        if p.filter_labels is not None:
            valid_labels = tf.constant([p.filter_labels])
            bbox_mask = tf.reduce_any(tf.equal(tf.expand_dims(labels, 1),
            bbox_mask = tf.cast(bbox_mask, tf.float32)
            bboxes_padding = 1 - bbox_mask * (1 - bboxes_padding)
            filtered_bboxes_3d_mask = bboxes_3d_mask * bbox_mask
            bboxes_td_mask *= bbox_mask
            filtered_bboxes_3d_mask = bboxes_3d_mask

        # Placeholder for counting the number of laser points that reside within
        # each 3-d bounding box. This must be filled in outside of this function
        # based on the loaded 3-d laser points.
        bboxes_3d_num_points = tf.zeros([p.max_num_objects], dtype=tf.int32)
        bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points,

        # Pad bboxes_3d.
        bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7])

        return py_utils.NestedMap(