示例#1
0
def _randomly_transform_points_boxes(
    mesh_inputs, object_inputs, x_min_degree_rotation, x_max_degree_rotation,
    y_min_degree_rotation, y_max_degree_rotation, z_min_degree_rotation,
    z_max_degree_rotation, rotation_center, min_scale_ratio, max_scale_ratio,
    translation_range):
  """Randomly rotate and translate points and boxes.

  Args:
    mesh_inputs: A dictionary containing mesh input tensors.
    object_inputs: A dictionary containing object input tensors.
    x_min_degree_rotation: Min degree of rotation around the x axis.
    x_max_degree_rotation: Max degree of rotation around the x axis.
    y_min_degree_rotation: Min degree of rotation around the y axis.
    y_max_degree_rotation: Max degree of rotation around the y axis.
    z_min_degree_rotation: Min degree of rotation around the z axis.
    z_max_degree_rotation: Max degree of rotation around the z axis.
    rotation_center: A 3d point that points are rotated around that.
    min_scale_ratio: Minimum scale ratio.
    max_scale_ratio: Maximum scale ratio.
    translation_range: A float value corresponding to the magnitude of
      translation in x, y, and z directions. If None, there will not be a
      translation.
  """
  # Random rotation of points in camera frame
  preprocessor_utils.rotate_randomly(
      mesh_inputs=mesh_inputs,
      object_inputs=object_inputs,
      x_min_degree_rotation=x_min_degree_rotation,
      x_max_degree_rotation=x_max_degree_rotation,
      y_min_degree_rotation=y_min_degree_rotation,
      y_max_degree_rotation=y_max_degree_rotation,
      z_min_degree_rotation=z_min_degree_rotation,
      z_max_degree_rotation=z_max_degree_rotation,
      rotation_center=rotation_center)
  # Random scaling
  preprocessor_utils.randomly_scale_points_and_objects(
      mesh_inputs=mesh_inputs,
      object_inputs=object_inputs,
      min_scale_ratio=min_scale_ratio,
      max_scale_ratio=max_scale_ratio)
  # Random translation
  if translation_range is not None:
    if translation_range < 0:
      raise ValueError('Translation range should be positive')
    preprocessor_utils.translate_randomly(
        mesh_inputs=mesh_inputs,
        object_inputs=object_inputs,
        delta_x_min=-translation_range,
        delta_x_max=translation_range,
        delta_y_min=-translation_range,
        delta_y_max=translation_range,
        delta_z_min=-translation_range,
        delta_z_max=translation_range)
示例#2
0
 def test_rotate_randomly(self):
     mesh_inputs = {
         standard_fields.InputDataFields.point_positions:
         tf.random.uniform([100, 3],
                           minval=-10.0,
                           maxval=10.0,
                           dtype=tf.float32),
     }
     object_inputs = {
         standard_fields.InputDataFields.objects_center:
         tf.random.uniform([20, 3],
                           minval=-10.0,
                           maxval=10.0,
                           dtype=tf.float32),
         standard_fields.InputDataFields.objects_rotation_matrix:
         tf.random.uniform([20, 3, 3],
                           minval=-1.0,
                           maxval=1.0,
                           dtype=tf.float32),
     }
     preprocessor_utils.rotate_randomly(mesh_inputs=mesh_inputs,
                                        object_inputs=object_inputs,
                                        x_min_degree_rotation=-10.0,
                                        x_max_degree_rotation=10.0,
                                        y_min_degree_rotation=-180.0,
                                        y_max_degree_rotation=180.0,
                                        z_min_degree_rotation=-10.0,
                                        z_max_degree_rotation=10.0,
                                        rotation_center=(0.0, 0.0, 0.0))
     self.assertAllEqual(
         mesh_inputs[standard_fields.InputDataFields.point_positions].shape,
         [100, 3])
     self.assertAllEqual(
         object_inputs[
             standard_fields.InputDataFields.objects_center].shape, [20, 3])
     self.assertAllEqual(
         object_inputs[
             standard_fields.InputDataFields.objects_rotation_matrix].shape,
         [20, 3, 3])