def test_velocity_l2(self): """Test for velocity_l2().""" # Same velocity. sa = DetectionBox(velocity=(4, 4)) sr = DetectionBox(velocity=(4, 4)) self.assertAlmostEqual(velocity_l2(sa, sr), 0) # Negative values. sa = DetectionBox(velocity=(-1, -1)) sr = DetectionBox(velocity=(1, 1)) self.assertAlmostEqual(velocity_l2(sa, sr), np.sqrt((1 + 1)**2 + (1 + 1)**2)) # Arbitrary values. sa = DetectionBox(velocity=(8.2, 1.4)) sr = DetectionBox(velocity=(6.4, -9.4)) self.assertAlmostEqual(velocity_l2(sa, sr), np.sqrt((6.4 - 8.2)**2 + (-9.4 - 1.4)**2))
def accumulate(gt_boxes: EvalBoxes, pred_boxes: EvalBoxes, class_name: str, dist_fcn: Callable, dist_th: float, verbose: bool = False) -> DetectionMetricData: """ Average Precision over predefined different recall thresholds for a single distance threshold. The recall/conf thresholds and other raw metrics will be used in secondary metrics. :param gt_boxes: Maps every sample_token to a list of its sample_annotations. :param pred_boxes: Maps every sample_token to a list of its sample_results. :param class_name: Class to compute AP on. :param dist_fcn: Distance function used to match detections and ground truths. :param dist_th: Distance threshold for a match. :param verbose: If true, print debug messages. :return: (average_prec, metrics). The average precision value and raw data for a number of metrics. """ # --------------------------------------------- # Organize input and initialize accumulators. # --------------------------------------------- # Count the positives. npos = len([1 for gt_box in gt_boxes.all if gt_box.detection_name == class_name]) if verbose: print("Found {} GT of class {} out of {} total across {} samples.". format(npos, class_name, len(gt_boxes.all), len(gt_boxes.sample_tokens))) # For missing classes in the GT, return a data structure corresponding to no predictions. if npos == 0: return DetectionMetricData.no_predictions() # Organize the predictions in a single list. pred_boxes_list = [box for box in pred_boxes.all if box.detection_name == class_name] pred_confs = [box.detection_score for box in pred_boxes_list] if verbose: print("Found {} PRED of class {} out of {} total across {} samples.". format(len(pred_confs), class_name, len(pred_boxes.all), len(pred_boxes.sample_tokens))) # Sort by confidence. sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1] # Do the actual matching. tp = [] # Accumulator of true positives. fp = [] # Accumulator of false positives. conf = [] # Accumulator of confidences. # match_data holds the extra metrics we calculate for each match. match_data = {'trans_err': [], 'vel_err': [], 'scale_err': [], 'orient_err': [], 'attr_err': [], 'conf': [], 'ego_dist': [], 'vel_magn': []} # --------------------------------------------- # Match and accumulate match data. # --------------------------------------------- taken = set() # Initially no gt bounding box is matched. for ind in sortind: pred_box = pred_boxes_list[ind] min_dist = np.inf match_gt_idx = None for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]): # Find closest match among ground truth boxes if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken: this_distance = dist_fcn(gt_box, pred_box) if this_distance < min_dist: min_dist = this_distance match_gt_idx = gt_idx # If the closest match is close enough according to threshold we have a match! is_match = min_dist < dist_th if is_match: taken.add((pred_box.sample_token, match_gt_idx)) # Update tp, fp and confs. tp.append(1) fp.append(0) conf.append(pred_box.detection_score) # Since it is a match, update match data also. gt_box_match = gt_boxes[pred_box.sample_token][match_gt_idx] match_data['trans_err'].append(center_distance(gt_box_match, pred_box)) match_data['vel_err'].append(velocity_l2(gt_box_match, pred_box)) match_data['scale_err'].append(1 - scale_iou(gt_box_match, pred_box)) # Barrier orientation is only determined up to 180 degree. (For cones orientation is discarded later) period = np.pi if class_name == 'barrier' else 2 * np.pi match_data['orient_err'].append(yaw_diff(gt_box_match, pred_box, period=period)) match_data['attr_err'].append(1 - attr_acc(gt_box_match, pred_box)) match_data['conf'].append(pred_box.detection_score) # For debugging only. match_data['ego_dist'].append(gt_box_match.ego_dist) match_data['vel_magn'].append(np.sqrt(np.sum(np.array(gt_box_match.velocity) ** 2))) else: # No match. Mark this as a false positive. tp.append(0) fp.append(1) conf.append(pred_box.detection_score) # Check if we have any matches. If not, just return a "no predictions" array. if len(match_data['trans_err']) == 0: return DetectionMetricData.no_predictions() # --------------------------------------------- # Calculate and interpolate precision and recall # --------------------------------------------- # Accumulate. tp = np.cumsum(tp).astype(np.float) fp = np.cumsum(fp).astype(np.float) conf = np.array(conf) # Calculate precision and recall. prec = tp / (fp + tp) rec = tp / float(npos) rec_interp = np.linspace(0, 1, DetectionMetricData.nelem) # 101 steps, from 0% to 100% recall. prec = np.interp(rec_interp, rec, prec, right=0) conf = np.interp(rec_interp, rec, conf, right=0) rec = rec_interp # --------------------------------------------- # Re-sample the match-data to match, prec, recall and conf. # --------------------------------------------- for key in match_data.keys(): if key == "conf": continue # Confidence is used as reference to align with fp and tp. So skip in this step. else: # For each match_data, we first calculate the accumulated mean. tmp = cummean(np.array(match_data[key])) # Then interpolate based on the confidences. (Note reversing since np.interp needs increasing arrays) match_data[key] = np.interp(conf[::-1], match_data['conf'][::-1], tmp[::-1])[::-1] # --------------------------------------------- # Done. Instantiate MetricData and return # --------------------------------------------- return DetectionMetricData(recall=rec, precision=prec, confidence=conf, trans_err=match_data['trans_err'], vel_err=match_data['vel_err'], scale_err=match_data['scale_err'], orient_err=match_data['orient_err'], attr_err=match_data['attr_err'])