示例#1
0
    def __call__(self, proposals, bbx, cat, iscrowd):
        """Match proposals to ground truth boxes

        Parameters
        ----------
        proposals : PackedSequence
            A sequence of N tensors with shapes P_i x 4 containing bounding box proposals, entries can be None
        bbx : sequence of torch.Tensor
            A sequence of N tensors with shapes K_i x 4 containing ground truth bounding boxes, entries can be None
        cat : sequence of torch.Tensor
            A sequence of N tensors with shapes K_i containing ground truth instance -> category mappings, entries can
            be None
        iscrowd : sequence of torch.Tensor
            Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes
            (shapes K_i x 4), entries can be None

        Returns
        -------
        out_proposals : PackedSequence
            A sequence of N tensors with shapes S_i x 4 containing the non-void bounding box proposals, entries are None
            for images that do not contain any non-void proposal
        match : PackedSequence
            A sequence of matching results with shape S_i, with the following semantic:
              - match[i, j] == -1: the j-th anchor in image i is negative
              - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k]
        """
        out_proposals = []
        match = []

        for proposals_i, bbx_i, cat_i, iscrowd_i in zip(
                proposals, bbx, cat, iscrowd):
            try:
                # Append proposals to ground truth bounding boxes before proceeding
                if bbx_i is not None and proposals_i is not None:
                    proposals_i = torch.cat([bbx_i, proposals_i], dim=0)
                elif bbx_i is not None:
                    proposals_i = bbx_i
                else:
                    raise Empty

                # Optionally check overlap with void
                if self.void_threshold != 0 and iscrowd_i is not None:
                    if iscrowd_i.dtype == torch.uint8:
                        overlap = mask_overlap(proposals_i, iscrowd_i)
                    else:
                        overlap = bbx_overlap(proposals_i, iscrowd_i)
                        overlap, _ = overlap.max(dim=1)

                    valid = overlap < self.void_threshold
                    proposals_i = proposals_i[valid]

                if proposals_i.size(0) == 0:
                    raise Empty

                # Find positives and negatives based on IoU
                if bbx_i is not None:
                    iou = ious(proposals_i, bbx_i)
                    best_iou, best_gt = iou.max(dim=1)

                    pos_idx = best_iou >= self.pos_threshold
                    neg_idx = (best_iou >= self.neg_threshold_lo) & (
                        best_iou < self.neg_threshold_hi)
                else:
                    # No ground truth boxes: all proposals that are non-void are negative
                    pos_idx = proposals_i.new_zeros(proposals_i.size(0),
                                                    dtype=torch.uint8)
                    neg_idx = proposals_i.new_ones(proposals_i.size(0),
                                                   dtype=torch.uint8)

                # Check that there are still some non-voids and do sub-sampling
                if not pos_idx.any().item() and not neg_idx.any().item():
                    raise Empty
                pos_idx, neg_idx = self._subsample(pos_idx, neg_idx)

                # Gather selected proposals
                out_proposals_i = proposals_i[torch.cat([pos_idx, neg_idx])]

                # Save matching
                match_i = out_proposals_i.new_full((out_proposals_i.size(0), ),
                                                   -1,
                                                   dtype=torch.long)
                match_i[:pos_idx.numel()] = best_gt[pos_idx]

                # Save to output
                out_proposals.append(out_proposals_i)
                match.append(match_i)
            except Empty:
                out_proposals.append(None)
                match.append(None)

        return PackedSequence(out_proposals), PackedSequence(match)
示例#2
0
    def __call__(self, anchors, bbx, iscrowd, valid_size):
        """Match anchors to ground truth boxes

        Parameters
        ----------
        anchors : torch.Tensor
            Tensors of anchor bounding boxes with shapes M x 4
        bbx : sequence of torch.Tensor
            Sequence of N tensors of ground truth bounding boxes with shapes M_i x 4, entries can be None
        iscrowd : sequence of torch.Tensor
            Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes
            (shapes K_i x 4), entries can be None
        valid_size : list of tuple of int
            List of N valid image sizes in input coordinates

        Returns
        -------
        match : torch.Tensor
            Tensor of matching results with shape N x M, with the following semantic:
              - match[i, j] == -2: the j-th anchor in image i is void
              - match[i, j] == -1: the j-th anchor in image i is negative
              - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k]
        """
        match = []
        for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size):
            # Default labels: everything is void
            match_i = anchors.new_full((anchors.size(0), ),
                                       -2,
                                       dtype=torch.long)

            try:
                # Find anchors that are entirely within the original image area
                valid = self._is_inside(anchors, valid_size_i)

                # Check overlap with crowd
                if self.void_threshold != 0 and iscrowd_i is not None:
                    if iscrowd_i.dtype == torch.uint8:
                        overlap = mask_overlap(anchors, iscrowd_i)
                    else:
                        overlap = bbx_overlap(anchors, iscrowd_i)
                        overlap, _ = overlap.max(dim=1)

                    valid = valid & (overlap < self.void_threshold)

                if not valid.any().item():
                    raise Empty

                valid_anchors = anchors[valid]

                if bbx_i is not None:
                    max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0))
                    max_a2g_idx = bbx_i.new_full((valid_anchors.size(0), ),
                                                 -1,
                                                 dtype=torch.long)
                    max_g2a_iou = []
                    max_g2a_idx = []

                    # Calculate assignments iteratively to save memory
                    for j, bbx_i_j in enumerate(
                            torch.split(bbx_i, CHUNK_SIZE, dim=0)):
                        iou = ious(valid_anchors, bbx_i_j)

                        # Anchor -> GT
                        iou_max, iou_idx = iou.max(dim=1)
                        replace_idx = iou_max > max_a2g_iou

                        max_a2g_idx[replace_idx] = iou_idx[
                            replace_idx] + j * CHUNK_SIZE
                        max_a2g_iou[replace_idx] = iou_max[replace_idx]

                        # GT -> Anchor
                        max_g2a_iou_j, max_g2a_idx_j = iou.transpose(
                            0, 1).max(dim=1)
                        max_g2a_iou.append(max_g2a_iou_j)
                        max_g2a_idx.append(max_g2a_idx_j)

                        del iou

                    max_g2a_iou = torch.cat(max_g2a_iou, dim=0)
                    max_g2a_idx = torch.cat(max_g2a_idx, dim=0)

                    a2g_pos = max_a2g_iou >= self.pos_threshold
                    a2g_neg = max_a2g_iou < self.neg_threshold
                    g2a_pos = max_g2a_iou > 0

                    valid_match = valid_anchors.new_full(
                        (valid_anchors.size(0), ), -2, dtype=torch.long)
                    valid_match[a2g_pos] = max_a2g_idx[a2g_pos]
                    valid_match[a2g_neg] = -1
                    valid_match[
                        max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze()
                else:
                    # No ground truth boxes for this image: everything that is not void is negative
                    valid_match = valid_anchors.new_full(
                        (valid_anchors.size(0), ), -1, dtype=torch.long)

                # Subsample positives and negatives
                self._subsample(valid_match)

                match_i[valid] = valid_match
            except Empty:
                pass

            match.append(match_i)

        return torch.stack(match, dim=0)
示例#3
0
    def zero_matcher(self, anchors, bbx, iscrowd, valid_size):
        match = []
        for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size):
            # Default labels: everything is void
            match_i = anchors.new_full((anchors.size(0), ),
                                       -2,
                                       dtype=torch.long)

            try:
                # Find anchors that are entirely within the original image area
                valid = self._is_inside(anchors, valid_size_i)

                # Check overlap with crowd
                if self.void_threshold != 0 and iscrowd_i is not None:
                    if iscrowd_i.dtype == torch.uint8:
                        overlap = mask_overlap(anchors, iscrowd_i)
                    else:
                        overlap = bbx_overlap(anchors, iscrowd_i)
                        overlap, _ = overlap.max(dim=1)

                    valid = valid & (overlap < self.void_threshold)

                if not valid.any().item():
                    raise Empty

                valid_anchors = anchors[valid]

                if bbx_i is not None:
                    max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0))
                    max_a2g_idx = bbx_i.new_full((valid_anchors.size(0), ),
                                                 -1,
                                                 dtype=torch.long)

                    max_g2a_iou = []
                    max_g2a_idx = []

                    # Calculate assignments iteratively to save memory
                    for j, bbx_i_j in enumerate(
                            torch.split(bbx_i, CHUNK_SIZE, dim=0)):
                        iou = ious(valid_anchors, bbx_i_j)

                        # Anchor -> GT
                        iou_max, iou_idx = iou.max(dim=1)
                        replace_idx = iou_max > max_a2g_iou

                        max_a2g_idx[replace_idx] = iou_idx[
                            replace_idx] + j * CHUNK_SIZE
                        max_a2g_iou[replace_idx] = iou_max[replace_idx]

                        # GT -> Anchor
                        max_g2a_iou_j, max_g2a_idx_j = iou.max(dim=0)

                        max_g2a_iou.append(max_g2a_iou_j)
                        max_g2a_idx.append(max_g2a_idx_j)

                        # del iou

                    max_g2a_iou = torch.cat(max_g2a_iou, dim=0)
                    max_g2a_idx = torch.cat(max_g2a_idx, dim=0)

                    a2g_pos = max_a2g_iou >= self.pos_threshold  # higher than thd as positive label

                    a2g_neg = max_a2g_iou < self.neg_threshold  # lower than thd as negative label

                    g2a_pos = max_g2a_iou > 0  # highest IOU

                    valid_match = valid_anchors.new_full(
                        (valid_anchors.size(0), ), -2, dtype=torch.long)
                    valid_match[a2g_pos] = max_a2g_idx[a2g_pos]
                    valid_match[a2g_neg] = -1
                    valid_match[
                        max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze()

                else:
                    # No ground truth boxes for this image: everything that is not void is negative
                    valid_match = valid_anchors.new_full(
                        (valid_anchors.size(0), ), -1, dtype=torch.long)

                # Subsample positives and negatives
                self._subsample(valid_match)
                match_i[valid] = valid_match

            except Empty:
                pass

            match.append(match_i)
        return torch.stack(match, dim=0)