def target_single_image(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w).
            reg_weight: tensor, same as box_target
        """
        output_h, output_w = feat_shape
        heatmap_channel = self.num_fg

        heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w))
        fake_heatmap = gt_boxes.new_zeros((output_h, output_w))
        box_target = gt_boxes.new_ones(
            (self.wh_planes, output_h, output_w)) * -1
        reg_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))

        if self.wh_area_process == 'log':
            boxes_areas_log = bbox_areas(gt_boxes).log()
        elif self.wh_area_process == 'sqrt':
            boxes_areas_log = bbox_areas(gt_boxes).sqrt()
        else:
            boxes_areas_log = bbox_areas(gt_boxes)
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        if self.wh_area_process == 'norm':
            boxes_area_topk_log[:] = 1.

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        feat_gt_boxes = gt_boxes / self.down_ratio
        feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]],
                                               min=0,
                                               max=output_w - 1)
        feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]],
                                               min=0,
                                               max=output_h - 1)
        feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1],
                            feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0])

        # we calc the center and ignore area based on the gt-boxes of the origin scale
        # no peak will fall between pixels
        ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2,
                                (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2],
                               dim=1) / self.down_ratio).to(torch.int)

        h_radiuses_alpha = (feat_hs / 2. * self.alpha).int()
        w_radiuses_alpha = (feat_ws / 2. * self.alpha).int()
        if self.wh_gaussian and self.alpha != self.beta:
            h_radiuses_beta = (feat_hs / 2. * self.beta).int()
            w_radiuses_beta = (feat_ws / 2. * self.beta).int()

        if not self.wh_gaussian:
            # calculate positive (center) regions
            r1 = (1 - self.beta) / 2
            ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(
                gt_boxes.transpose(0, 1), r1)
            ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [
                torch.round(x.float() / self.down_ratio).int()
                for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s]
            ]
            ctr_x1s, ctr_x2s = [
                torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s]
            ]
            ctr_y1s, ctr_y2s = [
                torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s]
            ]

        # larger boxes have lower priority than small boxes.
        for k in range(boxes_ind.shape[0]):
            cls_id = gt_labels[k] - 1

            fake_heatmap = fake_heatmap.zero_()
            self.draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                        h_radiuses_alpha[k].item(),
                                        w_radiuses_alpha[k].item())
            heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap)

            if self.wh_gaussian:
                if self.alpha != self.beta:
                    fake_heatmap = fake_heatmap.zero_()
                    self.draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                                h_radiuses_beta[k].item(),
                                                w_radiuses_beta[k].item())
                box_target_inds = fake_heatmap > 0
            else:
                ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[
                    k], ctr_x2s[k], ctr_y2s[k]
                box_target_inds = torch.zeros_like(fake_heatmap,
                                                   dtype=torch.uint8)
                box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1

            if self.wh_agnostic:
                box_target[:, box_target_inds] = gt_boxes[k][:, None]
                cls_id = 0
            else:
                box_target[(cls_id * 4):((cls_id + 1) * 4),
                           box_target_inds] = gt_boxes[k][:, None]

            if self.wh_gaussian:
                local_heatmap = fake_heatmap[box_target_inds]
                ct_div = local_heatmap.sum()
                local_heatmap *= boxes_area_topk_log[k]
                reg_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            else:
                reg_weight[cls_id, box_target_inds] = \
                    boxes_area_topk_log[k] / box_target_inds.sum().float()

        return heatmap, box_target, reg_weight
示例#2
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w).
        """
        output_h, output_w = feat_shape
        heatmap_channel = self.num_fg

        heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w))
        fake_heatmap = gt_boxes.new_zeros((output_h, output_w))
        box_target = gt_boxes.new_ones(
            (self.wh_planes, output_h, output_w)) * -1
        wh_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        hm_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        centerness = gt_boxes.new_zeros((1, output_h, output_w))

        if self.wh_area_process == 'log':
            boxes_areas_log = bbox_areas(gt_boxes).log()
        elif self.wh_area_process == 'sqrt':
            boxes_areas_log = bbox_areas(gt_boxes).sqrt()
        else:
            boxes_areas_log = bbox_areas(gt_boxes)
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        if self.wh_area_process == 'norm':
            boxes_area_topk_log[:] = 1.

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        feat_gt_boxes = gt_boxes / self.down_ratio
        feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]],
                                               min=0,
                                               max=output_w - 1)
        feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]],
                                               min=0,
                                               max=output_h - 1)
        feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1],
                            feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0])

        r1 = (1 - self.center_ratio) / 2
        r2 = (1 - self.ignore_ratio) / 2

        # we calc the center and ignore area based on the gt-boxes of the origin scale
        # no peak will fall between pixels
        ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2,
                                (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2],
                               dim=1) / self.down_ratio).to(torch.int)

        if self.hm_center_ratio is None:
            radiuses = torch.clamp(gaussian_radius(
                (feat_hs.ceil(), feat_ws.ceil())),
                                   min=0)
            hw_ratio_sqrt = (feat_hs / feat_ws).sqrt()
            h_radiuses = (radiuses * hw_ratio_sqrt).int()
            w_radiuses = (radiuses / hw_ratio_sqrt).int()
            if self.ct_gaussian:
                radiuses = radiuses.int()
        else:
            h_radiuses = (feat_hs * self.hm_center_ratio).int()
            w_radiuses = (feat_ws * self.hm_center_ratio).int()
            if (self.center_ratio / 2 !=
                    self.hm_center_ratio) and self.wh_heatmap:
                wh_h_radiuses = (feat_hs * self.center_ratio / 2).int()
                wh_w_radiuses = (feat_ws * self.center_ratio / 2).int()

        # calculate positive (center) regions
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose(
            0, 1),
                                                         r1,
                                                         use_round=False)
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [
            torch.round(x / self.down_ratio).int()
            for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s]
        ]
        ctr_x1s, ctr_x2s = [
            torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s]
        ]
        ctr_y1s, ctr_y2s = [
            torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s]
        ]
        ctr_xs_diff, ctr_ys_diff = ctr_x2s - ctr_x1s + 1, ctr_y2s - ctr_y1s + 1

        if self.fill_small:
            are_fill_small = (ctr_ys_diff <= 4) & (ctr_xs_diff <= 4)

        collide_pixels_summary = 0
        # larger boxes have lower priority than small boxes.
        for k in range(boxes_ind.shape[0]):
            cls_id = gt_labels[k] - 1
            ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[
                k], ctr_y2s[k]
            ctr_x_diff, ctr_y_diff = ctr_xs_diff[k], ctr_ys_diff[k]

            if self.fovea_hm or (self.fill_small and are_fill_small[k]):
                ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                    feat_gt_boxes[k], r2, (output_h, output_w))

                if not self.fovea_hm:
                    ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ignore_x1, ignore_y1, ignore_x2, ignore_y2

            fake_heatmap = fake_heatmap.zero_()
            if self.ct_gaussian:
                draw_umich_gaussian(fake_heatmap, ct_ints[k],
                                    radiuses[k].item())
            else:
                draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                       h_radiuses[k].item(),
                                       w_radiuses[k].item())

            if self.fovea_hm:
                # ignore_mask_box is necessary to prevent the ignore areas covering the
                # pos areas of larger boxes
                ignore_mask_box = (heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                                           ignore_x1:ignore_x2 + 1] == 0)
                heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                        ignore_x1:ignore_x2 + 1][ignore_mask_box] = -1
                heatmap[cls_id, ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1
                centerness[0] = torch.max(centerness[0], fake_heatmap)
            else:
                heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap)

            if self.wh_heatmap:
                if self.hm_center_ratio != self.center_ratio / 2:
                    fake_heatmap = fake_heatmap.zero_()
                    draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                           wh_h_radiuses[k].item(),
                                           wh_w_radiuses[k].item())
                box_target_inds = fake_heatmap > 0
            else:
                box_target_inds = torch.zeros_like(fake_heatmap,
                                                   dtype=torch.uint8)
                box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1

            if self.wh_agnostic:
                collide_pixels_summary += (box_target[:, box_target_inds] >
                                           0).sum()

                box_target[:, box_target_inds] = gt_boxes[k][:, None]
            else:
                collide_pixels_summary += (box_target[(
                    cls_id * 4):(cls_id + 1) * 4, box_target_inds] > 0).sum()

                box_target[(cls_id * 4):((cls_id + 1) * 4),
                           box_target_inds] = gt_boxes[k][:, None]

            local_heatmap = fake_heatmap[box_target_inds]
            ct_div = local_heatmap.sum()
            local_heatmap *= boxes_area_topk_log[k]

            if self.wh_agnostic:
                cls_id = 0

            if self.avg_wh_weightv2 and ct_div > 0:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv3 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv4 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            else:
                wh_weight[cls_id, box_target_inds] = \
                    boxes_area_topk_log[k] / box_target_inds.sum().float()

            if self.avg_wh_weightv4:
                wh_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = \
                    boxes_area_topk_log[k]

            if not self.ct_version:
                target_loc = fake_heatmap > 0.9
                hm_target_num = target_loc.sum().float()
                hm_weight[cls_id, target_loc] = 1 / (2 * (hm_target_num - 1))
                hm_weight[cls_id, ct_ints[k, 1].item(),
                          ct_ints[k, 0].item()] = 1 / 2.

        add_summary('box_target', collide_pixels=collide_pixels_summary)
        pos_pixels_summary = (box_target > 0).sum()
        add_summary('box_target', pos_pixels=pos_pixels_summary)
        add_summary('box_target',
                    collide_ratio=collide_pixels_summary /
                    pos_pixels_summary.float())

        return heatmap, box_target, centerness, wh_weight, hm_weight