def target_single_image(self, gt_boxes, gt_labels, feat_shape): """ Args: gt_boxes: tensor, tensor <=> img, (num_gt, 4). gt_labels: tensor, tensor <=> img, (num_gt,). feat_shape: tuple. Returns: heatmap: tensor, tensor <=> img, (80, h, w). box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w). reg_weight: tensor, same as box_target """ output_h, output_w = feat_shape heatmap_channel = self.num_fg heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w)) fake_heatmap = gt_boxes.new_zeros((output_h, output_w)) box_target = gt_boxes.new_ones( (self.wh_planes, output_h, output_w)) * -1 reg_weight = gt_boxes.new_zeros( (self.wh_planes // 4, output_h, output_w)) if self.wh_area_process == 'log': boxes_areas_log = bbox_areas(gt_boxes).log() elif self.wh_area_process == 'sqrt': boxes_areas_log = bbox_areas(gt_boxes).sqrt() else: boxes_areas_log = bbox_areas(gt_boxes) boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log, boxes_areas_log.size(0)) if self.wh_area_process == 'norm': boxes_area_topk_log[:] = 1. gt_boxes = gt_boxes[boxes_ind] gt_labels = gt_labels[boxes_ind] feat_gt_boxes = gt_boxes / self.down_ratio feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]], min=0, max=output_w - 1) feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]], min=0, max=output_h - 1) feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1], feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0]) # we calc the center and ignore area based on the gt-boxes of the origin scale # no peak will fall between pixels ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2, (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2], dim=1) / self.down_ratio).to(torch.int) h_radiuses_alpha = (feat_hs / 2. * self.alpha).int() w_radiuses_alpha = (feat_ws / 2. * self.alpha).int() if self.wh_gaussian and self.alpha != self.beta: h_radiuses_beta = (feat_hs / 2. * self.beta).int() w_radiuses_beta = (feat_ws / 2. * self.beta).int() if not self.wh_gaussian: # calculate positive (center) regions r1 = (1 - self.beta) / 2 ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region( gt_boxes.transpose(0, 1), r1) ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [ torch.round(x.float() / self.down_ratio).int() for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s] ] ctr_x1s, ctr_x2s = [ torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s] ] ctr_y1s, ctr_y2s = [ torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s] ] # larger boxes have lower priority than small boxes. for k in range(boxes_ind.shape[0]): cls_id = gt_labels[k] - 1 fake_heatmap = fake_heatmap.zero_() self.draw_truncate_gaussian(fake_heatmap, ct_ints[k], h_radiuses_alpha[k].item(), w_radiuses_alpha[k].item()) heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap) if self.wh_gaussian: if self.alpha != self.beta: fake_heatmap = fake_heatmap.zero_() self.draw_truncate_gaussian(fake_heatmap, ct_ints[k], h_radiuses_beta[k].item(), w_radiuses_beta[k].item()) box_target_inds = fake_heatmap > 0 else: ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[ k], ctr_x2s[k], ctr_y2s[k] box_target_inds = torch.zeros_like(fake_heatmap, dtype=torch.uint8) box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1 if self.wh_agnostic: box_target[:, box_target_inds] = gt_boxes[k][:, None] cls_id = 0 else: box_target[(cls_id * 4):((cls_id + 1) * 4), box_target_inds] = gt_boxes[k][:, None] if self.wh_gaussian: local_heatmap = fake_heatmap[box_target_inds] ct_div = local_heatmap.sum() local_heatmap *= boxes_area_topk_log[k] reg_weight[cls_id, box_target_inds] = local_heatmap / ct_div else: reg_weight[cls_id, box_target_inds] = \ boxes_area_topk_log[k] / box_target_inds.sum().float() return heatmap, box_target, reg_weight
def target_single_image(self, gt_boxes, gt_labels, feat_shape): """ Args: gt_boxes: tensor, tensor <=> img, (num_gt, 4). gt_labels: tensor, tensor <=> img, (num_gt,). feat_shape: tuple. Returns: heatmap: tensor, tensor <=> img, (80, h, w). box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w). """ output_h, output_w = feat_shape heatmap_channel = self.num_fg heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w)) fake_heatmap = gt_boxes.new_zeros((output_h, output_w)) box_target = gt_boxes.new_ones( (self.wh_planes, output_h, output_w)) * -1 wh_weight = gt_boxes.new_zeros( (self.wh_planes // 4, output_h, output_w)) hm_weight = gt_boxes.new_zeros( (self.wh_planes // 4, output_h, output_w)) centerness = gt_boxes.new_zeros((1, output_h, output_w)) if self.wh_area_process == 'log': boxes_areas_log = bbox_areas(gt_boxes).log() elif self.wh_area_process == 'sqrt': boxes_areas_log = bbox_areas(gt_boxes).sqrt() else: boxes_areas_log = bbox_areas(gt_boxes) boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log, boxes_areas_log.size(0)) if self.wh_area_process == 'norm': boxes_area_topk_log[:] = 1. gt_boxes = gt_boxes[boxes_ind] gt_labels = gt_labels[boxes_ind] feat_gt_boxes = gt_boxes / self.down_ratio feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]], min=0, max=output_w - 1) feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]], min=0, max=output_h - 1) feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1], feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0]) r1 = (1 - self.center_ratio) / 2 r2 = (1 - self.ignore_ratio) / 2 # we calc the center and ignore area based on the gt-boxes of the origin scale # no peak will fall between pixels ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2, (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2], dim=1) / self.down_ratio).to(torch.int) if self.hm_center_ratio is None: radiuses = torch.clamp(gaussian_radius( (feat_hs.ceil(), feat_ws.ceil())), min=0) hw_ratio_sqrt = (feat_hs / feat_ws).sqrt() h_radiuses = (radiuses * hw_ratio_sqrt).int() w_radiuses = (radiuses / hw_ratio_sqrt).int() if self.ct_gaussian: radiuses = radiuses.int() else: h_radiuses = (feat_hs * self.hm_center_ratio).int() w_radiuses = (feat_ws * self.hm_center_ratio).int() if (self.center_ratio / 2 != self.hm_center_ratio) and self.wh_heatmap: wh_h_radiuses = (feat_hs * self.center_ratio / 2).int() wh_w_radiuses = (feat_ws * self.center_ratio / 2).int() # calculate positive (center) regions ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose( 0, 1), r1, use_round=False) ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [ torch.round(x / self.down_ratio).int() for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s] ] ctr_x1s, ctr_x2s = [ torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s] ] ctr_y1s, ctr_y2s = [ torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s] ] ctr_xs_diff, ctr_ys_diff = ctr_x2s - ctr_x1s + 1, ctr_y2s - ctr_y1s + 1 if self.fill_small: are_fill_small = (ctr_ys_diff <= 4) & (ctr_xs_diff <= 4) collide_pixels_summary = 0 # larger boxes have lower priority than small boxes. for k in range(boxes_ind.shape[0]): cls_id = gt_labels[k] - 1 ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[ k], ctr_y2s[k] ctr_x_diff, ctr_y_diff = ctr_xs_diff[k], ctr_ys_diff[k] if self.fovea_hm or (self.fill_small and are_fill_small[k]): ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region( feat_gt_boxes[k], r2, (output_h, output_w)) if not self.fovea_hm: ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ignore_x1, ignore_y1, ignore_x2, ignore_y2 fake_heatmap = fake_heatmap.zero_() if self.ct_gaussian: draw_umich_gaussian(fake_heatmap, ct_ints[k], radiuses[k].item()) else: draw_truncate_gaussian(fake_heatmap, ct_ints[k], h_radiuses[k].item(), w_radiuses[k].item()) if self.fovea_hm: # ignore_mask_box is necessary to prevent the ignore areas covering the # pos areas of larger boxes ignore_mask_box = (heatmap[cls_id, ignore_y1:ignore_y2 + 1, ignore_x1:ignore_x2 + 1] == 0) heatmap[cls_id, ignore_y1:ignore_y2 + 1, ignore_x1:ignore_x2 + 1][ignore_mask_box] = -1 heatmap[cls_id, ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1 centerness[0] = torch.max(centerness[0], fake_heatmap) else: heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap) if self.wh_heatmap: if self.hm_center_ratio != self.center_ratio / 2: fake_heatmap = fake_heatmap.zero_() draw_truncate_gaussian(fake_heatmap, ct_ints[k], wh_h_radiuses[k].item(), wh_w_radiuses[k].item()) box_target_inds = fake_heatmap > 0 else: box_target_inds = torch.zeros_like(fake_heatmap, dtype=torch.uint8) box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1 if self.wh_agnostic: collide_pixels_summary += (box_target[:, box_target_inds] > 0).sum() box_target[:, box_target_inds] = gt_boxes[k][:, None] else: collide_pixels_summary += (box_target[( cls_id * 4):(cls_id + 1) * 4, box_target_inds] > 0).sum() box_target[(cls_id * 4):((cls_id + 1) * 4), box_target_inds] = gt_boxes[k][:, None] local_heatmap = fake_heatmap[box_target_inds] ct_div = local_heatmap.sum() local_heatmap *= boxes_area_topk_log[k] if self.wh_agnostic: cls_id = 0 if self.avg_wh_weightv2 and ct_div > 0: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div elif self.avg_wh_weightv3 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div elif self.avg_wh_weightv4 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6: wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div else: wh_weight[cls_id, box_target_inds] = \ boxes_area_topk_log[k] / box_target_inds.sum().float() if self.avg_wh_weightv4: wh_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = \ boxes_area_topk_log[k] if not self.ct_version: target_loc = fake_heatmap > 0.9 hm_target_num = target_loc.sum().float() hm_weight[cls_id, target_loc] = 1 / (2 * (hm_target_num - 1)) hm_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = 1 / 2. add_summary('box_target', collide_pixels=collide_pixels_summary) pos_pixels_summary = (box_target > 0).sum() add_summary('box_target', pos_pixels=pos_pixels_summary) add_summary('box_target', collide_ratio=collide_pixels_summary / pos_pixels_summary.float()) return heatmap, box_target, centerness, wh_weight, hm_weight