def keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): """ Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) and add it to the `pred_instances` as a `pred_keypoints` field. Args: pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number of instances in the batch, K is the number of keypoints, and S is the side length of the keypoint heatmap. The values are spatial logits. pred_instances (list[Instances]): A list of N Instances, where N is the number of images. Returns: None. Each element in pred_instances will contain an extra "pred_keypoints" field. The field is a tensor of shape (#instance, K, 3) where the last dimension corresponds to (x, y, score). The scores are larger than 0. """ # flatten all bboxes from all images together (list[Boxes] -> Rx4 tensor) bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits.detach(), bboxes_flat.detach()) num_instances_per_image = [len(i) for i in pred_instances] keypoint_results = keypoint_results[:, :, [0, 1, 3]].split( num_instances_per_image, dim=0) for keypoint_results_per_image, instances_per_image in zip( keypoint_results, pred_instances): # keypoint_results_per_image is (num instances)x(num keypoints)x(x, y, score) instances_per_image.pred_keypoints = keypoint_results_per_image
def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]): """ Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) and add it to the `pred_instances` as a `pred_keypoints` field. Args: pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number of instances in the batch, K is the number of keypoints, and S is the side length of the keypoint heatmap. The values are spatial logits. pred_instances (list[Instances]): A list of N Instances, where N is the number of images. Returns: None. Each element in pred_instances will contain extra "pred_keypoints" and "pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape (#instance, K, 3) where the last dimension corresponds to (x, y, score). The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw keypoint logits as passed to this function. """ # flatten all bboxes from all images together (list[Boxes] -> Rx4 tensor) bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) pred_keypoint_logits = pred_keypoint_logits.detach() keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach()) num_instances_per_image = [len(i) for i in pred_instances] # slicing multiple dimensions with sequences is not supported by torchscript # https://github.com/pytorch/pytorch/issues/43943 idx = torch.tensor([0, 1, 3]) keypoint_results = keypoint_results[:, :, idx].split(num_instances_per_image, dim=0) heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0) for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip( keypoint_results, heatmap_results, pred_instances): # keypoint_results_per_image is (num instances)x(num keypoints)x(x, y, score) # heatmap_results_per_image is (num instances)x(num keypoints)x(side)x(side) instances_per_image.pred_keypoints = keypoint_results_per_image instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image