예제 #1
0
    def forward(self, *x):
        x = enforce_singleton(x)
        confidences = []
        locations = []

        x = self.backbond1(x)
        confidence0, location0 = self.compute_header(0, x)
        confidences.append(confidence0)
        locations.append(location0)

        x = self.backbond2(x)
        confidence1, location1 = self.compute_header(1, x)
        confidences.append(confidence1)
        locations.append(location1)

        x = self.backbond3(x)
        confidence2, location2 = self.compute_header(2, x)
        confidences.append(confidence2)
        locations.append(location2)

        x = self.extra(x)
        confidence3, location3 = self.compute_header(3, x)
        confidences.append(confidence3)
        locations.append(location3)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.training:
            return confidences, locations
        else:
            confidences = softmax(confidences, -1)

            locations = decode(locations, self.priors, self.variance)
            return confidences, locations
예제 #2
0
    def forward(self, confidence, locations, target_confidence,
                target_locations):
        """Compute classification loss and smooth l1 loss.

        Args:
            confidence (batch_size, num_priors, num_classes): class predictions.
            locations (batch_size, num_priors, 4): predicted locations.
            target_confidence (batch_size, num_priors): real labels of all the priors.
            target_locations (batch_size, num_priors, 4): real boxes corresponding all the priors.
        """
        num_classes = confidence.size(2)
        num_batch = confidence.size(0)

        confidence_logit = softmax(confidence, -1)
        confidence_logit_probs, confidence_logit_idxs = confidence_logit.max(
            -1)
        probs_mask = confidence_logit_probs > 0.5
        label_mask = confidence_logit_idxs > 0

        pos_target_mask_all = target_confidence > 0
        pos_infer_mask_all = (pos_target_mask_all.float() +
                              probs_mask.float() + label_mask.float() == 3)

        decode_locations_all = decode(
            locations, self.priors, (self.center_variance, self.size_variance))
        decode_target_locations_all = decode(
            target_locations, self.priors,
            (self.center_variance, self.size_variance))
        giou_np = 0.0
        giou = 0.0
        overlaps = 0.0
        num_boxes = 0
        for i in range(num_batch):
            pos_target_mask = pos_target_mask_all[i]
            pos_infer_mask = pos_infer_mask_all[i]
            decode_locations = decode_locations_all[i][pos_infer_mask, :]
            decode_target_locations = decode_target_locations_all[i][
                pos_target_mask, :]
            num_boxes += decode_target_locations.shape[0]
            if decode_target_locations.shape[0] > 0 and decode_locations.shape[
                    0] > 0:
                giou = giou + (1 - (bbox_giou(decode_locations,
                                              decode_target_locations).sum(0) /
                                    decode_target_locations.shape[0])).sum()
                overlaps = overlaps + (-log(
                    clip(jaccard(decode_locations, decode_target_locations),
                         min=1e-8)).sum(0) /
                                       decode_target_locations.shape[0]).sum()
            elif decode_target_locations.shape[
                    0] == 0 and decode_locations.shape[0] == 0:
                pass
            else:
                giou = giou + 1
                overlaps = overlaps - log(to_tensor(1e-8))

        giou = giou / num_boxes
        overlaps = overlaps / num_boxes
        return giou