def topk_score(scores, K=40): r""" get top point in score map. Args: scores (Tensor): scores map. K (int): top K in scores map. """ batch, channel, height, width = scores.shape # get topk score and its index in every H x W(channel dim) feature map topk_scores, topk_inds = torch.topk(scores.reshape(batch, channel, -1), K) topk_inds = topk_inds % (height * width) topk_ys = (topk_inds / width).int().float() topk_xs = (topk_inds % width).int().float() # get all topk in in a batch topk_score, index = torch.topk(topk_scores.reshape(batch, -1), K) # div by K because index is grouped by K(C x K shape) topk_clses = (index / K).int() topk_inds = gather_feature(topk_inds.view(batch, -1, 1), index).reshape(batch, K) topk_ys = gather_feature(topk_ys.reshape(batch, -1, 1), index).reshape(batch, K) topk_xs = gather_feature(topk_xs.reshape(batch, -1, 1), index).reshape(batch, K) return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def forward(self, output, mask, index, target): pred = gather_feature(output, index, use_transform=True) mask = mask.unsqueeze(dim=2).expand_as(pred).float() # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') loss = F.l1_loss(pred * mask, target * mask, reduction='sum') loss = loss / (mask.sum() + 1e-4) return loss
def decode(fmap, wh, reg=None, cat_spec_wh=False, K=100): r""" decode feature maps, width height, regression to detections results. Args: fmap (Tensor): input feature map. wh (Tensor): tensor represents (width, height). reg (Tensor): tensor represents regression. cat_spec_wh (bool): whether reshape wh tensor. K (int): top k value in score map. """ batch, channel, height, width = fmap.shape fmap = CenterNetDecoder.pseudo_nms(fmap) scores, index, clses, ys, xs = CenterNetDecoder.topk_score(fmap, K=K) if reg is not None: reg = gather_feature(reg, index, use_transform=True) reg = reg.reshape(batch, K, 2) xs = xs.view(batch, K, 1) + reg[:, :, 0:1] ys = ys.view(batch, K, 1) + reg[:, :, 1:2] else: xs = xs.view(batch, K, 1) + 0.5 ys = ys.view(batch, K, 1) + 0.5 wh = gather_feature(wh, index, use_transform=True) if cat_spec_wh: wh = wh.view(batch, K, channel, 2) clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long() wh = wh.gather(2, clses_ind).reshape(batch, K, 2) else: wh = wh.reshape(batch, K, 2) clses = clses.reshape(batch, K, 1).float() scores = scores.reshape(batch, K, 1) half_w = wh[..., 0:1] / 2 half_h = wh[..., 1:2] / 2 bboxes = torch.cat( [xs - half_w, ys - half_h, xs + half_w, ys + half_h], dim=2) detections = (bboxes, scores, clses) return detections