Esempio n. 1
0
    def __call__(self, match_quality_matrix, batched = 0):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
            pairwise quality between M ground-truth elements and N predicted elements.

        Returns:
            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
            [0, M - 1] or a negative value indicating that prediction i could not
            be matched.
        """
        if match_quality_matrix.numel() == 0:
            # empty targets or proposals not supported during training
            if match_quality_matrix.shape[0] == 0:
                raise ValueError(
                    "No ground-truth boxes available for one of the images "
                    "during training")
            else:
                raise ValueError(
                    "No proposal boxes available for one of the images "
                    "during training")

        # match_quality_matrix is M (gt) x N (predicted)
        # Max over gt elements (dim 0) to find best gt candidate for each prediction

        if match_quality_matrix.is_cuda:
            if batched:
                matches = _C.match_proposals(match_quality_matrix,self.allow_low_quality_matches, self.low_threshold, self.high_threshold)
            else:
                match_quality_matrix_unsqueezed = match_quality_matrix.unsqueeze(0)
                matches = _C.match_proposals(match_quality_matrix_unsqueezed, self.allow_low_quality_matches, self.low_threshold, self.high_threshold).squeeze(0)
        else:       
            matched_vals, matches = match_quality_matrix.max(dim=0)
            if self.allow_low_quality_matches:
                all_matches = matches.clone()

            # Assign candidate matches with low quality to negative (unassigned) values
            below_low_threshold = matched_vals < self.low_threshold
            between_thresholds = (matched_vals >= self.low_threshold) & (
                matched_vals < self.high_threshold
            )
            matches.masked_fill_(below_low_threshold, Matcher.BELOW_LOW_THRESHOLD)
            matches.masked_fill_(between_thresholds, Matcher.BETWEEN_THRESHOLDS)

            if self.allow_low_quality_matches:
                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
        return matches
Esempio n. 2
0
    def __call__(self, match_quality_matrix):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
            pairwise quality between M ground-truth elements and N predicted elements.

        Returns:
            matches (Tensor[float]): an N tensor where N[i] is a matched gt in
            [0, M - 1] or a negative value indicating that prediction i could not
            be matched.
        """
        # match_quality_matrix is M (gt) x N (predicted)

        matches = _C.match_proposals(match_quality_matrix,
                                     self.allow_low_quality_matches,
                                     self.low_threshold, self.high_threshold)
        return matches