def loss_single(self, pred_hm, pred_wh, heatmap, box_target, wh_weight, down_ratio, hm_weight_factor, wh_weight_factor): H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) loss_cls = ct_focal_loss(pred_hm, heatmap) * hm_weight_factor base_step = self.get_down_ratio(down_ratio) shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat( (base_loc - pred_wh[:, [0, 1]], base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) boxes = box_target.permute(0, 2, 3, 1) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 loss_bbox = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * wh_weight_factor return loss_cls, loss_bbox
def __call__(self, pred_hm, pred_heights, pred_reg_xoffset, pred_reg_yoffset, pred_pose, heatmap, heights, reg_xoffset, reg_yoffset, pose, reg_mask, ind): """ Args: pred_hm: tensor, (batch, cls, h, w). pred_heights: tensor, (batch, 3, h, w). pred_reg_xoffset: tensor, (batch, 3, h, w). pred_reg_yoffset: tensor, (batch, 3, h, w). pred_pose: tensor, (batch, 8, h, w). heatmap: tensor, (batch, cls, h, w). heights: tensor, (batch, max_obj, 3). reg_xoffset: tensor, (batch, max_obj, 3). reg_yoffset: tensor, (batch, max_obj, 3). pose: tensor, (batch, max_obj). reg_mask: tensor, (batch, max_obj, 3). ind: tensor, (batch, max_obj). Returns: """ pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 3, h, w) => (batch, max_obj, 3) pred_heights = tranpose_and_gather_feat(pred_heights, ind) pred_reg_xoffset = tranpose_and_gather_feat(pred_reg_xoffset, ind) pred_reg_yoffset = tranpose_and_gather_feat(pred_reg_yoffset, ind) # cross_entropy only accepts (N,C,d) order # (batch, 8, h, w) => (batch, 8, max_obj) pred_pose = pred_pose.view(pred_pose.shape[0], pred_pose.shape[1], -1) pred_pose = pred_pose.gather( 2, ind.unsqueeze(1).expand(-1, pred_pose.shape[1], -1)) mask = reg_mask.float() avg_factor = mask.sum() + 1e-4 heights_loss = weighted_l1( pred_heights, heights, mask, avg_factor=avg_factor) * self.heights_weight xoff_loss = weighted_l1( pred_reg_xoffset, reg_xoffset, mask, avg_factor=avg_factor) * self.xoff_weight yoff_loss = weighted_l1( pred_reg_yoffset, reg_yoffset, mask, avg_factor=avg_factor) * self.yoff_weight instance_mask = mask[..., 0] instance_af = instance_mask.sum() + 1e-4 pose_loss = cross_entropy( pred_pose, pose, instance_mask, avg_factor=instance_af) * self.pose_weight return hm_loss, heights_loss, xoff_loss, yoff_loss, pose_loss
def loss_calc(self, pred_hm, pred_wh, pred_off, heatmap, box_target, wh_weight, off_target): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). heatmap: tensor, same as pred_hm. box_target: tensor, same as pred_wh. wh_weight: tensor, same as pred_wh. Returns: hm_loss wh_loss """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None or H != self.base_loc.shape[ 1] or W != self.base_loc.shape[2]: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.wh_weight mask_off = off_target.clone() mask_off = (mask_off > 0).float() num_mask = torch.nonzero(mask_off).size(0) * 0.5 off_loss = self.crit_off( pred_off, off_target, weight=mask_off, avg_factor=num_mask) * self.off_weight return hm_loss, wh_loss, off_loss
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight return hm_loss # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss
def loss_single(self, pred_hm, pred_wh, heatmap, box_target, wh_weight, down_ratio, base_loc_name, hm_weight_factor, wh_weight_factor, focal_loss_beta): H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) loss_cls = ct_focal_loss(pred_hm, heatmap, beta=focal_loss_beta) * hm_weight_factor if getattr(self, base_loc_name) is None or H != getattr(self, base_loc_name).shape[ 1] or W != getattr(self, base_loc_name).shape[2]: base_step = down_ratio shifts_x = torch.arange( 0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange( 0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) setattr(self, base_loc_name, torch.stack((shift_x, shift_y), dim=0)) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((getattr(self, base_loc_name) - pred_wh[:, [0, 1]], getattr(self, base_loc_name) + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) boxes = box_target.permute(0, 2, 3, 1) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 loss_bbox = giou_loss( pred_boxes, boxes, mask, avg_factor=avg_factor) * wh_weight_factor return loss_cls, loss_bbox
def loss_calc(self, pred_hm, pred_wh, pred_hm_2, pred_wh_2, pred_hm_3, pred_wh_3, heatmap, box_target, wh_weight, heatmap_2, box_target_2, wh_weight_2, heatmap_3, box_target_3, wh_weight_3): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). heatmap: tensor, same as pred_hm. box_target: tensor, same as pred_wh. wh_weight: tensor, same as pred_wh. Returns: hm_loss wh_loss """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) pred_hm_2 = torch.clamp(pred_hm_2.sigmoid_(), min=1e-4, max=1 - 1e-4) pred_hm_3 = torch.clamp(pred_hm_3.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight hm_loss_2 = ct_focal_loss(pred_hm_2, heatmap_2) * self.hm_weight_2 hm_loss_3 = ct_focal_loss(pred_hm_3, heatmap_3) * self.hm_weight_3 mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 mask2 = wh_weight_2.view(-1, H, W) avg_factor2 = mask2.sum() + 1e-4 mask3 = wh_weight_3.view(-1, H, W) avg_factor3 = mask3.sum() + 1e-4 if self.base_loc is None or H != self.base_loc.shape[ 1] or W != self.base_loc.shape[2]: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) pred_boxes_2 = torch.cat((self.base_loc - pred_wh_2[:, [0, 1]], self.base_loc + pred_wh_2[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) pred_boxes_3 = torch.cat((self.base_loc - pred_wh_3[:, [0, 1]], self.base_loc + pred_wh_3[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) boxes_2 = box_target_2.permute(0, 2, 3, 1) boxes_3 = box_target_3.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.wh_weight wh_loss_2 = giou_loss( pred_boxes_2, boxes_2, mask2, avg_factor=avg_factor2) * self.wh_weight_2 wh_loss_3 = giou_loss( pred_boxes_3, boxes_3, mask3, avg_factor=avg_factor3) * self.wh_weight_3 return hm_loss, wh_loss, hm_loss_2, wh_loss_2, hm_loss_3, wh_loss_3
def loss_calc(self, pred_hm_large, pred_hm_little, pred_wh_large, pred_wh_little, heatmap_large, heatmap_little, box_target_large, box_target_little, wh_weight_large, wh_weight_little): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). heatmap: tensor, same as pred_hm. box_target: tensor, same as pred_wh. wh_weight: tensor, same as pred_wh. Returns: hm_loss wh_loss """ H, W = pred_hm_little.shape[2:] pred_hm_large = torch.clamp(pred_hm_large.sigmoid_(), min=1e-4, max=1 - 1e-4) pred_hm_little = torch.clamp(pred_hm_little.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = 1 * ( ct_focal_loss(pred_hm_large, heatmap_large) * self.hm_weight + ct_focal_loss(pred_hm_little, heatmap_little) * self.hm_weight) mask_large = wh_weight_large.view(-1, H, W) mask_little = wh_weight_little.view(-1, H, W) avg_factor_large = mask_large.sum() + 1e-4 avg_factor_little = mask_little.sum() + 1e-4 if self.base_loc is None or H != self.base_loc.shape[ 1] or W != self.base_loc.shape[2]: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap_little.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap_little.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes_large = torch.cat( (self.base_loc - pred_wh_large[:, [0, 1]], self.base_loc + pred_wh_large[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) pred_boxes_little = torch.cat( (self.base_loc - pred_wh_little[:, [0, 1]], self.base_loc + pred_wh_little[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes_large = box_target_large.permute(0, 2, 3, 1) boxes_little = box_target_little.permute(0, 2, 3, 1) wh_loss = giou_loss_ct(pred_boxes_large, boxes_large, mask_large, avg_factor=avg_factor_large) * self.wh_weight+ \ giou_loss_ct(pred_boxes_little, boxes_little, mask_little, avg_factor=avg_factor_little) * self.wh_weight return hm_loss, wh_loss
def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target, centerness, wh_weight, hm_weight): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). pred_centerness: tensor or None, (batch, 1, h, w). heatmap: tensor, (batch, 80, h, w). box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). centerness: tensor or None, (batch, 1, h, w). wh_weight: tensor or None, (batch, 80, h, w). Returns: """ if every_n_local_step(100): pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4) gt_hm_summary = heatmap.clone() if self.fovea_hm: if not self.only_merge: pred_ctn_summary = torch.clamp( torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) add_feature_summary( 'centernet/centerness', pred_ctn_summary.detach().cpu().numpy(), type='f') add_feature_summary( 'centernet/merge', (pred_ctn_summary * pred_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/gt_centerness', centerness.detach().cpu().numpy(), type='f') add_feature_summary('centernet/gt_merge', (centerness * gt_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/heatmap', pred_hm_summary.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', gt_hm_summary.detach().cpu().numpy()) H, W = pred_hm.shape[2:] if not self.fovea_hm: pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_weight = None if self.ct_version else hm_weight hm_loss = ct_focal_loss(pred_hm, heatmap, hm_weight=hm_weight) * self.hm_weight centerness_loss = hm_loss.new_tensor([0.]) merge_loss = hm_loss.new_tensor([0.]) else: care_mask = (heatmap >= 0).float() avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6 if not self.only_merge: hm_loss = py_sigmoid_focal_loss( pred_hm, heatmap, care_mask, reduction='sum') / avg_factor * self.hm_weight pred_centerness = torch.clamp(torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) centerness_loss = ct_focal_loss( pred_centerness, centerness, gamma=2.) * self.ct_weight merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm) * pred_centerness, min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight else: hm_loss = pred_hm.new_tensor([0.]) centerness_loss = pred_hm.new_tensor([0.]) merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight if not self.wh_agnostic: pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W) box_target = box_target.view( box_target.size(0) * pred_hm.size(1), 4, H, W) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.giou_weight return hm_loss, wh_loss, centerness_loss, merge_loss
def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). pred_reg_offset: None or tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight off_loss = hm_loss.new_tensor(0.) if self.use_reg_offset: pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind) off_loss = weighted_l1( pred_reg, reg_offset, mask, avg_factor=avg_factor) * self.off_weight add_summary('centernet', gt_reg_off=reg_offset[reg_offset > 0].mean().item()) if every_n_local_step(500): add_feature_summary('centernet/heatmap', pred_hm.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', heatmap.detach().cpu().numpy()) if self.use_reg_offset: add_feature_summary('centernet/reg_offset', pred_reg_offset.detach().cpu().numpy()) return hm_loss, wh_loss, off_loss
def loss_calc(self, pred_feat, pred_hm, pred_wh, heatmap, box_target, wh_weight): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). heatmap: tensor, same as pred_hm. box_target: tensor, same as pred_wh. wh_weight: tensor, same as pred_wh. Returns: hm_loss wh_loss """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None or H != self.base_loc.shape[ 1] or W != self.base_loc.shape[2]: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = self.iou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.wh_weight wh2_loss = wh_loss.new_zeros([1]) if self.two_stage: heat = simple_nms(pred_hm) scores, inds, clses, ys, xs = self._topk(heat, topk=100) pred_boxes_2 = pred_boxes.view(pred_boxes.size(0), -1, pred_boxes.size(3)) boxes_2 = boxes.view(*pred_boxes_2.shape) inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), pred_boxes_2.size(2)) pred_boxes_2 = pred_boxes_2.gather(1, inds) # (batch, 100, 4) boxes_2 = boxes_2.gather(1, inds) score_thr = 0.01 scores_keep = scores > score_thr # (batch, topk) batch_idx = pred_boxes_2.new_tensor( torch.arange(0., pred_boxes_2.shape[0], 1.)).view(-1, 1, 1).expand(pred_boxes_2.shape[0], pred_boxes_2.shape[1], 1)[scores_keep] pred_boxes_2 = pred_boxes_2[scores_keep] boxes_2 = boxes_2[scores_keep].detach() valid_boxes = (boxes_2 >= 0).min(1)[0] batch_idx = batch_idx[valid_boxes] # (n, 1) pred_boxes_2 = pred_boxes_2[valid_boxes] # (n, 4) boxes_2 = boxes_2[valid_boxes] # (n, 4) roi_boxes = torch.cat((batch_idx, pred_boxes_2), dim=1).detach() if roi_boxes.size(0) > 0: rois = self.align(pred_feat, roi_boxes) # (n, cha, 7, 7) pred_wh2 = self.wh2(rois).view(-1, 4) pred_boxes_2[:, [0, 1]] = pred_boxes_2[:, [0, 1]].detach() - \ pred_wh2[:, [0, 1]] * 16 pred_boxes_2[:, [2, 3]] = pred_boxes_2[:, [2, 3]].detach() + \ pred_wh2[:, [2, 3]] * 16 wh2_loss = giou_loss(pred_boxes_2, boxes_2, boxes_2.new_ones(boxes_2.size(0))) return hm_loss, wh_loss, wh2_loss
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, center_location): """ Args: pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w). pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w). heatmap: tensor, (batch, 80, h*w for all levels). wh: tensor, (batch, max_obj*level_num, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num). ind: tensor, (batch, max_obj*level_num). center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when using GIOU. Returns: """ if every_n_local_step(500): for lvl, hm in enumerate(pred_hm): hm_summary = hm.clone().detach().sigmoid_() add_feature_summary('centernet_heatmap_lv{}'.format(lvl), hm_summary.cpu().numpy()) H, W = pred_hm[0].shape[2:] level_num = len(pred_hm) pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm], dim=-1) pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2) ind_levels = ind.chunk(level_num, dim=1) pred_wh_pruned = [] for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels): pred_wh_pruned.append( tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl)) pred_wh_pruned = torch.cat(pred_wh_pruned, dim=1) # (batch, max_obj*level_num, 2) mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat((center_location - pred_wh_pruned / 2., center_location + pred_wh_pruned / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) box_tl = torch.clamp(center_location - wh / 2., min=0) boxes = torch.cat((box_tl, box_br), dim=2) mask_expand_4 = mask.repeat(1, 1, 2) wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4) else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss