def test_construction_methods(self, max_num_points, batch_size,
                                  batch_shape):
        points, sizes = utils._create_random_point_cloud_padded(
            max_num_points, batch_shape)
        num_points = np.sum(sizes)

        sizes_flat = sizes.reshape([batch_size])
        points_flat = points.reshape([batch_size, max_num_points, 3])
        batch_ids = np.repeat(np.arange(0, batch_size), sizes_flat)

        points_seg = np.empty([num_points, 3])
        cur_id = 0
        for pts, size in zip(points_flat, sizes_flat):
            points_seg[cur_id:cur_id + size] = pts[:size]
            cur_id += size

        pc_from_padded = PointCloud(points, sizes=sizes)
        self.assertAllEqual(batch_ids, pc_from_padded._batch_ids)
        self.assertAllClose(points_seg, pc_from_padded._points)

        pc_from_ids = PointCloud(points_seg, batch_ids)
        pc_from_ids.set_batch_shape(batch_shape)

        pc_from_sizes = PointCloud(points_seg, sizes=sizes_flat)
        pc_from_sizes.set_batch_shape(batch_shape)
        self.assertAllEqual(batch_ids, pc_from_sizes._batch_ids)

        points_from_padded = pc_from_padded.get_points(
            max_num_points=max_num_points)
        points_from_ids = pc_from_ids.get_points(max_num_points=max_num_points)
        points_from_sizes = pc_from_sizes.get_points(
            max_num_points=max_num_points)

        self.assertAllEqual(points_from_padded, points_from_ids)
        self.assertAllEqual(points_from_ids, points_from_sizes)
        self.assertAllEqual(points_from_sizes, points_from_padded)
 def test_flatten_unflatten_padded(self, batch_shape, num_points,
                                   dimension):
     batch_size = np.prod(batch_shape)
     points, sizes = utils._create_random_point_cloud_padded(
         num_points, batch_shape, dimension=dimension)
     point_cloud = PointCloud(points, sizes=sizes)
     retrieved_points = point_cloud.get_points().numpy()
     self.assertAllEqual(points.shape, retrieved_points.shape)
     points = points.reshape([batch_size, num_points, dimension])
     retrieved_points = retrieved_points.reshape(
         [batch_size, num_points, dimension])
     sizes = sizes.reshape([batch_size])
     for i in range(batch_size):
         self.assertAllClose(points[i, :sizes[i]],
                             retrieved_points[i, :sizes[i]])
         self.assertTrue(np.all(retrieved_points[i, sizes[i]:] == 0))