def __call__(self, pred_hm, pred_heights, pred_reg_xoffset, pred_reg_yoffset, pred_pose, heatmap, heights, reg_xoffset, reg_yoffset, pose, reg_mask, ind): """ Args: pred_hm: tensor, (batch, cls, h, w). pred_heights: tensor, (batch, 3, h, w). pred_reg_xoffset: tensor, (batch, 3, h, w). pred_reg_yoffset: tensor, (batch, 3, h, w). pred_pose: tensor, (batch, 8, h, w). heatmap: tensor, (batch, cls, h, w). heights: tensor, (batch, max_obj, 3). reg_xoffset: tensor, (batch, max_obj, 3). reg_yoffset: tensor, (batch, max_obj, 3). pose: tensor, (batch, max_obj). reg_mask: tensor, (batch, max_obj, 3). ind: tensor, (batch, max_obj). Returns: """ pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 3, h, w) => (batch, max_obj, 3) pred_heights = tranpose_and_gather_feat(pred_heights, ind) pred_reg_xoffset = tranpose_and_gather_feat(pred_reg_xoffset, ind) pred_reg_yoffset = tranpose_and_gather_feat(pred_reg_yoffset, ind) # cross_entropy only accepts (N,C,d) order # (batch, 8, h, w) => (batch, 8, max_obj) pred_pose = pred_pose.view(pred_pose.shape[0], pred_pose.shape[1], -1) pred_pose = pred_pose.gather( 2, ind.unsqueeze(1).expand(-1, pred_pose.shape[1], -1)) mask = reg_mask.float() avg_factor = mask.sum() + 1e-4 heights_loss = weighted_l1( pred_heights, heights, mask, avg_factor=avg_factor) * self.heights_weight xoff_loss = weighted_l1( pred_reg_xoffset, reg_xoffset, mask, avg_factor=avg_factor) * self.xoff_weight yoff_loss = weighted_l1( pred_reg_yoffset, reg_yoffset, mask, avg_factor=avg_factor) * self.yoff_weight instance_mask = mask[..., 0] instance_af = instance_mask.sum() + 1e-4 pose_loss = cross_entropy( pred_pose, pose, instance_mask, avg_factor=instance_af) * self.pose_weight return hm_loss, heights_loss, xoff_loss, yoff_loss, pose_loss
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight return hm_loss # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss
def forward(self, objness, box_centers, box_scales, cls_preds, objness_t, center_t, scale_t, weight_t, class_t, class_mask): # compute some normalization count, except batch-size denorm = torch.tensor(objness_t.size())[1:].prod().to(torch.float32) class_mask = class_mask.to(torch.float32) weight_t = weight_t * objness_t hard_objness_t = torch.where(objness_t > 0, torch.ones_like(objness_t), objness_t) new_objness_mask = torch.where(objness_t > 0, objness_t, (objness_t >= 0).to(torch.float32)) obj_loss = F.binary_cross_entropy_with_logits( objness, hard_objness_t, new_objness_mask) * denorm center_loss = F.binary_cross_entropy_with_logits( box_centers, center_t, weight_t) * denorm * 2 scale_loss = weighted_l1( box_scales, scale_t, weight_t, reduction='mean') * denorm * 2 denorm_class = torch.tensor( class_t.size())[1:].prod(dtype=torch.float32) class_mask = class_mask * objness_t cls_loss = F.binary_cross_entropy_with_logits( cls_preds, class_t, class_mask) * denorm_class return obj_loss, center_loss, scale_loss, cls_loss
def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). pred_reg_offset: None or tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight off_loss = hm_loss.new_tensor(0.) if self.use_reg_offset: pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind) off_loss = weighted_l1( pred_reg, reg_offset, mask, avg_factor=avg_factor) * self.off_weight add_summary('centernet', gt_reg_off=reg_offset[reg_offset > 0].mean().item()) if every_n_local_step(500): add_feature_summary('centernet/heatmap', pred_hm.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', heatmap.detach().cpu().numpy()) if self.use_reg_offset: add_feature_summary('centernet/reg_offset', pred_reg_offset.detach().cpu().numpy()) return hm_loss, wh_loss, off_loss
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, center_location): """ Args: pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w). pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w). heatmap: tensor, (batch, 80, h*w for all levels). wh: tensor, (batch, max_obj*level_num, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num). ind: tensor, (batch, max_obj*level_num). center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when using GIOU. Returns: """ if every_n_local_step(500): for lvl, hm in enumerate(pred_hm): hm_summary = hm.clone().detach().sigmoid_() add_feature_summary('centernet_heatmap_lv{}'.format(lvl), hm_summary.cpu().numpy()) H, W = pred_hm[0].shape[2:] level_num = len(pred_hm) pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm], dim=-1) pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2) ind_levels = ind.chunk(level_num, dim=1) pred_wh_pruned = [] for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels): pred_wh_pruned.append( tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl)) pred_wh_pruned = torch.cat(pred_wh_pruned, dim=1) # (batch, max_obj*level_num, 2) mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat((center_location - pred_wh_pruned / 2., center_location + pred_wh_pruned / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) box_tl = torch.clamp(center_location - wh / 2., min=0) boxes = torch.cat((box_tl, box_br), dim=2) mask_expand_4 = mask.repeat(1, 1, 2) wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4) else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss