예제 #1
0
 def test_equal(self):
     label1 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True)
     label2 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True)
     self.assertTrue(label1.equal(label1, ['Car', 'Van'], rtol=1e-5))
     self.assertTrue(label1.equal(label2, ['Car', 'Van'], rtol=1e-5))
     self.assertTrue(label2.equal(label1, ['Car', 'Van'], rtol=1e-5))
     self.assertTrue(label2.equal(label2, ['Car', 'Van'], rtol=1e-5))
     label1 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True)
     label2 = KittiLabel('./unit-test/data/test_KittiLabel_000003.txt').read_label_file(no_dontcare=True)
     self.assertTrue(not label1.equal(label2, ['Car', 'Van'], rtol=1e-5))
     self.assertTrue(not label2.equal(label1, ['Car', 'Van'], rtol=1e-5))
예제 #2
0
 def test_targets(self):
     def limit_period_torch(val, offset=0.5, period=np.pi):
         return val - torch.floor(val / period + offset) * period
     import torch
     import torch.nn as nn
     from det3.methods.second.ops.torch_ops import rotate_nms
     from det3.dataloader.kittidata import KittiCalib
     for i, data in enumerate(self.dataloader):
         if i == 2:
             break
         else:
             continue
     label = data["metadata"][0]["label"]
     tag = data["metadata"][0]["tag"]
     cls_pred = torch.from_numpy(data["labels"]).cuda().float()
     cls_pred *= (cls_pred >= 0).float()
     cls_pred = cls_pred.long()
     cls_pred = nn.functional.one_hot(cls_pred, num_classes=2+1)
     cls_pred = cls_pred[..., 1:]
     anchors = torch.from_numpy(data["anchors"]).cuda().float()
     box_pred = torch.from_numpy(data["reg_targets"]).cuda().float()
     # pred_dict = {
     #     "cls_preds": cls_pred * 10,
     #     "box_preds": box_pred
     # }
     # box_coder = self.box_coder
     # from det3.ops import write_pkl
     # write_pkl({"pred_dict": pred_dict, "box_coder": box_coder}, "test_model_est.pkl")
     box_pred = self.box_coder.decode(box_pred, anchors)
     for box_preds, cls_preds in zip(box_pred, cls_pred):
         box_preds = box_preds.float()
         cls_preds = cls_preds.float()
         total_scores = cls_preds
         nms_func = rotate_nms
         top_scores, top_labels = torch.max(
             total_scores, dim=-1)
         top_scores_keep = top_scores >= 0.5
         top_scores = top_scores.masked_select(top_scores_keep)
         box_preds = box_preds[top_scores_keep]
         top_labels = top_labels[top_scores_keep]
         boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
         selected = nms_func(
             boxes_for_nms,
             top_scores,
             pre_max_size=1000,
             post_max_size=1000,
             iou_threshold=0.3,
         )
         selected_boxes = box_preds[selected]
         selected_labels = top_labels[selected]
         selected_scores = top_scores[selected]
         box_preds = selected_boxes
         scores = selected_scores
         label_preds = selected_labels
         final_box_preds = box_preds
         final_scores = scores
         final_labels = label_preds
         predictions_dict = {
             "box3d_lidar": final_box_preds,
             "scores": final_scores,
             "label_preds": label_preds,
         }
         from det3.dataloader.kittidata import KittiObj, KittiLabel
         label_gt = KittiLabel()
         label_est = KittiLabel()
         calib = KittiCalib(f"unit_tests/data/test_kittidata/training/calib/{tag}.txt").read_calib_file()
         for obj_str in label.split("\n"):
             if len(obj_str) == 0:
                 continue
             obj = KittiObj(obj_str)
             if obj.type not in ["Car", "Pedestrian"]:
                 continue
             bcenter_Fcam = np.array([obj.x, obj.y, obj.z]).reshape(-1, 3)
             bcenter_Flidar = calib.leftcam2lidar(bcenter_Fcam)
             center_Flidar = bcenter_Flidar + np.array([0, 0, obj.h/2.0]).reshape(-1, 3)
             if (center_Flidar[0, 0] < 0 or center_Flidar[0, 0] > 52.8
                 or center_Flidar[0, 1] < -30 or center_Flidar[0, 1] > 30
                 or center_Flidar[0, 2] < -3 or center_Flidar[0, 2] > 1):
                 continue
             obj.truncated = 0
             obj.occluded = 0
             obj.alpha = 0
             obj.bbox_l = 0
             obj.bbox_t = 0
             obj.bbox_r = 0
             obj.bbox_b = 0
             label_gt.add_obj(obj)
         for box3d_lidar, label_preds, score in zip(
                 predictions_dict["box3d_lidar"],
                 predictions_dict["label_preds"],
                 predictions_dict["scores"]):
             obj = KittiObj()
             obj.type = "Car" if label_preds == 0 else "Pedestrian"
             xyzwlhry_Flidar = box3d_lidar.cpu().numpy().flatten()
             bcenter_Flidar = xyzwlhry_Flidar[:3].reshape(-1, 3)
             bcenter_Fcam = calib.lidar2leftcam(bcenter_Flidar)
             obj.x, obj.y, obj.z = bcenter_Fcam.flatten()
             obj.w, obj.l, obj.h, obj.ry = xyzwlhry_Flidar[3:]
             obj.truncated = 0
             obj.occluded = 0
             obj.alpha = 0
             obj.bbox_l = 0
             obj.bbox_t = 0
             obj.bbox_r = 0
             obj.bbox_b = 0
             label_est.add_obj(obj)
         self.assertTrue(label_gt.equal(label_est, acc_cls=["Car", "Pedestrian"], rtol=1e-2))