def _draw_heatmap(self, heatmap, center, h, w): if self.gt_plus_dot5: ct_int = (center + 0.5).to(torch.int) else: ct_int = center.to(torch.int) if self.use_truncate_gaussia: if self.use_tight_gauusia: h_radius = (h / 2).int().item() w_radius = (w / 2).int().item() else: radius = gaussian_radius((h.ceil(), w.ceil())) radius = max(0, int(radius.item())) h_radius = (radius * (h / w).sqrt()).int().item() w_radius = (radius * (w / h).sqrt()).int().item() draw_truncate_gaussian(heatmap, ct_int, h_radius, w_radius) else: radius = gaussian_radius((h.ceil(), w.ceil())) radius = max(0, int(radius.item())) draw_umich_gaussian(heatmap, ct_int, radius) return ct_int
def target_single_image(self, gt_boxes, gt_labels, feat_shape): """ Args: gt_boxes: tensor, tensor <=> img, (num_gt, 4). gt_labels: tensor, tensor <=> img, (num_gt,). feat_shape: tuple. Returns: heatmap: tensor, tensor <=> img, (80, h, w). box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w). """ output_h, output_w = feat_shape heatmap_channel = self.num_fg heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w)) fake_heatmap = gt_boxes.new_zeros((output_h, output_w)) box_target = gt_boxes.new_ones( (self.wh_planes, output_h, output_w)) * -1 wh_weight = gt_boxes.new_zeros( (self.wh_planes // 4, output_h, output_w)) hm_weight = gt_boxes.new_zeros( (self.wh_planes // 4, output_h, output_w)) centerness = gt_boxes.new_zeros((1, output_h, output_w)) if self.wh_area_process == 'log': boxes_areas_log = bbox_areas(gt_boxes).log() elif self.wh_area_process == 'sqrt': boxes_areas_log = bbox_areas(gt_boxes).sqrt() else: boxes_areas_log = bbox_areas(gt_boxes) boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log, boxes_areas_log.size(0)) if self.wh_area_process == 'norm': boxes_area_topk_log[:] = 1. gt_boxes = gt_boxes[boxes_ind] gt_labels = gt_labels[boxes_ind] feat_gt_boxes = gt_boxes / self.down_ratio feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]], min=0, max=output_w - 1) feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]], min=0, max=output_h - 1) feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1], feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0]) r1 = (1 - self.center_ratio) / 2 r2 = (1 - self.ignore_ratio) / 2 # we calc the center and ignore area based on the gt-boxes of the origin scale # no peak will fall between pixels ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2, (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2], dim=1) / self.down_ratio).to(torch.int) if self.hm_center_ratio is None: radiuses = torch.clamp(gaussian_radius( (feat_hs.ceil(), feat_ws.ceil())), min=0) hw_ratio_sqrt = (feat_hs / feat_ws).sqrt() h_radiuses = (radiuses * hw_ratio_sqrt).int() w_radiuses = (radiuses / hw_ratio_sqrt).int() if self.ct_gaussian: radiuses = radiuses.int() else: h_radiuses = (feat_hs * self.hm_center_ratio).int() w_radiuses = (feat_ws * self.hm_center_ratio).int() if (self.center_ratio / 2 != self.hm_center_ratio) and self.wh_heatmap: wh_h_radiuses = (feat_hs * self.center_ratio / 2).int() wh_w_radiuses = (feat_ws * self.center_ratio / 2).int() # calculate positive (center) regions ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose( 0, 1), r1, use_round=False) ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [ torch.round(x / self.down_ratio).int() for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s] ] ctr_x1s, ctr_x2s = [ torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s] ] ctr_y1s, ctr_y2s = [ torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s] ] ctr_xs_diff, ctr_ys_diff = ctr_x2s - ctr_x1s + 1, ctr_y2s - ctr_y1s + 1 if self.fill_small: are_fill_small = (ctr_ys_diff <= 4) & (ctr_xs_diff <= 4) collide_pixels_summary = 0 # larger boxes have lower priority than small boxes. for k in range(boxes_ind.shape[0]): cls_id = gt_labels[k] - 1 ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[ k], ctr_y2s[k] ctr_x_diff, ctr_y_diff = ctr_xs_diff[k], ctr_ys_diff[k] if self.fovea_hm or (self.fill_small and are_fill_small[k]): ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region( feat_gt_boxes[k], r2, (output_h, output_w)) if not self.fovea_hm: ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ignore_x1, ignore_y1, ignore_x2, ignore_y2 fake_heatmap = fake_heatmap.zero_() if self.ct_gaussian: draw_umich_gaussian(fake_heatmap, ct_ints[k], radiuses[k].item()) else: draw_truncate_gaussian(fake_heatmap, ct_ints[k], h_radiuses[k].item(), w_radiuses[k].item()) if self.fovea_hm: # ignore_mask_box is necessary to prevent the ignore areas covering the # pos areas of larger boxes ignore_mask_box = (heatmap[cls_id, ignore_y1:ignore_y2 + 1, ignore_x1:ignore_x2 + 1] == 0) heatmap[cls_id, ignore_y1:ignore_y2 + 1, ignore_x1:ignore_x2 + 1][ignore_mask_box] = -1 heatmap[cls_id, ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1 centerness[0] = torch.max(centerness[0], fake_heatmap) else: heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap) if self.wh_heatmap: if self.hm_center_ratio != self.center_ratio / 2: fake_heatmap = fake_heatmap.zero_() draw_truncate_gaussian(fake_heatmap, ct_ints[k], wh_h_radiuses[k].item(), wh_w_radiuses[k].item()) box_target_inds = fake_heatmap > 0 else: box_target_inds = torch.zeros_like(fake_heatmap, dtype=torch.uint8) box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1 if self.wh_agnostic: collide_pixels_summary += (box_target[:, box_target_inds] > 0).sum() box_target[:, box_target_inds] = gt_boxes[k][:, None] else: collide_pixels_summary += (box_target[( cls_id * 4):(cls_id + 1) * 4, box_target_inds] > 0).sum() box_target[(cls_id * 4):((cls_id + 1) * 4), box_target_inds] = gt_boxes[k][:, None] local_heatmap = fake_heatmap[box_target_inds] ct_div = local_heatmap.sum() local_heatmap *= boxes_area_topk_log[k] if self.wh_agnostic: cls_id = 0 if self.avg_wh_weightv2 and ct_div > 0: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div elif self.avg_wh_weightv3 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div elif self.avg_wh_weightv4 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div else: wh_weight[cls_id, box_target_inds] = \ boxes_area_topk_log[k] / box_target_inds.sum().float() if self.avg_wh_weightv4: wh_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = \ boxes_area_topk_log[k] if not self.ct_version: target_loc = fake_heatmap > 0.9 hm_target_num = target_loc.sum().float() hm_weight[cls_id, target_loc] = 1 / (2 * (hm_target_num - 1)) hm_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = 1 / 2. add_summary('box_target', collide_pixels=collide_pixels_summary) pos_pixels_summary = (box_target > 0).sum() add_summary('box_target', pos_pixels=pos_pixels_summary) add_summary('box_target', collide_ratio=collide_pixels_summary / pos_pixels_summary.float()) return heatmap, box_target, centerness, wh_weight, hm_weight
def target_single_image(self, gt_boxes, gt_labels, feat_shape): """ Args: gt_boxes: tensor, tensor <=> img, (num_gt, 4). gt_labels: tensor, tensor <=> img, (num_gt,). feat_shape: tuple. Returns: heatmap: tensor, tensor <=> img, (80, h, w). wh: tensor, tensor <=> img, (max_obj, 2). reg_mask: tensor, tensor <=> img, (max_obj,). ind: tensor, tensor <=> img, (max_obj,). reg: tensor, tensor <=> img, (max_obj, 2). center_location: tensor or None, tensor <=> img, (max_obj, 2). """ output_h, output_w = feat_shape heatmap = gt_boxes.new_zeros((self.num_fg, output_h, output_w)) wh = gt_boxes.new_zeros((self.max_objs, 2)) reg_mask = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.uint8) ind = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.long) reg, center_location = None, None if self.use_reg_offset: reg = gt_boxes.new_zeros((self.max_objs, 2)) if self.use_giou: center_location = gt_boxes.new_zeros((self.max_objs, 2)) gt_boxes /= self.down_ratio gt_boxes[:, [0, 2]] = torch.clamp(gt_boxes[:, [0, 2]], 0, output_w - 1) gt_boxes[:, [1, 3]] = torch.clamp(gt_boxes[:, [1, 3]], 0, output_h - 1) hs, ws = (gt_boxes[:, 3] - gt_boxes[:, 1], gt_boxes[:, 2] - gt_boxes[:, 0]) for k in range(gt_boxes.shape[0]): cls_id = gt_labels[k] - 1 h, w = hs[k], ws[k] if h > 0 and w > 0: center = gt_boxes.new_tensor([ (gt_boxes[k, 0] + gt_boxes[k, 2]) / 2, (gt_boxes[k, 1] + gt_boxes[k, 3]) / 2 ]) # no peak will fall between pixels if self.gt_plus_dot5: ct_int = (center + 0.5).to(torch.int) else: ct_int = center.to(torch.int) if self.use_truncate_gaussia: if self.use_tight_gauusia: h_radius = (h / 2).int().item() w_radius = (w / 2).int().item() else: radius = gaussian_radius((h.ceil(), w.ceil())) radius = max(0, int(radius.item())) h_radius = (radius * (h / w).sqrt()).int().item() w_radius = (radius * (w / h).sqrt()).int().item() draw_truncate_gaussian(heatmap[cls_id], ct_int, h_radius, w_radius) else: radius = gaussian_radius((h.ceil(), w.ceil())) radius = max(0, int(radius.item())) draw_umich_gaussian(heatmap[cls_id], ct_int, radius) # directly predict the width and height wh[k] = wh.new_tensor([1. * w, 1. * h]) ind[k] = ct_int[1] * output_w + ct_int[0] if self.use_reg_offset: reg[k] = center - ct_int.float() if self.use_giou: center_location[k] = center reg_mask[k] = 1 return heatmap, wh, reg_mask, ind, reg, center_location