Пример #1
0
def sampling_points_from_frustum(height,
                                 width,
                                 focal,
                                 principal_point,
                                 depth_min=0.0,
                                 depth_max=5.,
                                 frustum_size=(256, 256, 256)):
    """Generates samples from a camera frustum."""

    # ------------------ Get the rays from the camera ----------------------------
    sampling_points = grid.generate((0., 0.),
                                    (float(width) - 1, float(height) - 1),
                                    (frustum_size[0], frustum_size[1]))
    sampling_points = tf.reshape(sampling_points, [-1, 2])  # [h*w, 2]
    rays = perspective.ray(sampling_points, focal, principal_point)  # [h*w, 3]

    # ------------------ Extract a volume in front of the camera -----------------
    depth_tensor = grid.generate((depth_min, ), (depth_max, ),
                                 (frustum_size[2], ))
    sampling_volume = tf.multiply(tf.expand_dims(
        rays, axis=-1), tf.transpose(depth_tensor))  # [h*w, 3, dstep]
    sampling_volume = tf.transpose(sampling_volume,
                                   [0, 2, 1])  # [h*w, dstep, 3]
    sampling_volume = tf.reshape(sampling_volume, [-1, 3])  # [h*w*dstep, 3]
    return sampling_volume
Пример #2
0
    def test_generate_random(self):
        """Test the uniform grid generation."""
        starts = np.array((0., 0.), dtype=np.float32)
        stops = np.random.randint(1, 10, size=(2))
        nums = stops + 1
        stops = stops.astype(np.float32)

        g = grid.generate(starts, stops, nums)
        shape = nums.tolist() + [2]
        xv, yv = np.meshgrid(range(shape[0]), range(shape[1]), indexing="ij")
        gt = np.stack((xv, yv), axis=-1).astype(np.float32)

        self.assertAllClose(g, gt)
Пример #3
0
def generate_ground_image(height,
                          width,
                          focal,
                          principal_point,
                          camera_rotation_matrix,
                          camera_translation_vector,
                          ground_color=(0.43, 0.43, 0.8)):
  """Generate an image depicting only the ground."""
  batch_size = camera_rotation_matrix.shape[0]
  background_image = np.ones((batch_size, height, width, 1, 1),
                             dtype=np.float32)
  background_image[:, -1, ...] = 0  # Zero the bottom line for proper sampling

  # The projection of the ground depends on the top right corner (approximation)
  plane_point_np = np.tile(np.array([[3.077984, 2.905388, 0.]],
                                    dtype=np.float32), (batch_size, 1))
  plane_point_rotated = rotation_matrix_3d.rotate(plane_point_np,
                                                  camera_rotation_matrix)
  plane_point_translated = plane_point_rotated + camera_translation_vector
  plane_point2d = \
    perspective.project(plane_point_translated, focal, principal_point)
  _, y = tf.split(plane_point2d, [1, 1], axis=-1)
  sfactor = height/y
  helper_matrix1 = np.tile(np.array([[[1, 0, 0],
                                      [0, 0, 0],
                                      [0, 0, 0]]]), (batch_size, 1, 1))
  helper_matrix2 = np.tile(np.array([[[0, 0, 0],
                                      [0, 1, 0],
                                      [0, 0, 1]]]), (batch_size, 1, 1))
  transformation_matrix = tf.multiply(tf.expand_dims(sfactor, -1),
                                      helper_matrix1) + helper_matrix2
  plane_points = grid.generate((0., 0., 0.),
                               (float(height), float(width), 0.),
                               (height, width, 1))
  plane_points = tf.reshape(plane_points, [-1, 3])
  transf_plane_points = tf.matmul(transformation_matrix,
                                  plane_points,
                                  transpose_b=True)
  interpolated_points = \
    trilinear.interpolate(background_image,
                          tf.linalg.matrix_transpose(transf_plane_points))
  ground_alpha = (1- tf.reshape(interpolated_points,
                                [batch_size, height, width, 1]))
  ground_image = tf.ones((batch_size, height, width, 3))*ground_color
  return ground_image, ground_alpha
Пример #4
0
def generate_ground_image(height,
                          width,
                          focal,
                          principal_point,
                          camera_rotation_matrix,
                          camera_translation_vector,
                          ground_color=(0.43, 0.43, 0.8)):
    """Generate an image depicting only the ground."""
    background_image = np.ones((height, width, 1, 1), dtype=np.float32)
    background_image[-1,
                     ...] = 0  # Set the bottom line to 0 for proper sampling

    # The projection of the ground depends on the top right corner (approximation)
    plane_point_np = np.array([[3.077984, 2.905388, 0.]], dtype=np.float32)
    plane_point2d = \
      perspective.project(rotation_matrix_3d.rotate(plane_point_np,
                                                    camera_rotation_matrix) +
                          camera_translation_vector.T, focal, principal_point)
    _, y = tf.split(plane_point2d, [1, 1], axis=-1)
    sfactor = y / 256.
    transformation_matrix = (1/sfactor)*np.array([[1, 0, 0],
                                                  [0, 0, 0],
                                                  [0, 0, 0]]) + \
                            np.array([[0, 0, 0],
                                      [0, 1, 0],
                                      [0, 0, 1]])
    plane_points = grid.generate(
        (0., 0., 0.), (float(height), float(width), 0.), (height, width, 1))
    plane_points = tf.reshape(plane_points, [-1, 3])
    transf_plane_points = tf.matmul(transformation_matrix,
                                    plane_points,
                                    transpose_b=True)
    interpolated_points = \
      trilinear.interpolate(background_image, tf.transpose(transf_plane_points))
    ground_alpha = (1 - tf.reshape(interpolated_points, [256, 256, 1]))
    ground_image = tf.ones((256, 256, 3)) * ground_color
    return ground_image, ground_alpha
Пример #5
0
def perspective_transform(
    image: type_alias.TensorLike,
    transform_matrix: type_alias.TensorLike,
    output_shape: Optional[type_alias.TensorLike] = None,
    resampling_type: ResamplingType = ResamplingType.BILINEAR,
    border_type: BorderType = BorderType.ZERO,
    pixel_type: PixelType = PixelType.HALF_INTEGER,
    name: Optional[str] = "perspective_transform",
) -> tf.Tensor:
  """Applies a projective transformation to an image.

  The projective transformation is represented by a 3 x 3 matrix
  [[a0, a1, a2], [b0, b1, b2], [c0, c1, c2]], mapping a point `[x, y]` to a
  transformed point
  `[x', y'] = [(a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k]`, where
  `k = c0 x + c1 y + c2`.

  Note:
      The transformation matrix maps target to source by transforming output
      points to input points.

  Args:
    image: A tensor of shape `[B, H_i, W_i, C]`, where `B` is the batch size,
      `H_i` the height of the image, `W_i` the width of the image, and `C` the
      number of channels of the image.
    transform_matrix: A tensor of shape `[B, 3, 3]` containing projective
      transform matrices. The transformation maps target to source by
      transforming output points to input points.
    output_shape: The heigh `H_o` and width `W_o` output dimensions after the
      transform. If None, output is the same size as input image.
    resampling_type: Resampling mode. Supported values are
      `ResamplingType.NEAREST` and `ResamplingType.BILINEAR`.
    border_type: Border mode. Supported values are `BorderType.ZERO` and
      `BorderType.DUPLICATE`.
    pixel_type: Pixel mode. Supported values are `PixelType.INTEGER` and
      `PixelType.HALF_INTEGER`.
    name: A name for this op. Defaults to "perspective_transform".

  Returns:
    A tensor of shape `[B, H_o, W_o, C]` containing transformed images.

  Raises:
    ValueError: If `image` has rank != 4. If `transform_matrix` has rank < 3 or
    its last two dimensions are not 3. If `image` and `transform_matrix` batch
    dimension does not match.
  """
  with tf.name_scope(name):
    image = tf.convert_to_tensor(value=image, name="image")
    transform_matrix = tf.convert_to_tensor(
        value=transform_matrix, name="transform_matrix")
    output_shape = tf.shape(
        input=image)[-3:-1] if output_shape is None else tf.convert_to_tensor(
            value=output_shape, name="output_shape")

    shape.check_static(image, tensor_name="image", has_rank=4)
    shape.check_static(
        transform_matrix,
        tensor_name="transform_matrix",
        has_rank=3,
        has_dim_equals=((-1, 3), (-2, 3)))
    shape.compare_batch_dimensions(
        tensors=(image, transform_matrix),
        last_axes=0,
        broadcast_compatible=False)

    dtype = image.dtype
    zero = tf.cast(0.0, dtype)
    height, width = tf.unstack(output_shape, axis=-1)
    warp = grid.generate(
        starts=(zero, zero),
        stops=(tf.cast(width, dtype) - 1.0, tf.cast(height, dtype) - 1.0),
        nums=(width, height))
    warp = tf.transpose(a=warp, perm=[1, 0, 2])

    if pixel_type == PixelType.HALF_INTEGER:
      warp += 0.5

    padding = [[0, 0] for _ in range(warp.shape.ndims)]
    padding[-1][-1] = 1
    warp = tf.pad(
        tensor=warp, paddings=padding, mode="CONSTANT", constant_values=1.0)

    warp = warp[..., tf.newaxis]
    transform_matrix = transform_matrix[:, tf.newaxis, tf.newaxis, ...]
    warp = tf.linalg.matmul(transform_matrix, warp)
    warp = warp[..., 0:2, 0] / warp[..., 2, :]

    return sample(image, warp, resampling_type, border_type, pixel_type)
Пример #6
0
def sampling_points_from_3d_grid(grid_size, dtype=tf.float32):
    """Returns a tensor of shape `[M, 3]`, with M the number of sampling points."""
    sampling_points = grid.generate((-1.0, -1.0, -1.0), (1.0, 1.0, 1.0),
                                    grid_size)
    sampling_points = tf.cast(sampling_points, dtype)
    return tf.reshape(sampling_points, [-1, 3])
Пример #7
0
    def test_rasterize_preset(self):
        camera_origin = (0.0, 0.0, 0.0)
        camera_up = (0.0, 1.0, 0.0)
        look_at_point = (0.0, 0.0, 1.0)
        field_of_view = (60 * np.math.pi / 180, )
        near_plane = (0.01, )
        far_plane = (400.0, )

        # Construct the view projection matrix.
        model_to_eye_matrix = look_at.right_handed(camera_origin,
                                                   look_at_point, camera_up)
        perspective_matrix = perspective.right_handed(
            field_of_view, (float(_IMAGE_WIDTH) / float(_IMAGE_HEIGHT), ),
            near_plane, far_plane)
        view_projection_matrix = tf.linalg.matmul(perspective_matrix,
                                                  model_to_eye_matrix)
        view_projection_matrix = tf.expand_dims(view_projection_matrix, axis=0)

        depth = 1.0
        vertices = np.array([[(-2.0 * _TRIANGLE_SIZE, 0.0, depth),
                              (0.0, _TRIANGLE_SIZE, depth), (0.0, 0.0, depth),
                              (0.0, -_TRIANGLE_SIZE, depth)]],
                            dtype=np.float32)
        triangles = np.array(((1, 2, 0), (0, 2, 3)), np.int32)

        predicted_fb = rasterization_backend.rasterize(
            vertices, triangles, view_projection_matrix,
            (_IMAGE_WIDTH, _IMAGE_HEIGHT))

        with self.subTest(name="triangle_index"):
            groundtruth_triangle_index = np.zeros(
                (1, _IMAGE_HEIGHT, _IMAGE_WIDTH, 1), dtype=np.int32)
            groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2, 0] = 0
            groundtruth_triangle_index[..., :_IMAGE_HEIGHT // 2,
                                       _IMAGE_WIDTH // 2:, 0] = 1
            self.assertAllEqual(groundtruth_triangle_index,
                                predicted_fb.triangle_id)

        with self.subTest(name="mask"):
            groundtruth_mask = np.ones((1, _IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
                                       dtype=np.int32)
            groundtruth_mask[..., :_IMAGE_WIDTH // 2, 0] = 0
            self.assertAllEqual(groundtruth_mask, predicted_fb.foreground_mask)

        attributes = np.array(((1.0, 0.0, 0.0), (0.0, 1.0, 0.0),
                               (0.0, 0.0, 1.0))).astype(np.float32)
        perspective_correct_interpolation = lambda geometry, pixels: glm.perspective_correct_interpolation(  # pylint: disable=g-long-lambda,line-too-long
            geometry, attributes, pixels, model_to_eye_matrix,
            perspective_matrix,
            np.array((_IMAGE_WIDTH, _IMAGE_HEIGHT)).astype(np.float32),
            np.array((0.0, 0.0)).astype(np.float32))
        with self.subTest(name="barycentric_coordinates_triangle_0"):
            geometry_0 = tf.gather(vertices, triangles[0, :], axis=1)
            pixels_0 = tf.transpose(grid.generate((3.5, 2.5), (6.5, 4.5),
                                                  (4, 3)),
                                    perm=(1, 0, 2))
            barycentrics_gt_0 = perspective_correct_interpolation(
                geometry_0, pixels_0)
            self.assertAllClose(barycentrics_gt_0,
                                predicted_fb.barycentrics.value[0, 2:, 3:, :],
                                atol=1e-3)

        with self.subTest(name="barycentric_coordinates_triangle_1"):
            geometry_1 = tf.gather(vertices, triangles[1, :], axis=1)
            pixels_1 = tf.transpose(grid.generate((3.5, 0.5), (6.5, 1.5),
                                                  (4, 2)),
                                    perm=(1, 0, 2))
            barycentrics_gt_1 = perspective_correct_interpolation(
                geometry_1, pixels_1)
            self.assertAllClose(barycentrics_gt_1,
                                predicted_fb.barycentrics.value[0, 0:2, 3:, :],
                                atol=1e-3)
Пример #8
0
def random_patches(focal: tf.Tensor,
                   principal_point: tf.Tensor,
                   height: int,
                   width: int,
                   patch_height: int,
                   patch_width: int,
                   scale: float = 1.0,
                   indexing: str = "ij",
                   name: str = None) -> Tuple[tf.Tensor, tf.Tensor]:
    """Sample patches at different scales and from an image.

  Args:
    focal: A tensor of shape `[A1, ..., An, 2]`
    principal_point: A tensor of shape `[A1, ..., An, 2]`
    height: The height of the image plane in pixels.
    width: The width of the image plane in pixels.
    patch_height: The height M of the patch in pixels.
    patch_width: The width N of the patch in pixels.
    scale: The scale of the patch.
    indexing: Indexing of the patch ('ij' or 'xy')
    name: A name for this op that defaults to "random_patches".

  Returns:
    A tensor of shape `[A1, ..., An, M*N, 3]` where the last dimension is the
      ray directions in 3D passing from the M*N pixels of the patch and
    a tensor of shape `[A1, ..., An, M*N, 2]` with the pixel x, y locations.
  """
    with tf.compat.v1.name_scope(name, "random_patches",
                                 [focal, principal_point]):
        focal = tf.convert_to_tensor(value=focal)
        principal_point = tf.convert_to_tensor(value=principal_point)

        shape.check_static(tensor=focal,
                           tensor_name="focal",
                           has_dim_equals=(-1, 2))
        shape.check_static(tensor=principal_point,
                           tensor_name="principal_point",
                           has_dim_equals=(-1, 2))
        shape.compare_batch_dimensions(tensors=(focal, principal_point),
                                       tensor_names=("focal",
                                                     "principal_point"),
                                       last_axes=-2,
                                       broadcast_compatible=True)

        if indexing not in ["xy", "ij"]:
            raise ValueError("'axis' needs to be 'xy' or 'ij'")

        batch_shape = tf.shape(focal)[:-1]
        patch = grid.generate([0, 0], [patch_width - 1, patch_height - 1],
                              [patch_width, patch_height])
        if indexing == "xy":
            patch = tf.reverse(patch, axis=[-1])
        patch = tf.cast(patch, tf.float32)
        patch = patch * scale

        interm_shape = tf.concat([tf.ones_like(batch_shape),
                                  tf.shape(patch)],
                                 axis=0)
        patch = tf.reshape(patch, interm_shape)

        random_y = tf.random.uniform(batch_shape,
                                     minval=0,
                                     maxval=height -
                                     int(patch_height * scale) + 1,
                                     dtype=tf.int32)
        random_x = tf.random.uniform(batch_shape,
                                     minval=0,
                                     maxval=width - int(patch_width * scale) +
                                     1,
                                     dtype=tf.int32)

        patch_origins = tf.cast(tf.stack([random_x, random_y], axis=-1),
                                tf.float32)
        patch_origins = tf.expand_dims(tf.expand_dims(patch_origins, -2), -2)

        pixels = tf.cast(patch + patch_origins, tf.float32)

        final_shape = tf.concat([batch_shape, [patch_height * patch_width, 2]],
                                axis=0)
        pixels = tf.reshape(pixels, final_shape)

        rays = ray(pixels, tf.expand_dims(focal, -2),
                   tf.expand_dims(principal_point, -2))
        return rays, pixels
Пример #9
0
    def test_rasterize_preset(self):
        model_to_eye_matrix = rasterization_test_utils.make_look_at_matrix(
            look_at_point=(0.0, 0.0, 1.0))
        perspective_matrix = rasterization_test_utils.make_perspective_matrix(
            _IMAGE_WIDTH, _IMAGE_HEIGHT)
        view_projection_matrix = tf.linalg.matmul(perspective_matrix,
                                                  model_to_eye_matrix)
        view_projection_matrix = tf.expand_dims(view_projection_matrix, axis=0)

        depth = 1.0
        vertices = np.array([[(-2.0 * _TRIANGLE_SIZE, 0.0, depth),
                              (0.0, _TRIANGLE_SIZE, depth), (0.0, 0.0, depth),
                              (0.0, -_TRIANGLE_SIZE, depth)]],
                            dtype=np.float32)
        triangles = np.array(((1, 2, 0), (0, 2, 3)), np.int32)

        predicted_fb = _proxy_rasterize(vertices, triangles,
                                        view_projection_matrix)

        with self.subTest(name="triangle_index"):
            groundtruth_triangle_index = np.zeros(
                (1, _IMAGE_HEIGHT, _IMAGE_WIDTH, 1), dtype=np.int32)
            groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2, 0] = 0
            groundtruth_triangle_index[..., :_IMAGE_HEIGHT // 2,
                                       _IMAGE_WIDTH // 2:, 0] = 1
            self.assertAllEqual(groundtruth_triangle_index,
                                predicted_fb.triangle_id)

        with self.subTest(name="mask"):
            groundtruth_mask = np.ones((1, _IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
                                       dtype=np.int32)
            groundtruth_mask[..., :_IMAGE_WIDTH // 2, 0] = 0
            self.assertAllEqual(groundtruth_mask, predicted_fb.foreground_mask)

        attributes = np.array(((1.0, 0.0, 0.0), (0.0, 1.0, 0.0),
                               (0.0, 0.0, 1.0))).astype(np.float32)
        perspective_correct_interpolation = lambda geometry, pixels: glm.perspective_correct_interpolation(  # pylint: disable=g-long-lambda,line-too-long
            geometry, attributes, pixels, model_to_eye_matrix,
            perspective_matrix,
            np.array((_IMAGE_WIDTH, _IMAGE_HEIGHT)).astype(np.float32),
            np.array((0.0, 0.0)).astype(np.float32))
        with self.subTest(name="barycentric_coordinates_triangle_0"):
            geometry_0 = tf.gather(vertices, triangles[0, :], axis=1)
            pixels_0 = tf.transpose(grid.generate((3.5, 2.5), (6.5, 4.5),
                                                  (4, 3)),
                                    perm=(1, 0, 2))
            barycentrics_gt_0 = perspective_correct_interpolation(
                geometry_0, pixels_0)
            self.assertAllClose(barycentrics_gt_0,
                                predicted_fb.barycentrics.value[0, 2:, 3:, :],
                                atol=1e-3)

        with self.subTest(name="barycentric_coordinates_triangle_1"):
            geometry_1 = tf.gather(vertices, triangles[1, :], axis=1)
            pixels_1 = tf.transpose(grid.generate((3.5, 0.5), (6.5, 1.5),
                                                  (4, 2)),
                                    perm=(1, 0, 2))
            barycentrics_gt_1 = perspective_correct_interpolation(
                geometry_1, pixels_1)
            self.assertAllClose(barycentrics_gt_1,
                                predicted_fb.barycentrics.value[0, 0:2, 3:, :],
                                atol=1e-3)
Пример #10
0
    def update(self, labeled_sdfs, labeled_classes, labeled_poses,
               predicted_sdfs, predicted_classes, predicted_poses):
        """Update."""
        if labeled_sdfs or labeled_classes:
            print(labeled_sdfs)
        mean_x = tf.reduce_mean(labeled_poses[1][:, 0])
        mean_z = tf.reduce_mean(labeled_poses[1][:, 2])
        samples_world = grid.generate(
            (mean_x - 0.5, 0.0, mean_z - 0.5),
            (mean_x + 0.5, 1.0, mean_z + 0.5),
            [self.resolution, self.resolution, self.resolution])
        samples_world = tf.reshape(samples_world, [-1, 3])

        status = False
        if status:
            _, axs = plt.subplots(3, 3)
            fig_obj_count = 0

        # Do the same for the ground truth and predictions
        num_collisions = 0
        prev_intersection = 0
        sdf_values = tf.zeros_like(samples_world)[:, 0:1]
        for classes, sdfs, poses in [(predicted_classes, predicted_sdfs,
                                      predicted_poses)]:
            for i in range(classes.shape[0]):
                sdf = tf.expand_dims(sdfs[i], -1)
                sdf = sdf * -1.0  # inside positive, outside zero
                samples_object = centernet_utils.transform_pointcloud(
                    tf.reshape(samples_world, [1, 1, -1, 3]),
                    tf.reshape(poses[2][i], [1, 1, 3]),
                    tf.reshape(poses[0][i], [1, 1, 3, 3]),
                    tf.reshape(poses[1][i], [1, 1, 3]),
                    inverse=True) * 2.0
                samples_object = (samples_object *
                                  (29.0 / 32.0) / 2.0 + 0.5) * 32.0 - 0.5
                samples = tf.squeeze(samples_object)
                interpolated = trilinear.interpolate(sdf, samples)
                occupancy_value = tf.math.sign(
                    tf.nn.relu(interpolated + self.tol))
                sdf_values += occupancy_value
                intersection = tf.reduce_sum(
                    tf.math.sign(tf.nn.relu(sdf_values - 1)))
                if intersection > prev_intersection:
                    prev_intersection = intersection
                    num_collisions += 1
                status2 = False
                if status2:
                    a = 1
                    values = interpolated
                    inter = tf.reshape(
                        values,
                        [self.resolution, self.resolution, self.resolution])
                    inter = tf.transpose(tf.reduce_max(inter, axis=a))
                    im = axs[fig_obj_count, 0].matshow(inter.numpy())
                    plt.colorbar(im, ax=axs[fig_obj_count, 0])

                    values = tf.math.sign(tf.nn.relu(interpolated + self.tol))
                    inter = tf.reshape(
                        values,
                        [self.resolution, self.resolution, self.resolution])
                    inter = tf.transpose(tf.reduce_max(inter, axis=a))
                    im = axs[fig_obj_count, 1].matshow(inter.numpy())
                    plt.colorbar(im, ax=axs[fig_obj_count, 1])

                    values = sdf_values
                    inter = tf.reshape(
                        values,
                        [self.resolution, self.resolution, self.resolution])
                    inter = tf.transpose(tf.reduce_max(inter, axis=a))
                    im = axs[fig_obj_count, 2].matshow(inter.numpy())
                    plt.colorbar(im, ax=axs[fig_obj_count, 2])

                    fig_obj_count += 1

        intersection = tf.reduce_sum(tf.math.sign(tf.nn.relu(sdf_values - 1)))
        union = tf.reduce_sum(tf.math.sign(sdf_values))
        iou = intersection / union
        self.collisions.append(num_collisions)
        self.intersections.append(intersection)
        self.ious.append(iou)
        return num_collisions, intersection, iou
Пример #11
0
    def update(self, labeled_sdfs, labeled_classes, labeled_poses,
               predicted_sdfs, predicted_classes, predicted_poses):
        """Update."""
        labeled_rotations = labeled_poses[0]
        labeled_translations = labeled_poses[1]
        labeled_sizes = labeled_poses[2]

        status = True
        if status:
            box_limits_x = [100, -100]
            # box_limits_y = [100, -100]
            box_limits_z = [100, -100]
            for i in range(labeled_translations.shape[0]):
                rot = tf.reshape(tf.gather(labeled_rotations[i], [0, 2, 6, 8]),
                                 [2, 2])

                min_x = tf.cast(0.0 - labeled_sizes[i][0] / 2.0,
                                dtype=tf.float32)
                max_x = tf.cast(0.0 + labeled_sizes[i][0] / 2.0,
                                dtype=tf.float32)
                # min_y = tf.cast(0.0 - labeled_sizes[i][1] / 2.0, dtype=tf.float32)
                # max_y = tf.cast(0.0 + labeled_sizes[i][1] / 2.0, dtype=tf.float32)
                min_z = tf.cast(0.0 - labeled_sizes[i][2] / 2.0,
                                dtype=tf.float32)
                max_z = tf.cast(0.0 + labeled_sizes[i][2] / 2.0,
                                dtype=tf.float32)

                translation = tf.reshape(
                    [labeled_translations[i][0], labeled_translations[i][2]],
                    [2, 1])

                pt_0 = rot @ tf.reshape([min_x, min_z], [2, 1]) + translation
                pt_1 = rot @ tf.reshape([min_x, max_z], [2, 1]) + translation
                pt_2 = rot @ tf.reshape([max_x, min_z], [2, 1]) + translation
                pt_3 = rot @ tf.reshape([max_x, max_z], [2, 1]) + translation

                for pt in [pt_0, pt_1, pt_2, pt_3]:
                    if pt[0] < box_limits_x[0]:
                        box_limits_x[0] = pt[0]

                    if pt[0] > box_limits_x[1]:
                        box_limits_x[1] = pt[0]

                    if pt[1] < box_limits_z[0]:
                        box_limits_z[0] = pt[1]

                    if pt[1] > box_limits_z[1]:
                        box_limits_z[1] = pt[1]
            mean_x = tf.reduce_mean(box_limits_x)
            mean_z = tf.reduce_mean(box_limits_z)
        else:
            mean_x = tf.reduce_mean(labeled_translations[:, 0])
            mean_z = tf.reduce_mean(labeled_translations[:, 2])
        samples_world = grid.generate(
            (mean_x - 0.5, 0.0, mean_z - 0.5),
            (mean_x + 0.5, 1.0, mean_z + 0.5),
            [self.resolution, self.resolution, self.resolution])
        # samples_world = grid.generate(
        #     (box_limits_x[0][0], box_limits_y[0], box_limits_z[0][0]),
        #     (box_limits_x[1][0], box_limits_y[1], box_limits_z[1][0]),
        #     [self.resolution, self.resolution, self.resolution])
        # samples_world = grid.generate(
        #     (-5.0, -5.0, -5.0),
        #     (5.0, 5.0, 5.0),
        #     [self.resolution, self.resolution, self.resolution])
        samples_world = tf.reshape(samples_world, [-1, 3])
        ious = []

        status = False
        if status:
            _, axs = plt.subplots(labeled_translations.shape[0], 5)
            fig_obj_count = 0
        for class_id in range(self.max_num_classes):
            # Do the same for the ground truth and predictions
            sdf_values = tf.zeros_like(samples_world)[:, 0:1]
            for mtype, (classes, sdfs, poses) in enumerate([
                (labeled_classes, labeled_sdfs, labeled_poses),
                (predicted_classes, predicted_sdfs, predicted_poses)
            ]):
                for i in range(classes.shape[0]):
                    if class_id == classes[i]:
                        sdf = tf.expand_dims(sdfs[i], -1)
                        sdf = sdf * -1.0  # inside positive, outside zero
                        samples_object = centernet_utils.transform_pointcloud(
                            tf.reshape(samples_world, [1, 1, -1, 3]),
                            tf.reshape(poses[2][i], [1, 1, 3]),
                            tf.reshape(poses[0][i], [1, 1, 3, 3]),
                            tf.reshape(poses[1][i], [1, 1, 3]),
                            inverse=True) * 2.0
                        samples_object = \
                            (samples_object * (29.0/32.0) / 2.0 + 0.5) * 32.0 - 0.5
                        samples = tf.squeeze(samples_object)
                        interpolated = trilinear.interpolate(sdf, samples)

                        sdf_values += tf.math.sign(
                            tf.nn.relu(interpolated + self.tol))
                        status2 = False
                        if status2:
                            a = 2
                            values = interpolated
                            inter = tf.reshape(values, [
                                self.resolution, self.resolution,
                                self.resolution
                            ])
                            inter = tf.transpose(tf.reduce_max(inter, axis=a))
                            im = axs[fig_obj_count,
                                     mtype * 2 + 0].matshow(inter.numpy())
                            plt.colorbar(im,
                                         ax=axs[fig_obj_count, mtype * 2 + 0])
                            print(mtype, fig_obj_count, 0)

                            values = tf.math.sign(
                                tf.nn.relu(interpolated + self.tol))
                            inter = tf.reshape(values, [
                                self.resolution, self.resolution,
                                self.resolution
                            ])
                            inter = tf.transpose(tf.reduce_max(inter, axis=a))
                            im = axs[fig_obj_count,
                                     mtype * 2 + 1].matshow(inter.numpy())
                            plt.colorbar(im,
                                         ax=axs[fig_obj_count, mtype * 2 + 1])
                            print(mtype, fig_obj_count, 1)

                            if mtype == 1:
                                values = sdf_values
                                inter = tf.reshape(values, [
                                    self.resolution, self.resolution,
                                    self.resolution
                                ])
                                inter = tf.transpose(
                                    tf.reduce_max(inter, axis=a))
                                im = axs[fig_obj_count,
                                         4].matshow(inter.numpy())
                                plt.colorbar(im, ax=axs[fig_obj_count, 4])
                                print(mtype, fig_obj_count, 2)
                                fig_obj_count += 1

            intersection = tf.reduce_sum(
                tf.math.sign(tf.nn.relu(sdf_values - 1)))
            union = tf.reduce_sum(tf.math.sign(sdf_values))
            iou = intersection / union
            if not tf.math.is_nan(iou):
                ious.append(iou)
            status3 = False
            if status3:
                _ = plt.figure(figsize=(5, 5))
                plt.clf()
                # mask = (sdf_values.numpy() > 0)[:, 0]
                # plt.scatter(samples_world.numpy()[mask, 0],
                #             samples_world.numpy()[mask, 1],
                #             marker='.', c=sdf_values.numpy()[mask, 0])

                plt.scatter(samples_world.numpy()[:, 0],
                            samples_world.numpy()[:, 1],
                            marker='.',
                            c=sdf_values.numpy()[:, 0])
                plt.colorbar()
            if not tf.math.is_nan(iou):
                self.iou_per_class[class_id].append(iou)
        if ious:
            ious = [0]
        return np.mean(ious), np.min(ious)