Exemplo n.º 1
0
    def topk_score(scores, K=40):
        r"""
        get top point in score map.

        Args:
            scores (Tensor): scores map.
            K (int): top K in scores map.
        """
        batch, channel, height, width = scores.shape

        # get topk score and its index in every H x W(channel dim) feature map
        topk_scores, topk_inds = torch.topk(scores.reshape(batch, channel, -1),
                                            K)

        topk_inds = topk_inds % (height * width)
        topk_ys = (topk_inds / width).int().float()
        topk_xs = (topk_inds % width).int().float()

        # get all topk in in a batch
        topk_score, index = torch.topk(topk_scores.reshape(batch, -1), K)
        # div by K because index is grouped by K(C x K shape)
        topk_clses = (index / K).int()
        topk_inds = gather_feature(topk_inds.view(batch, -1, 1),
                                   index).reshape(batch, K)
        topk_ys = gather_feature(topk_ys.reshape(batch, -1, 1),
                                 index).reshape(batch, K)
        topk_xs = gather_feature(topk_xs.reshape(batch, -1, 1),
                                 index).reshape(batch, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
Exemplo n.º 2
0
 def forward(self, output, mask, index, target):
     pred = gather_feature(output, index, use_transform=True)
     mask = mask.unsqueeze(dim=2).expand_as(pred).float()
     # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
     loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
     loss = loss / (mask.sum() + 1e-4)
     return loss
Exemplo n.º 3
0
    def decode(fmap, wh, reg=None, cat_spec_wh=False, K=100):
        r"""
        decode feature maps, width height, regression to detections results.

        Args:
            fmap (Tensor): input feature map.
            wh (Tensor): tensor represents (width, height).
            reg (Tensor): tensor represents regression.
            cat_spec_wh (bool): whether reshape wh tensor.
            K (int): top k value in score map.
        """
        batch, channel, height, width = fmap.shape

        fmap = CenterNetDecoder.pseudo_nms(fmap)

        scores, index, clses, ys, xs = CenterNetDecoder.topk_score(fmap, K=K)
        if reg is not None:
            reg = gather_feature(reg, index, use_transform=True)
            reg = reg.reshape(batch, K, 2)
            xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
            ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
        else:
            xs = xs.view(batch, K, 1) + 0.5
            ys = ys.view(batch, K, 1) + 0.5
        wh = gather_feature(wh, index, use_transform=True)

        if cat_spec_wh:
            wh = wh.view(batch, K, channel, 2)
            clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1,
                                                          2).long()
            wh = wh.gather(2, clses_ind).reshape(batch, K, 2)
        else:
            wh = wh.reshape(batch, K, 2)

        clses = clses.reshape(batch, K, 1).float()
        scores = scores.reshape(batch, K, 1)

        half_w = wh[..., 0:1] / 2
        half_h = wh[..., 1:2] / 2
        bboxes = torch.cat(
            [xs - half_w, ys - half_h, xs + half_w, ys + half_h], dim=2)

        detections = (bboxes, scores, clses)

        return detections