def __call__(self, proposals, bbx, cat, iscrowd): """Match proposals to ground truth boxes Parameters ---------- proposals : PackedSequence A sequence of N tensors with shapes P_i x 4 containing bounding box proposals, entries can be None bbx : sequence of torch.Tensor A sequence of N tensors with shapes K_i x 4 containing ground truth bounding boxes, entries can be None cat : sequence of torch.Tensor A sequence of N tensors with shapes K_i containing ground truth instance -> category mappings, entries can be None iscrowd : sequence of torch.Tensor Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes (shapes K_i x 4), entries can be None Returns ------- out_proposals : PackedSequence A sequence of N tensors with shapes S_i x 4 containing the non-void bounding box proposals, entries are None for images that do not contain any non-void proposal match : PackedSequence A sequence of matching results with shape S_i, with the following semantic: - match[i, j] == -1: the j-th anchor in image i is negative - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k] """ out_proposals = [] match = [] for proposals_i, bbx_i, cat_i, iscrowd_i in zip( proposals, bbx, cat, iscrowd): try: # Append proposals to ground truth bounding boxes before proceeding if bbx_i is not None and proposals_i is not None: proposals_i = torch.cat([bbx_i, proposals_i], dim=0) elif bbx_i is not None: proposals_i = bbx_i else: raise Empty # Optionally check overlap with void if self.void_threshold != 0 and iscrowd_i is not None: if iscrowd_i.dtype == torch.uint8: overlap = mask_overlap(proposals_i, iscrowd_i) else: overlap = bbx_overlap(proposals_i, iscrowd_i) overlap, _ = overlap.max(dim=1) valid = overlap < self.void_threshold proposals_i = proposals_i[valid] if proposals_i.size(0) == 0: raise Empty # Find positives and negatives based on IoU if bbx_i is not None: iou = ious(proposals_i, bbx_i) best_iou, best_gt = iou.max(dim=1) pos_idx = best_iou >= self.pos_threshold neg_idx = (best_iou >= self.neg_threshold_lo) & ( best_iou < self.neg_threshold_hi) else: # No ground truth boxes: all proposals that are non-void are negative pos_idx = proposals_i.new_zeros(proposals_i.size(0), dtype=torch.uint8) neg_idx = proposals_i.new_ones(proposals_i.size(0), dtype=torch.uint8) # Check that there are still some non-voids and do sub-sampling if not pos_idx.any().item() and not neg_idx.any().item(): raise Empty pos_idx, neg_idx = self._subsample(pos_idx, neg_idx) # Gather selected proposals out_proposals_i = proposals_i[torch.cat([pos_idx, neg_idx])] # Save matching match_i = out_proposals_i.new_full((out_proposals_i.size(0), ), -1, dtype=torch.long) match_i[:pos_idx.numel()] = best_gt[pos_idx] # Save to output out_proposals.append(out_proposals_i) match.append(match_i) except Empty: out_proposals.append(None) match.append(None) return PackedSequence(out_proposals), PackedSequence(match)
def __call__(self, anchors, bbx, iscrowd, valid_size): """Match anchors to ground truth boxes Parameters ---------- anchors : torch.Tensor Tensors of anchor bounding boxes with shapes M x 4 bbx : sequence of torch.Tensor Sequence of N tensors of ground truth bounding boxes with shapes M_i x 4, entries can be None iscrowd : sequence of torch.Tensor Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes (shapes K_i x 4), entries can be None valid_size : list of tuple of int List of N valid image sizes in input coordinates Returns ------- match : torch.Tensor Tensor of matching results with shape N x M, with the following semantic: - match[i, j] == -2: the j-th anchor in image i is void - match[i, j] == -1: the j-th anchor in image i is negative - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k] """ match = [] for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size): # Default labels: everything is void match_i = anchors.new_full((anchors.size(0), ), -2, dtype=torch.long) try: # Find anchors that are entirely within the original image area valid = self._is_inside(anchors, valid_size_i) # Check overlap with crowd if self.void_threshold != 0 and iscrowd_i is not None: if iscrowd_i.dtype == torch.uint8: overlap = mask_overlap(anchors, iscrowd_i) else: overlap = bbx_overlap(anchors, iscrowd_i) overlap, _ = overlap.max(dim=1) valid = valid & (overlap < self.void_threshold) if not valid.any().item(): raise Empty valid_anchors = anchors[valid] if bbx_i is not None: max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0)) max_a2g_idx = bbx_i.new_full((valid_anchors.size(0), ), -1, dtype=torch.long) max_g2a_iou = [] max_g2a_idx = [] # Calculate assignments iteratively to save memory for j, bbx_i_j in enumerate( torch.split(bbx_i, CHUNK_SIZE, dim=0)): iou = ious(valid_anchors, bbx_i_j) # Anchor -> GT iou_max, iou_idx = iou.max(dim=1) replace_idx = iou_max > max_a2g_iou max_a2g_idx[replace_idx] = iou_idx[ replace_idx] + j * CHUNK_SIZE max_a2g_iou[replace_idx] = iou_max[replace_idx] # GT -> Anchor max_g2a_iou_j, max_g2a_idx_j = iou.transpose( 0, 1).max(dim=1) max_g2a_iou.append(max_g2a_iou_j) max_g2a_idx.append(max_g2a_idx_j) del iou max_g2a_iou = torch.cat(max_g2a_iou, dim=0) max_g2a_idx = torch.cat(max_g2a_idx, dim=0) a2g_pos = max_a2g_iou >= self.pos_threshold a2g_neg = max_a2g_iou < self.neg_threshold g2a_pos = max_g2a_iou > 0 valid_match = valid_anchors.new_full( (valid_anchors.size(0), ), -2, dtype=torch.long) valid_match[a2g_pos] = max_a2g_idx[a2g_pos] valid_match[a2g_neg] = -1 valid_match[ max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze() else: # No ground truth boxes for this image: everything that is not void is negative valid_match = valid_anchors.new_full( (valid_anchors.size(0), ), -1, dtype=torch.long) # Subsample positives and negatives self._subsample(valid_match) match_i[valid] = valid_match except Empty: pass match.append(match_i) return torch.stack(match, dim=0)
def zero_matcher(self, anchors, bbx, iscrowd, valid_size): match = [] for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size): # Default labels: everything is void match_i = anchors.new_full((anchors.size(0), ), -2, dtype=torch.long) try: # Find anchors that are entirely within the original image area valid = self._is_inside(anchors, valid_size_i) # Check overlap with crowd if self.void_threshold != 0 and iscrowd_i is not None: if iscrowd_i.dtype == torch.uint8: overlap = mask_overlap(anchors, iscrowd_i) else: overlap = bbx_overlap(anchors, iscrowd_i) overlap, _ = overlap.max(dim=1) valid = valid & (overlap < self.void_threshold) if not valid.any().item(): raise Empty valid_anchors = anchors[valid] if bbx_i is not None: max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0)) max_a2g_idx = bbx_i.new_full((valid_anchors.size(0), ), -1, dtype=torch.long) max_g2a_iou = [] max_g2a_idx = [] # Calculate assignments iteratively to save memory for j, bbx_i_j in enumerate( torch.split(bbx_i, CHUNK_SIZE, dim=0)): iou = ious(valid_anchors, bbx_i_j) # Anchor -> GT iou_max, iou_idx = iou.max(dim=1) replace_idx = iou_max > max_a2g_iou max_a2g_idx[replace_idx] = iou_idx[ replace_idx] + j * CHUNK_SIZE max_a2g_iou[replace_idx] = iou_max[replace_idx] # GT -> Anchor max_g2a_iou_j, max_g2a_idx_j = iou.max(dim=0) max_g2a_iou.append(max_g2a_iou_j) max_g2a_idx.append(max_g2a_idx_j) # del iou max_g2a_iou = torch.cat(max_g2a_iou, dim=0) max_g2a_idx = torch.cat(max_g2a_idx, dim=0) a2g_pos = max_a2g_iou >= self.pos_threshold # higher than thd as positive label a2g_neg = max_a2g_iou < self.neg_threshold # lower than thd as negative label g2a_pos = max_g2a_iou > 0 # highest IOU valid_match = valid_anchors.new_full( (valid_anchors.size(0), ), -2, dtype=torch.long) valid_match[a2g_pos] = max_a2g_idx[a2g_pos] valid_match[a2g_neg] = -1 valid_match[ max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze() else: # No ground truth boxes for this image: everything that is not void is negative valid_match = valid_anchors.new_full( (valid_anchors.size(0), ), -1, dtype=torch.long) # Subsample positives and negatives self._subsample(valid_match) match_i[valid] = valid_match except Empty: pass match.append(match_i) return torch.stack(match, dim=0)