def test_intersection3d_9dof_box(self): boxes1_rotations = np.tile(np.expand_dims(np.eye(3), axis=0), [2, 1, 1]) boxes1_length = np.array([1.0, 1.0]) boxes1_height = np.array([1.0, 1.0]) boxes1_width = np.array([1.0, 1.0]) boxes1_center = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) boxes2_rotations = np.expand_dims(np.eye(3), axis=0) boxes2_length = np.array([1.0]) boxes2_height = np.array([1.0]) boxes2_width = np.array([1.0]) boxes2_center = np.array([[0.0, 0.0, 0.0]]) intersections = np_box_ops.intersection3d_9dof_box( boxes1_length=boxes1_length, boxes1_height=boxes1_height, boxes1_width=boxes1_width, boxes1_center=boxes1_center, boxes1_rotation_matrix=boxes1_rotations, boxes2_length=boxes2_length, boxes2_height=boxes2_height, boxes2_width=boxes2_width, boxes2_center=boxes2_center, boxes2_rotation_matrix=boxes2_rotations) expected_intersections = np.array([[1.0], [0.0]], dtype=np.float32) self.assertAllClose(intersections, expected_intersections, rtol=0.1, atol=1.0)
def intersection3d(boxlist1, boxlist2): """Computes pairwise intersection areas between boxes. Args: boxlist1: BoxList3d holding N boxes. boxlist2: BoxList3d holding M boxes. Returns: a numpy array with shape [N*M] representing pairwise intersection area """ boxlist1_rotation_matrix = boxlist1.get_rotation_matrix() boxlist2_rotation_matrix = boxlist2.get_rotation_matrix() if (boxlist1_rotation_matrix is not None) and (boxlist2_rotation_matrix is not None): return np_box_ops.intersection3d_9dof_box( boxes1_length=boxlist1.get_length(), boxes1_height=boxlist1.get_height(), boxes1_width=boxlist1.get_width(), boxes1_center=boxlist1.get_center(), boxes1_rotation_matrix=boxlist1_rotation_matrix, boxes2_length=boxlist2.get_length(), boxes2_height=boxlist2.get_height(), boxes2_width=boxlist2.get_width(), boxes2_center=boxlist2.get_center(), boxes2_rotation_matrix=boxlist2_rotation_matrix) else: return np_box_ops.intersection3d_7dof_box( boxes1_length=boxlist1.get_length(), boxes1_height=boxlist1.get_height(), boxes1_width=boxlist1.get_width(), boxes1_center=boxlist1.get_center(), boxes1_rotation_z_radians=boxlist1.get_rotation_z_radians(), boxes2_length=boxlist2.get_length(), boxes2_height=boxlist2.get_height(), boxes2_width=boxlist2.get_width(), boxes2_center=boxlist2.get_center(), boxes2_rotation_z_radians=boxlist2.get_rotation_z_radians())