コード例 #1
0
    def _draw_heatmap(self, heatmap, center, h, w):
        if self.gt_plus_dot5:
            ct_int = (center + 0.5).to(torch.int)
        else:
            ct_int = center.to(torch.int)
        if self.use_truncate_gaussia:
            if self.use_tight_gauusia:
                h_radius = (h / 2).int().item()
                w_radius = (w / 2).int().item()
            else:
                radius = gaussian_radius((h.ceil(), w.ceil()))
                radius = max(0, int(radius.item()))
                h_radius = (radius * (h / w).sqrt()).int().item()
                w_radius = (radius * (w / h).sqrt()).int().item()
            draw_truncate_gaussian(heatmap, ct_int, h_radius, w_radius)
        else:
            radius = gaussian_radius((h.ceil(), w.ceil()))
            radius = max(0, int(radius.item()))
            draw_umich_gaussian(heatmap, ct_int, radius)

        return ct_int
コード例 #2
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w).
        """
        output_h, output_w = feat_shape
        heatmap_channel = self.num_fg

        heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w))
        fake_heatmap = gt_boxes.new_zeros((output_h, output_w))
        box_target = gt_boxes.new_ones(
            (self.wh_planes, output_h, output_w)) * -1
        wh_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        hm_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        centerness = gt_boxes.new_zeros((1, output_h, output_w))

        if self.wh_area_process == 'log':
            boxes_areas_log = bbox_areas(gt_boxes).log()
        elif self.wh_area_process == 'sqrt':
            boxes_areas_log = bbox_areas(gt_boxes).sqrt()
        else:
            boxes_areas_log = bbox_areas(gt_boxes)
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        if self.wh_area_process == 'norm':
            boxes_area_topk_log[:] = 1.

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        feat_gt_boxes = gt_boxes / self.down_ratio
        feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]],
                                               min=0,
                                               max=output_w - 1)
        feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]],
                                               min=0,
                                               max=output_h - 1)
        feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1],
                            feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0])

        r1 = (1 - self.center_ratio) / 2
        r2 = (1 - self.ignore_ratio) / 2

        # we calc the center and ignore area based on the gt-boxes of the origin scale
        # no peak will fall between pixels
        ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2,
                                (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2],
                               dim=1) / self.down_ratio).to(torch.int)

        if self.hm_center_ratio is None:
            radiuses = torch.clamp(gaussian_radius(
                (feat_hs.ceil(), feat_ws.ceil())),
                                   min=0)
            hw_ratio_sqrt = (feat_hs / feat_ws).sqrt()
            h_radiuses = (radiuses * hw_ratio_sqrt).int()
            w_radiuses = (radiuses / hw_ratio_sqrt).int()
            if self.ct_gaussian:
                radiuses = radiuses.int()
        else:
            h_radiuses = (feat_hs * self.hm_center_ratio).int()
            w_radiuses = (feat_ws * self.hm_center_ratio).int()
            if (self.center_ratio / 2 !=
                    self.hm_center_ratio) and self.wh_heatmap:
                wh_h_radiuses = (feat_hs * self.center_ratio / 2).int()
                wh_w_radiuses = (feat_ws * self.center_ratio / 2).int()

        # calculate positive (center) regions
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose(
            0, 1),
                                                         r1,
                                                         use_round=False)
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [
            torch.round(x / self.down_ratio).int()
            for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s]
        ]
        ctr_x1s, ctr_x2s = [
            torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s]
        ]
        ctr_y1s, ctr_y2s = [
            torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s]
        ]
        ctr_xs_diff, ctr_ys_diff = ctr_x2s - ctr_x1s + 1, ctr_y2s - ctr_y1s + 1

        if self.fill_small:
            are_fill_small = (ctr_ys_diff <= 4) & (ctr_xs_diff <= 4)

        collide_pixels_summary = 0
        # larger boxes have lower priority than small boxes.
        for k in range(boxes_ind.shape[0]):
            cls_id = gt_labels[k] - 1
            ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[
                k], ctr_y2s[k]
            ctr_x_diff, ctr_y_diff = ctr_xs_diff[k], ctr_ys_diff[k]

            if self.fovea_hm or (self.fill_small and are_fill_small[k]):
                ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                    feat_gt_boxes[k], r2, (output_h, output_w))

                if not self.fovea_hm:
                    ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ignore_x1, ignore_y1, ignore_x2, ignore_y2

            fake_heatmap = fake_heatmap.zero_()
            if self.ct_gaussian:
                draw_umich_gaussian(fake_heatmap, ct_ints[k],
                                    radiuses[k].item())
            else:
                draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                       h_radiuses[k].item(),
                                       w_radiuses[k].item())

            if self.fovea_hm:
                # ignore_mask_box is necessary to prevent the ignore areas covering the
                # pos areas of larger boxes
                ignore_mask_box = (heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                                           ignore_x1:ignore_x2 + 1] == 0)
                heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                        ignore_x1:ignore_x2 + 1][ignore_mask_box] = -1
                heatmap[cls_id, ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1
                centerness[0] = torch.max(centerness[0], fake_heatmap)
            else:
                heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap)

            if self.wh_heatmap:
                if self.hm_center_ratio != self.center_ratio / 2:
                    fake_heatmap = fake_heatmap.zero_()
                    draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                           wh_h_radiuses[k].item(),
                                           wh_w_radiuses[k].item())
                box_target_inds = fake_heatmap > 0
            else:
                box_target_inds = torch.zeros_like(fake_heatmap,
                                                   dtype=torch.uint8)
                box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1

            if self.wh_agnostic:
                collide_pixels_summary += (box_target[:, box_target_inds] >
                                           0).sum()

                box_target[:, box_target_inds] = gt_boxes[k][:, None]
            else:
                collide_pixels_summary += (box_target[(
                    cls_id * 4):(cls_id + 1) * 4, box_target_inds] > 0).sum()

                box_target[(cls_id * 4):((cls_id + 1) * 4),
                           box_target_inds] = gt_boxes[k][:, None]

            local_heatmap = fake_heatmap[box_target_inds]
            ct_div = local_heatmap.sum()
            local_heatmap *= boxes_area_topk_log[k]

            if self.wh_agnostic:
                cls_id = 0

            if self.avg_wh_weightv2 and ct_div > 0:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv3 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv4 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            else:
                wh_weight[cls_id, box_target_inds] = \
                    boxes_area_topk_log[k] / box_target_inds.sum().float()

            if self.avg_wh_weightv4:
                wh_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = \
                    boxes_area_topk_log[k]

            if not self.ct_version:
                target_loc = fake_heatmap > 0.9
                hm_target_num = target_loc.sum().float()
                hm_weight[cls_id, target_loc] = 1 / (2 * (hm_target_num - 1))
                hm_weight[cls_id, ct_ints[k, 1].item(),
                          ct_ints[k, 0].item()] = 1 / 2.

        add_summary('box_target', collide_pixels=collide_pixels_summary)
        pos_pixels_summary = (box_target > 0).sum()
        add_summary('box_target', pos_pixels=pos_pixels_summary)
        add_summary('box_target',
                    collide_ratio=collide_pixels_summary /
                    pos_pixels_summary.float())

        return heatmap, box_target, centerness, wh_weight, hm_weight
コード例 #3
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            wh: tensor, tensor <=> img, (max_obj, 2).
            reg_mask: tensor, tensor <=> img, (max_obj,).
            ind: tensor, tensor <=> img, (max_obj,).
            reg: tensor, tensor <=> img, (max_obj, 2).
            center_location: tensor or None, tensor <=> img, (max_obj, 2).
        """
        output_h, output_w = feat_shape
        heatmap = gt_boxes.new_zeros((self.num_fg, output_h, output_w))
        wh = gt_boxes.new_zeros((self.max_objs, 2))
        reg_mask = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.uint8)
        ind = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.long)

        reg, center_location = None, None
        if self.use_reg_offset:
            reg = gt_boxes.new_zeros((self.max_objs, 2))
        if self.use_giou:
            center_location = gt_boxes.new_zeros((self.max_objs, 2))

        gt_boxes /= self.down_ratio
        gt_boxes[:, [0, 2]] = torch.clamp(gt_boxes[:, [0, 2]], 0, output_w - 1)
        gt_boxes[:, [1, 3]] = torch.clamp(gt_boxes[:, [1, 3]], 0, output_h - 1)
        hs, ws = (gt_boxes[:, 3] - gt_boxes[:, 1],
                  gt_boxes[:, 2] - gt_boxes[:, 0])

        for k in range(gt_boxes.shape[0]):
            cls_id = gt_labels[k] - 1
            h, w = hs[k], ws[k]
            if h > 0 and w > 0:
                center = gt_boxes.new_tensor([
                    (gt_boxes[k, 0] + gt_boxes[k, 2]) / 2,
                    (gt_boxes[k, 1] + gt_boxes[k, 3]) / 2
                ])

                # no peak will fall between pixels
                if self.gt_plus_dot5:
                    ct_int = (center + 0.5).to(torch.int)
                else:
                    ct_int = center.to(torch.int)
                if self.use_truncate_gaussia:
                    if self.use_tight_gauusia:
                        h_radius = (h / 2).int().item()
                        w_radius = (w / 2).int().item()
                    else:
                        radius = gaussian_radius((h.ceil(), w.ceil()))
                        radius = max(0, int(radius.item()))
                        h_radius = (radius * (h / w).sqrt()).int().item()
                        w_radius = (radius * (w / h).sqrt()).int().item()
                    draw_truncate_gaussian(heatmap[cls_id], ct_int, h_radius,
                                           w_radius)
                else:
                    radius = gaussian_radius((h.ceil(), w.ceil()))
                    radius = max(0, int(radius.item()))
                    draw_umich_gaussian(heatmap[cls_id], ct_int, radius)
                # directly predict the width and height
                wh[k] = wh.new_tensor([1. * w, 1. * h])
                ind[k] = ct_int[1] * output_w + ct_int[0]
                if self.use_reg_offset:
                    reg[k] = center - ct_int.float()
                if self.use_giou:
                    center_location[k] = center
                reg_mask[k] = 1

        return heatmap, wh, reg_mask, ind, reg, center_location