def test_equal(self): label1 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True) label2 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True) self.assertTrue(label1.equal(label1, ['Car', 'Van'], rtol=1e-5)) self.assertTrue(label1.equal(label2, ['Car', 'Van'], rtol=1e-5)) self.assertTrue(label2.equal(label1, ['Car', 'Van'], rtol=1e-5)) self.assertTrue(label2.equal(label2, ['Car', 'Van'], rtol=1e-5)) label1 = KittiLabel('./unit-test/data/test_KittiLabel_000012.txt').read_label_file(no_dontcare=True) label2 = KittiLabel('./unit-test/data/test_KittiLabel_000003.txt').read_label_file(no_dontcare=True) self.assertTrue(not label1.equal(label2, ['Car', 'Van'], rtol=1e-5)) self.assertTrue(not label2.equal(label1, ['Car', 'Van'], rtol=1e-5))
def test_targets(self): def limit_period_torch(val, offset=0.5, period=np.pi): return val - torch.floor(val / period + offset) * period import torch import torch.nn as nn from det3.methods.second.ops.torch_ops import rotate_nms from det3.dataloader.kittidata import KittiCalib for i, data in enumerate(self.dataloader): if i == 2: break else: continue label = data["metadata"][0]["label"] tag = data["metadata"][0]["tag"] cls_pred = torch.from_numpy(data["labels"]).cuda().float() cls_pred *= (cls_pred >= 0).float() cls_pred = cls_pred.long() cls_pred = nn.functional.one_hot(cls_pred, num_classes=2+1) cls_pred = cls_pred[..., 1:] anchors = torch.from_numpy(data["anchors"]).cuda().float() box_pred = torch.from_numpy(data["reg_targets"]).cuda().float() # pred_dict = { # "cls_preds": cls_pred * 10, # "box_preds": box_pred # } # box_coder = self.box_coder # from det3.ops import write_pkl # write_pkl({"pred_dict": pred_dict, "box_coder": box_coder}, "test_model_est.pkl") box_pred = self.box_coder.decode(box_pred, anchors) for box_preds, cls_preds in zip(box_pred, cls_pred): box_preds = box_preds.float() cls_preds = cls_preds.float() total_scores = cls_preds nms_func = rotate_nms top_scores, top_labels = torch.max( total_scores, dim=-1) top_scores_keep = top_scores >= 0.5 top_scores = top_scores.masked_select(top_scores_keep) box_preds = box_preds[top_scores_keep] top_labels = top_labels[top_scores_keep] boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]] selected = nms_func( boxes_for_nms, top_scores, pre_max_size=1000, post_max_size=1000, iou_threshold=0.3, ) selected_boxes = box_preds[selected] selected_labels = top_labels[selected] selected_scores = top_scores[selected] box_preds = selected_boxes scores = selected_scores label_preds = selected_labels final_box_preds = box_preds final_scores = scores final_labels = label_preds predictions_dict = { "box3d_lidar": final_box_preds, "scores": final_scores, "label_preds": label_preds, } from det3.dataloader.kittidata import KittiObj, KittiLabel label_gt = KittiLabel() label_est = KittiLabel() calib = KittiCalib(f"unit_tests/data/test_kittidata/training/calib/{tag}.txt").read_calib_file() for obj_str in label.split("\n"): if len(obj_str) == 0: continue obj = KittiObj(obj_str) if obj.type not in ["Car", "Pedestrian"]: continue bcenter_Fcam = np.array([obj.x, obj.y, obj.z]).reshape(-1, 3) bcenter_Flidar = calib.leftcam2lidar(bcenter_Fcam) center_Flidar = bcenter_Flidar + np.array([0, 0, obj.h/2.0]).reshape(-1, 3) if (center_Flidar[0, 0] < 0 or center_Flidar[0, 0] > 52.8 or center_Flidar[0, 1] < -30 or center_Flidar[0, 1] > 30 or center_Flidar[0, 2] < -3 or center_Flidar[0, 2] > 1): continue obj.truncated = 0 obj.occluded = 0 obj.alpha = 0 obj.bbox_l = 0 obj.bbox_t = 0 obj.bbox_r = 0 obj.bbox_b = 0 label_gt.add_obj(obj) for box3d_lidar, label_preds, score in zip( predictions_dict["box3d_lidar"], predictions_dict["label_preds"], predictions_dict["scores"]): obj = KittiObj() obj.type = "Car" if label_preds == 0 else "Pedestrian" xyzwlhry_Flidar = box3d_lidar.cpu().numpy().flatten() bcenter_Flidar = xyzwlhry_Flidar[:3].reshape(-1, 3) bcenter_Fcam = calib.lidar2leftcam(bcenter_Flidar) obj.x, obj.y, obj.z = bcenter_Fcam.flatten() obj.w, obj.l, obj.h, obj.ry = xyzwlhry_Flidar[3:] obj.truncated = 0 obj.occluded = 0 obj.alpha = 0 obj.bbox_l = 0 obj.bbox_t = 0 obj.bbox_r = 0 obj.bbox_b = 0 label_est.add_obj(obj) self.assertTrue(label_gt.equal(label_est, acc_cls=["Car", "Pedestrian"], rtol=1e-2))