def od_sample_detection(od_sample_raw_preds, od_detection_mask_dataset): labels = ["one", "two", "three", "four"] detections = _extract_od_results( _apply_threshold(od_sample_raw_preds[0], threshold=0.001), labels, od_detection_mask_dataset.im_paths[0], ) detections["idx"] = 0 del detections["keypoints"] return detections
def test__extract_od_results(od_sample_raw_preds, od_data_path_labels): """ test that `_extract_od_results` can convert raw preds. """ pred = { k: v.detach().cpu().numpy() for k, v in od_sample_raw_preds[0].items() } out = _extract_od_results(pred, labels=od_data_path_labels, im_path=None) bboxes = out["det_bboxes"] assert type(bboxes[0]) == DetectionBbox assert len(bboxes) == 5 assert out["masks"].shape == (5, 666, 499) assert out["keypoints"].shape == (5, 13, 3)