Пример #1
0
    def _get_proposal_pairs(self, proposals):
        proposal_pairs = []
        for i, proposals_per_image in enumerate(proposals):
            box_subj = proposals_per_image.bbox
            box_obj = proposals_per_image.bbox

            box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
            box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
            proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)

            idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposals_per_image.bbox.device)
            idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposals_per_image.bbox.device)
            proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)

            keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero().view(-1)

            # if we filter non overlap bounding boxes
            if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
                ious = boxlist_iou(proposals_per_image, proposals_per_image).view(-1)
                ious = ious[keep_idx]
                keep_idx = keep_idx[(ious > 0).nonzero().view(-1)]
            proposal_idx_pairs = proposal_idx_pairs[keep_idx]
            proposal_box_pairs = proposal_box_pairs[keep_idx]
            proposal_pairs_per_image = BoxPairList(proposal_box_pairs, proposals_per_image.size, proposals_per_image.mode)
            proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)

            proposal_pairs.append(proposal_pairs_per_image)
        return proposal_pairs
Пример #2
0
    def get_triplets_as_string(self, top_obj: BoxList,
                               top_pred: BoxPairList) -> List[str]:
        """
        Given top detected objects and top predicted relationships, return
        the triplets in human-readable form.
        :param top_obj: BoxList containing top detected objects
        :param top_pred: BoxPairList containing the top detected triplets
        :return: List of triplets (in decreasing score order)
        """
        # num_detected_objects
        obj_indices = top_obj.get_field("labels")

        # 100 x 2 (indices in obj_indices)
        obj_pairs_indices = top_pred.get_field("idx_pairs")

        # 100 (indices in GLOBAL relationship indices list)
        rel_indices = top_pred.get_field("scores").max(1)[1]

        # 100 x 3
        top_triplets = torch.stack(
            (obj_indices[obj_pairs_indices[:, 0]],
             obj_indices[obj_pairs_indices[:, 1]], rel_indices), 1).tolist()

        idx_to_obj = self.data_loader_test.dataset.ind_to_classes
        idx_to_rel = self.data_loader_test.dataset.ind_to_predicates

        # convert integers to labels
        top_triplets_str = []
        for t in top_triplets:
            top_triplets_str.append(idx_to_obj[t[0]] + " " + idx_to_rel[t[2]] +
                                    " " + idx_to_obj[t[1]])

        return top_triplets_str
Пример #3
0
    def _get_proposal_pairs(self, proposals):
        proposal_pairs = []
        for i, proposals_per_image in enumerate(proposals):
            box_subj = proposals_per_image.bbox
            box_obj = proposals_per_image.bbox
            box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
            box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
            proposal_box_pairs = torch.cat(
                (box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)

            idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(
                1, box_obj.shape[0], 1).to(proposals_per_image.bbox.device)
            idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(
                box_subj.shape[0], 1, 1).to(proposals_per_image.bbox.device)
            proposal_idx_pairs = torch.cat(
                (idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)

            non_duplicate_idx = (proposal_idx_pairs[:, 0] !=
                                 proposal_idx_pairs[:, 1]).nonzero()
            proposal_idx_pairs = proposal_idx_pairs[non_duplicate_idx.view(-1)]
            proposal_box_pairs = proposal_box_pairs[non_duplicate_idx.view(-1)]
            proposal_pairs_per_image = BoxPairList(proposal_box_pairs,
                                                   proposals_per_image.size,
                                                   proposals_per_image.mode)
            proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)

            proposal_pairs.append(proposal_pairs_per_image)
        return proposal_pairs
Пример #4
0
    def match_targets_to_proposals(self, proposal, target):
        match_quality_matrix = boxlist_iou(target, proposal)
        temp = []
        target_box_pairs = []
        for i in range(match_quality_matrix.shape[0]):
            for j in range(match_quality_matrix.shape[0]):
                match_i = match_quality_matrix[i].view(1, -1)
                match_j = match_quality_matrix[j].view(-1, 1)
                match_ij = (match_i + match_j) / 2
                match_ij.view(-1)[::match_quality_matrix.shape[1]] = 0
                temp.append(match_ij)
                boxi = target.bbox[i]
                boxj = target.bbox[j]
                box_pair = torch.cat((boxi, boxj), 0)
                target_box_pairs.append(box_pair)

        match_pair_quality_matrix = torch.stack(temp, 0).view(len(temp), -1)
        target_box_pairs = torch.stack(target_box_pairs, 0)
        target_pair = BoxPairList(target_box_pairs, target.size, target.mode)
        target_pair.add_field("labels",
                              target.get_field("pred_labels").view(-1))

        box_subj = proposal.bbox
        box_obj = proposal.bbox
        box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
        box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
        proposal_box_pairs = torch.cat(
            (box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
        proposal_pairs = BoxPairList(proposal_box_pairs, proposal.size,
                                     proposal.mode)

        idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(
            1, box_obj.shape[0], 1).to(proposal.bbox.device)
        idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(
            box_subj.shape[0], 1, 1).to(proposal.bbox.device)
        proposal_idx_pairs = torch.cat(
            (idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
        proposal_pairs.add_field("idx_pairs", proposal_idx_pairs)

        # matched_idxs = self.proposal_matcher(match_quality_matrix)
        matched_idxs = self.proposal_pair_matcher(match_pair_quality_matrix)

        # Fast RCNN only need "labels" field for selecting the targets
        # target = target.copy_with_fields("pred_labels")
        # get the targets corresponding GT for each proposal
        # NB: need to clamp the indices because we can have a single
        # GT in the image, and matched_idxs can be -2, which goes
        # out of bounds
        if self.use_matched_pairs_only and (
                matched_idxs >= 0).sum() > self.minimal_matched_pairs:
            # filter all matched_idxs < 0
            proposal_pairs = proposal_pairs[matched_idxs >= 0]
            matched_idxs = matched_idxs[matched_idxs >= 0]

        matched_targets = target_pair[matched_idxs.clamp(min=0)]
        matched_targets.add_field("matched_idxs", matched_idxs)
        return matched_targets, proposal_pairs
Пример #5
0
 def prepare_boxpairlist(self, boxes, scores, image_shape):
     """
     Returns BoxList from `boxes` and adds probability scores information
     as an extra field
     `boxes` has shape (#detections, 4 * #classes), where each row represents
     a list of predicted bounding boxes for each of the object classes in the
     dataset (including the background class). The detections in each row
     originate from the same object proposal.
     `scores` has shape (#detection, #classes), where each row represents a list
     of object detection confidence scores for each of the object classes in the
     dataset (including the background class). `scores[i, j]`` corresponds to the
     box at `boxes[i, j * 4:(j + 1) * 4]`.
     """
     boxes = boxes.reshape(-1, 8)
     scores = scores.reshape(-1)
     boxlist = BoxPairList(boxes, image_shape, mode="xyxy")
     boxlist.add_field("scores", scores)
     return boxlist
Пример #6
0
    def _fullsample_test(self, proposals):
        """
        This method get all subject-object pairs, and return the proposals.
        Note: this function keeps a state.

        Arguments:
            proposals (list[BoxList])
        """
        proposal_pairs = []
        for i, proposals_per_image in enumerate(proposals):
            box_subj = proposals_per_image.bbox
            box_obj = proposals_per_image.bbox

            box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
            box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
            proposal_box_pairs = torch.cat(
                (box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)

            idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(
                1, box_obj.shape[0], 1).to(proposals_per_image.bbox.device)
            idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(
                box_subj.shape[0], 1, 1).to(proposals_per_image.bbox.device)
            proposal_idx_pairs = torch.cat(
                (idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)

            keep_idx = (proposal_idx_pairs[:, 0] !=
                        proposal_idx_pairs[:, 1]).nonzero().view(-1)

            # if we filter non overlap bounding boxes
            if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
                ious = boxlist_iou(proposals_per_image,
                                   proposals_per_image).view(-1)
                ious = ious[keep_idx]
                keep_idx = keep_idx[(ious > 0).nonzero().view(-1)]
            proposal_idx_pairs = proposal_idx_pairs[keep_idx]
            proposal_box_pairs = proposal_box_pairs[keep_idx]
            proposal_pairs_per_image = BoxPairList(proposal_box_pairs,
                                                   proposals_per_image.size,
                                                   proposals_per_image.mode)
            proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)

            proposal_pairs.append(proposal_pairs_per_image)
        return proposal_pairs