Esempio n. 1
0
    def get_triplets_as_string(self, top_obj: BoxList,
                               top_pred: BoxPairList) -> List[str]:
        """
        Given top detected objects and top predicted relationships, return
        the triplets in human-readable form.
        :param top_obj: BoxList containing top detected objects
        :param top_pred: BoxPairList containing the top detected triplets
        :return: List of triplets (in decreasing score order)
        """
        # num_detected_objects
        obj_indices = top_obj.get_field("labels")

        # 100 x 2 (indices in obj_indices)
        obj_pairs_indices = top_pred.get_field("idx_pairs")

        # 100 (indices in GLOBAL relationship indices list)
        rel_indices = top_pred.get_field("scores").max(1)[1]

        # 100 x 3
        top_triplets = torch.stack(
            (obj_indices[obj_pairs_indices[:, 0]],
             obj_indices[obj_pairs_indices[:, 1]], rel_indices), 1).tolist()

        idx_to_obj = self.data_loader_test.dataset.ind_to_classes
        idx_to_rel = self.data_loader_test.dataset.ind_to_predicates

        # convert integers to labels
        top_triplets_str = []
        for t in top_triplets:
            top_triplets_str.append(idx_to_obj[t[0]] + " " + idx_to_rel[t[2]] +
                                    " " + idx_to_obj[t[1]])

        return top_triplets_str
Esempio n. 2
0
    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)
        logits = BoxList.get_field("logits").reshape(-1, num_classes)
        features = boxlist.get_field("features")

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            features_j = features[inds]
            boxes_j = boxes[inds, j * 4:(j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class.add_field("features", features_j)
            boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels",
                torch.full((num_labels, ), j, dtype=torch.int64,
                           device=device))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(),
                number_of_detections - self.detections_per_img + 1)
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result
    def filter_results_nm(self, boxlist, num_classes, thresh=0.05):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS). Similar to Neural-Motif Network
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)
        logits = boxlist.get_field("logits").reshape(-1, num_classes)
        features = boxlist.get_field("features")

        valid_cls = (scores[:, 1:].max(0)[0] > thresh).nonzero() + 1

        nms_mask = scores.clone()
        nms_mask.zero_()

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in valid_cls.view(-1).cpu():
            scores_j = scores[:, j]
            boxes_j = boxes[:, j * 4:(j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class.add_field(
                "idxs",
                torch.arange(0, scores.shape[0]).long())
            # boxlist_for_class = boxlist_nms(
            #     boxlist_for_class, self.nms
            # )
            boxlist_for_class = boxlist_nms(boxlist_for_class, 0.3)
            nms_mask[:, j][boxlist_for_class.get_field("idxs")] = 1

            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels",
                torch.full((num_labels, ), j, dtype=torch.int64,
                           device=device))
            result.append(boxlist_for_class)

        dists_all = nms_mask * scores

        # filter duplicate boxes
        scores_pre, labels_pre = dists_all.max(1)
        inds_all = scores_pre.nonzero()
        assert inds_all.dim() != 0
        inds_all = inds_all.squeeze(1)

        labels_all = labels_pre[inds_all]
        scores_all = scores_pre[inds_all]
        features_all = features[inds_all]
        logits_all = logits[inds_all]

        box_inds_all = inds_all * scores.shape[1] + labels_all
        result = BoxList(boxlist.bbox.view(-1, 4)[box_inds_all],
                         boxlist.size,
                         mode="xyxy")
        result.add_field("labels", labels_all)
        result.add_field("scores", scores_all)
        result.add_field("logits", logits_all)
        result.add_field("features", features_all)
        number_of_detections = len(result)

        vs, idx = torch.sort(scores_all, dim=0, descending=True)
        idx = idx[vs > thresh]
        if self.detections_per_img < idx.size(0):
            idx = idx[:self.detections_per_img]
        result = result[idx]
        return result