def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): x = self.extract_feat( img) # each tensor in this tuple is corresponding to a level. if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)): # TODO remove hardcode add_image_summary( 'image/origin', tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[57.12, 58.395, 57.375], to_rgb=True)[0], gt_bboxes[0].cpu(), gt_labels[0].cpu()) if isinstance(x[0], tuple): feature_p = x[0] else: feature_p = x add_feature_summary('feature/x', feature_p[-1].detach().cpu().numpy()) outs = self.bbox_head(x) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg) losses = self.bbox_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
def forward(self, feats): """ Args: feats: list(tensor). Returns: heatmap: tensor, (batch, cls, h, w). heights: tensor, (batch, 3, h, w). xoffset: tensor, (batch, 3, h, w). yoffset: tensor, (batch, 3, h, w). poses: tensor, (batch, 8, h, w). feat: tensor, (batch, c, h, w). """ x = feats[-1] for i, (deconv_layer, shortcut_layer) in enumerate( zip(self.deconv_layers, self.shortcut_layers)): x = deconv_layer(x) if self.use_shortcut: shortcut = shortcut_layer(feats[-i - 2]) if self.neg_shortcut: shortcut = -1 * F.relu(-1 * shortcut) x = x + shortcut if every_n_local_step(500): add_feature_summary('ct_head_shortcut_{}'.format(i), shortcut.detach().cpu().numpy()) heatmap = self.hm(x) heights = self.heights_head(x) xoffset = self.xoffset_head(x) yoffset = self.yoffset_head(x) poses = self.pose_head(x) return heatmap, heights, xoffset, yoffset, poses, x
def forward_train(self, img, img_meta, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, proposals=None): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_meta (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals : override rpn proposals with custom proposals. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) losses = dict() if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)): add_image_summary( 'image/origin', tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0], gt_bboxes[0].cpu(), gt_labels[0].cpu()) add_feature_summary('feature/x', x[-1].detach().cpu().numpy()) # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss( *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) losses.update(rpn_losses) proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, proposal_cfg) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = img.size(0) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] num_bgs = [] num_fgs = [] for i in range(num_imgs): assign_result = bbox_assigner.assign(proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i]) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes[i], gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) num_fgs.append(sampling_result.pos_inds.shape[0]) num_bgs.append(sampling_result.neg_inds.shape[0]) add_summary(prefix="sample_fast_rcnn_targets", num_fgs=np.mean(num_fgs), num_bgs=np.mean(num_bgs)) # bbox head forward and loss if self.with_bbox: rois = bbox2roi([res.bboxes for res in sampling_results]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, *bbox_targets) losses.update(loss_bbox) # mask head forward and loss if self.with_mask: if not self.share_roi_extractor: pos_rois = bbox2roi( [res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_feats = bbox_feats[pos_inds] if mask_feats.shape[0] > 0: mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target( sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) return losses
def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target, centerness, wh_weight, hm_weight): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). pred_centerness: tensor or None, (batch, 1, h, w). heatmap: tensor, (batch, 80, h, w). box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). centerness: tensor or None, (batch, 1, h, w). wh_weight: tensor or None, (batch, 80, h, w). Returns: """ if every_n_local_step(100): pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4) gt_hm_summary = heatmap.clone() if self.fovea_hm: if not self.only_merge: pred_ctn_summary = torch.clamp( torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) add_feature_summary( 'centernet/centerness', pred_ctn_summary.detach().cpu().numpy(), type='f') add_feature_summary( 'centernet/merge', (pred_ctn_summary * pred_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/gt_centerness', centerness.detach().cpu().numpy(), type='f') add_feature_summary('centernet/gt_merge', (centerness * gt_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/heatmap', pred_hm_summary.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', gt_hm_summary.detach().cpu().numpy()) H, W = pred_hm.shape[2:] if not self.fovea_hm: pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_weight = None if self.ct_version else hm_weight hm_loss = ct_focal_loss(pred_hm, heatmap, hm_weight=hm_weight) * self.hm_weight centerness_loss = hm_loss.new_tensor([0.]) merge_loss = hm_loss.new_tensor([0.]) else: care_mask = (heatmap >= 0).float() avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6 if not self.only_merge: hm_loss = py_sigmoid_focal_loss( pred_hm, heatmap, care_mask, reduction='sum') / avg_factor * self.hm_weight pred_centerness = torch.clamp(torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) centerness_loss = ct_focal_loss( pred_centerness, centerness, gamma=2.) * self.ct_weight merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm) * pred_centerness, min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight else: hm_loss = pred_hm.new_tensor([0.]) centerness_loss = pred_hm.new_tensor([0.]) merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight if not self.wh_agnostic: pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W) box_target = box_target.view( box_target.size(0) * pred_hm.size(1), 4, H, W) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.giou_weight return hm_loss, wh_loss, centerness_loss, merge_loss
def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). pred_reg_offset: None or tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight off_loss = hm_loss.new_tensor(0.) if self.use_reg_offset: pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind) off_loss = weighted_l1( pred_reg, reg_offset, mask, avg_factor=avg_factor) * self.off_weight add_summary('centernet', gt_reg_off=reg_offset[reg_offset > 0].mean().item()) if every_n_local_step(500): add_feature_summary('centernet/heatmap', pred_hm.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', heatmap.detach().cpu().numpy()) if self.use_reg_offset: add_feature_summary('centernet/reg_offset', pred_reg_offset.detach().cpu().numpy()) return hm_loss, wh_loss, off_loss
def forward(self, feats): """ Args: feats: list(tensor). Returns: hm: tensor, (batch, 80, h, w). wh: tensor, (batch, 2, h, w). reg: None or tensor, (batch, 2, h, w). """ x = feats[-1] if not self.use_dla: for i, (deconv_layer, shortcut_layer) in enumerate( zip(self.deconv_layers, self.shortcut_layers)): x = deconv_layer(x) if self.use_shortcut: shortcut = shortcut_layer(feats[-i - 2]) if self.neg_shortcut: shortcut = -1 * F.relu(-1 * shortcut) x = x + shortcut if every_n_local_step(500): add_feature_summary('ct_head_shortcut_{}'.format(i), shortcut.detach().cpu().numpy()) if self.use_rep_points: share_feat = self.share_head_conv(x) o1, o2, mask = torch.chunk(self.wh(share_feat), 3, dim=1) offset = torch.cat( (o1, o2), dim=1) # 18 channels for example, h1, w1, h2, w2, ... mask = torch.sigmoid(mask) hm = self.hm(share_feat, offset, mask) # seems like the code below will not improve the mAP, but it suppose to. kernel_spatial = self.rep_points_kernel**2 o1, o2 = torch.chunk(offset.permute(0, 2, 3, 1).contiguous().view( -1, kernel_spatial, 2).transpose(1, 2).contiguous().view( offset.shape[0], *offset.shape[2:], kernel_spatial * 2).permute(0, 3, 1, 2), 2, dim=1) if every_n_local_step(100): for i in range(offset.shape[1]): add_histogram_summary('ct_rep_points_{}'.format(i), offset[:, [i]].detach().cpu()) radius = (self.rep_points_kernel - 1) // 2 h_base = hm.new_tensor([i for i in range(-radius, radius + 1)]) h_base = torch.stack( [h_base for _ in range(self.rep_points_kernel)], dim=1).view(1, kernel_spatial, 1, 1) w_base = hm.new_tensor([i for i in range(-radius, radius + 1)])[None] w_base = torch.cat([w_base for _ in range(self.rep_points_kernel)], dim=0).view(1, kernel_spatial, 1, 1) h_loc, w_loc = o1 + h_base, o2 + w_base wh = torch.cat([ w_loc.max(1, keepdim=True)[0] - w_loc.min(1, keepdim=True)[0], h_loc.max(1, keepdim=True)[0] - h_loc.min(1, keepdim=True)[0] ], dim=1) else: hm = self.hm(x) wh = self.wh(x) reg = self.reg(x) if self.use_reg_offset else None if self.use_exp_wh: wh = wh.exp() if every_n_local_step(500): add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu()) add_histogram_summary('ct_head_feat/wh', wh.detach().cpu()) if self.use_reg_offset: add_histogram_summary('ct_head_feat/reg', reg.detach().cpu()) if self.use_rep_points: hm_summary, wh_summary = self.hm, self.wh elif self.use_exp_hm: hm_summary, wh_summary = self.hm[-1].conv, self.wh[-1] else: hm_summary, wh_summary = self.hm[-1], self.wh[-1] add_histogram_summary('ct_head_param/hm', hm_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/wh', wh_summary.weight.detach().cpu(), is_param=True) if self.use_reg_offset: add_histogram_summary('ct_head_param/reg', self.reg[-1].weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/hm_grad', hm_summary.weight.grad.detach().cpu(), collect_type='none') add_histogram_summary('ct_head_param/wh_grad', wh_summary.weight.grad.detach().cpu(), collect_type='none') return hm, wh, reg
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, center_location): """ Args: pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w). pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w). heatmap: tensor, (batch, 80, h*w for all levels). wh: tensor, (batch, max_obj*level_num, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num). ind: tensor, (batch, max_obj*level_num). center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when using GIOU. Returns: """ if every_n_local_step(500): for lvl, hm in enumerate(pred_hm): hm_summary = hm.clone().detach().sigmoid_() add_feature_summary('centernet_heatmap_lv{}'.format(lvl), hm_summary.cpu().numpy()) H, W = pred_hm[0].shape[2:] level_num = len(pred_hm) pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm], dim=-1) pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2) ind_levels = ind.chunk(level_num, dim=1) pred_wh_pruned = [] for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels): pred_wh_pruned.append( tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl)) pred_wh_pruned = torch.cat(pred_wh_pruned, dim=1) # (batch, max_obj*level_num, 2) mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat((center_location - pred_wh_pruned / 2., center_location + pred_wh_pruned / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) box_tl = torch.clamp(center_location - wh / 2., min=0) boxes = torch.cat((box_tl, box_br), dim=2) mask_expand_4 = mask.repeat(1, 1, 2) wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4) else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss