def test_map_points_to_boxes_tf(self):
   # pyformat: disable
   points = tf.constant([[0.0, 0.0, 0.0],
                         [1.5, 0.0, 0.0],
                         [1.0, 1.5, 0.0],
                         [2.5, 2.5, 2.5],
                         [2.5, 2.5, 0.0]])
   box_lengths = tf.constant([[4.0], [1.0]])
   box_heights = tf.constant([[1.0], [1.0]])
   box_widths = tf.constant([[2.0], [1.0]])
   box_rotations = tf.constant([[[1.0, 0.0, 0.0],
                                 [0.0, 1.0, 0.0],
                                 [0.0, 0.0, 1.0]],
                                [[1.0, 0.0, 0.0],
                                 [0.0, 1.0, 0.0],
                                 [0.0, 0.0, 1.0]]])
   box_translations = tf.constant([[0.0, 0.0, 0.0],
                                   [2.5, 2.5, 2.5]])
   # pyformat: enable
   box_indices = box_utils.map_points_to_boxes(
       points=points,
       boxes_length=box_lengths,
       boxes_height=box_heights,
       boxes_width=box_widths,
       boxes_rotation_matrix=box_rotations,
       boxes_center=box_translations,
       box_margin=0.0)
   self.assertAllEqual(box_indices.numpy(), np.array([0, 0, -1, 1, -1]))
def compute_semantic_labels(inputs, points_key, box_margin=0.1):
  """Computes ground-truth semantic labels of the points.

  If a point falls inside an object box, assigns it to the label of that box.
  Otherwise the point is assigned to background (unknown) which is label 0.

  Args:
    inputs: A dictionary containing points and objects.
    points_key: A string corresponding to the tensor of point positions in
      inputs.
    box_margin: A margin by which object boxes are grown. Useful to make sure
      points on the object box boundary fall inside the object.

  Returns:
    A tf.int32 tensor of size [num_points, 1] containing point semantic labels.

  Raises:
    ValueError: If the required object or point keys are not in inputs.
  """
  if points_key not in inputs:
    raise ValueError(('points_key: %s not in inputs.' % points_key))
  if 'objects/shape/dimension' not in inputs:
    raise ValueError('`objects/shape/dimension` not in inputs.')
  if 'objects/pose/R' not in inputs:
    raise ValueError('`objects/pose/R` not in inputs.')
  if 'objects/pose/t' not in inputs:
    raise ValueError('`objects/pose/t` not in inputs.')
  if 'objects/category/label' not in inputs:
    raise ValueError('`objects/category/label` not in inputs.')
  point_positions = inputs[points_key]
  boxes_length = inputs['objects/shape/dimension'][:, 0:1]
  boxes_width = inputs['objects/shape/dimension'][:, 1:2]
  boxes_height = inputs['objects/shape/dimension'][:, 2:3]
  boxes_rotation_matrix = inputs['objects/pose/R']
  boxes_center = inputs['objects/pose/t']
  boxes_label = tf.expand_dims(inputs['objects/category/label'], axis=1)
  boxes_label = tf.pad(boxes_label, paddings=[[1, 0], [0, 0]])
  points_box_index = box_utils.map_points_to_boxes(
      points=point_positions,
      boxes_length=boxes_length,
      boxes_height=boxes_height,
      boxes_width=boxes_width,
      boxes_rotation_matrix=boxes_rotation_matrix,
      boxes_center=boxes_center,
      box_margin=box_margin)
  return tf.gather(boxes_label, points_box_index + 1)
def compute_motion_labels(scene,
                          frame0,
                          frame1,
                          frame_start_index,
                          points_key,
                          box_margin=0.1):
  """Compute motion label for each point.

  Args:
    scene: dict of tensor containing scene.
    frame0: dict of tensor containing points and objects.
    frame1: dict of tensor containing points and objects.
    frame_start_index: starting frame index.
    points_key:  A string corresponding to the tensor of point positions in
      inputs.
    box_margin: A margin value to enlarge box, so that surrounding points are
      included.

  Returns:
    A motion tensor of [N, 3] shape.

  """
  point_positions = frame0[points_key]
  frame0_object_names = frame0['objects/name']
  frame1_object_names = frame1['objects/name']
  bool_matrix = tf.math.equal(
      tf.expand_dims(frame0_object_names, axis=1),
      tf.expand_dims(frame1_object_names, axis=0))
  match_indices = tf.where(bool_matrix)

  # object box level
  box_dimension = tf.gather(
      frame0['objects/shape/dimension'], match_indices[:, 0], axis=0)
  boxes_length = box_dimension[:, 0:1]
  boxes_width = box_dimension[:, 1:2]
  boxes_height = box_dimension[:, 2:3]
  boxes_rotation_matrix = tf.gather(
      frame0['objects/pose/R'], match_indices[:, 0], axis=0)
  boxes_center = tf.gather(
      frame0['objects/pose/t'], match_indices[:, 0], axis=0)
  frame1_box_rotation_matrix = tf.gather(
      frame1['objects/pose/R'], match_indices[:, 1], axis=0)
  frame1_box_center = tf.gather(
      frame1['objects/pose/t'], match_indices[:, 1], axis=0)

  # frame level
  frame0_rotation = scene['frames/pose/R'][frame_start_index]
  frame1_rotation = scene['frames/pose/R'][frame_start_index + 1]
  frame0_translation = scene['frames/pose/t'][frame_start_index]
  frame1_translation = scene['frames/pose/t'][frame_start_index + 1]

  frame1_box_center_global = tf.tensordot(
      frame1_box_center, frame1_rotation, axes=(1, 1)) + frame1_translation
  frame1_box_center_in_frame0 = tf.tensordot(
      frame1_box_center_global - frame0_translation,
      frame0_rotation,
      axes=(1, 0))

  # only find index on boxes that are matched between two frames
  points_box_index = box_utils.map_points_to_boxes(
      points=point_positions,
      boxes_length=boxes_length,
      boxes_height=boxes_height,
      boxes_width=boxes_width,
      boxes_rotation_matrix=boxes_rotation_matrix,
      boxes_center=boxes_center,
      box_margin=box_margin)

  # TODO(huangrui): disappered object box have 0 motion.
  # Probably consider set to nan or ignore_label.

  # 1. gather points in surviving matched box only,
  #    and replicate rotation/t to same length;
  # 2. get points in box frame, apply new rotation/t per point;
  # 3. new location minus old location -> motion vector;
  # 4. scatter it to a larger motion_vector with 0 for
  #    points ouside of matched boxes.

  # Need to limit boxes to those matched boxes.
  # otherwise the points_box_index will contain useless box.

  # index in all point array, of points that are inside the box.
  points_inside_box_index = tf.where(points_box_index + 1)[:, 0]
  box_index = tf.gather(points_box_index, points_inside_box_index)
  points_inside_box = tf.gather(point_positions, points_inside_box_index)
  box_rotation_per_point = tf.gather(boxes_rotation_matrix, box_index)
  box_center_per_point = tf.gather(boxes_center, box_index)
  # Tensor [N, 3, 3] and [N, 3]. note we are transform points reversely.
  points_in_box_frame = tf.einsum('ikj,ik->ij', box_rotation_per_point,
                                  points_inside_box - box_center_per_point)

  # Transform rotation of box from frame1 coordinate to frame0 coordinate
  # note, transpose is implemented via changing summation axis
  frame1_box_rotation_matrix_global = tf.transpose(
      tf.tensordot(frame1_rotation, frame1_box_rotation_matrix, axes=(1, 1)),
      perm=(1, 0, 2))
  frame1_box_rotation_matrix_in_frame0 = tf.transpose(
      tf.tensordot(
          frame0_rotation, frame1_box_rotation_matrix_global, axes=(0, 1)),
      perm=(1, 0, 2))

  # this is the points_position_after_following_frame1_box's motion.
  frame1_box_rotation_in_frame0_per_point = tf.gather(
      frame1_box_rotation_matrix_in_frame0, box_index)
  frame1_box_center_in_frame0_per_point = tf.gather(frame1_box_center_in_frame0,
                                                    box_index)

  points_in_box_frame1 = tf.einsum(
      'ijk,ik->ij', frame1_box_rotation_in_frame0_per_point,
      points_in_box_frame) + frame1_box_center_in_frame0_per_point
  motion_vector = points_in_box_frame1 - points_inside_box

  scattered_vector = tf.scatter_nd(
      indices=tf.expand_dims(points_inside_box_index, axis=1),
      updates=motion_vector,
      shape=tf.shape(point_positions, out_type=tf.dtypes.int64))

  return scattered_vector