def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): x = self.extract_feat( img) # each tensor in this tuple is corresponding to a level. if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)): # TODO remove hardcode add_image_summary( 'image/origin', tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[57.12, 58.395, 57.375], to_rgb=True)[0], gt_bboxes[0].cpu(), gt_labels[0].cpu()) if isinstance(x[0], tuple): feature_p = x[0] else: feature_p = x add_feature_summary('feature/x', feature_p[-1].detach().cpu().numpy()) outs = self.bbox_head(x) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg) losses = self.bbox_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
def forward(self, x): # R50: 22.5ms x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] mask = None for i, layer_name in enumerate(self.res_layers): if self.attention_mask_layer_n and self.attention_mask_layer_n == i: x, mask = self.attention_mask(x) res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) if self.add_summay_every_n_step and every_n_local_step( self.add_summay_every_n_step): add_histogram_summary('resnet_feat_layer{}'.format(i + 1), x.detach().cpu()) add_histogram_summary( 'resnet_weight_layer{}'.format(i + 1), res_layer[-1].conv2.weight.detach().cpu(), is_param=True) if self.attention_mask_layer_n: return tuple(outs), mask if self.return_fc: x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x return tuple(outs)
def forward(self, img, img_k): """ Input: im_q: a batch of query images im_k: a batch of key images Output: logits, targets """ img_q = img if every_n_local_step(self.train_cfg.get('vis_freq', 100)): # add_image_summary('origin', img[0], type='0to1') add_image_summary('query', img_q[0], type='mean0var') add_image_summary('key', img_k[0], type='mean0var') # compute query features q = self.encoder_q(img_q) # queries: NxC q = nn.functional.normalize(q, dim=1) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) k = self.encoder_k(im_k) # keys: NxC k = nn.functional.normalize(k, dim=1) # undo shuffle k = self._batch_unshuffle_ddp(k, idx_unshuffle) # compute logits # Einstein sum is more intuitive # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # logits: Nx(1+len_queue) logits = torch.cat([l_pos, l_neg], dim=1) # apply temperature logits /= self.temperature # labels: positive key indicators labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # dequeue and enqueue self._dequeue_and_enqueue(k) if self.training: loss = self.criterion(logits, labels) acc1, acc5 = accuracy(logits, labels, topk=(1, 5)) add_summary('acc', top1=acc1[0], top5=acc5[0]) return dict(moco_loss=loss) return logits, labels
def forward(self, feats): """ Args: feats: list(tensor). Returns: hms: list(tensor), tensor <=> level. (batch, 80, h, w). whs: list(tensor), tensor <=> level. (batch, 2, h, w). """ hms, whs = [], [] if self.select_feat_index: feats = [feats[i] for i in self.select_feat_index] for feat in feats: hm = self.hm(feat) wh = self.wh(feat) if self.use_neg_wh: wh = wh * -1 if self.use_exp_wh: wh = wh.exp() hms.append(hm) whs.append(wh) if every_n_local_step(500): for lvl, (feat, hm, wh) in enumerate(zip(feats, hms, whs)): add_histogram_summary('mlct_head_feat_fpn_lv{}'.format(lvl), feat.detach().cpu()) add_histogram_summary( 'mlct_head_feat_heatmap_lv{}'.format(lvl), hm.detach().cpu()) add_histogram_summary('mlct_head_feat_wh_lv{}'.format(lvl), wh.detach().cpu()) hm_summary = self.hm[-1] wh_summary = self.wh[-1] add_histogram_summary('mlct_head_param_hm', hm_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('mlct_head_param_wh', wh_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('mlct_head_param_hm_grad', hm_summary.weight.grad.detach().cpu(), collect_type='none') add_histogram_summary('mlct_head_param_wh_grad', wh_summary.weight.grad.detach().cpu(), collect_type='none') return hms, whs
def forward_train(self, img, img_meta, gt_bboxes=None, gt_labels=None, gt_bboxes_ignore=None): if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)): add_image_summary( 'image/origin', tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0], gt_bboxes[0].cpu().numpy(), gt_labels[0].cpu().numpy()) x = self.extract_feat(img) rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses
def forward(self, feats): """ Args: feats: list(tensor). Returns: heatmap: tensor, (batch, cls, h, w). heights: tensor, (batch, 3, h, w). xoffset: tensor, (batch, 3, h, w). yoffset: tensor, (batch, 3, h, w). poses: tensor, (batch, 8, h, w). feat: tensor, (batch, c, h, w). """ x = feats[-1] for i, (deconv_layer, shortcut_layer) in enumerate( zip(self.deconv_layers, self.shortcut_layers)): x = deconv_layer(x) if self.use_shortcut: shortcut = shortcut_layer(feats[-i - 2]) if self.neg_shortcut: shortcut = -1 * F.relu(-1 * shortcut) x = x + shortcut if every_n_local_step(500): add_feature_summary('ct_head_shortcut_{}'.format(i), shortcut.detach().cpu().numpy()) heatmap = self.hm(x) heights = self.heights_head(x) xoffset = self.xoffset_head(x) yoffset = self.yoffset_head(x) poses = self.pose_head(x) return heatmap, heights, xoffset, yoffset, poses, x
def forward_single_level(self, x, idx): """ Retina-R50: R50 takes 24.0ms, FPN takes 3.62ms, HEAD takes 17.48ms, consuming 46.5ms. | | cls_feat | reg_feat | cls_score | bbox_pred | Total | | ---- | -------- | -------- | --------- | --------- | ------ | | P3 | 3.47ms | 3.46ms | 2.23ms | 0.24ms | 9.41ms | | P4 | 1.30ms | 1.26ms | 0.74ms | 0.11ms | 3.43ms | | P5 | NA | NA | NA | NA | 1.92ms | | P6 | NA | NA | NA | NA | 1.35ms | | P7 | NA | NA | NA | NA | 1.37ms | Args: x: tensor. Returns: """ # for a single level of multiply images. if isinstance(x, tuple): cls_feat, reg_feat = x[0], x[1] else: cls_feat, reg_feat = x, x if every_n_local_step(): add_histogram_summary('retina_head_feat/cls_in_{}'.format(idx), cls_feat.detach().cpu()) add_histogram_summary('retina_head_feat/reg_in_{}'.format(idx), reg_feat.detach().cpu()) for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: reg_feat = reg_conv(reg_feat) if every_n_local_step(): add_histogram_summary('retina_head_feat/cls_out_{}'.format(idx), cls_feat.detach().cpu()) add_histogram_summary('retina_head_feat/reg_out_{}'.format(idx), reg_feat.detach().cpu()) if idx == 0: for i, (cls_conv, reg_conv) in enumerate( zip(self.cls_convs, self.reg_convs)): add_histogram_summary( 'retina_head_param/cls_conv_{}'.format(i), cls_conv.conv.weight.detach().cpu(), is_param=True) add_histogram_summary( 'retina_head_param/reg_conv_{}'.format(i), reg_conv.conv.weight.detach().cpu(), is_param=True) add_histogram_summary( 'retina_head_param/cls_convs_grad', self.cls_convs[-1].conv.weight.grad.detach().cpu(), collect_type='none') add_histogram_summary( 'retina_head_param/reg_conv_grad', self.reg_convs[-1].conv.weight.grad.detach().cpu(), collect_type='none') cls_score = self.retina_cls(cls_feat) bbox_pred = self.retina_reg(reg_feat) return cls_score, bbox_pred
def forward_train(self, img, img_meta, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, proposals=None): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_meta (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals : override rpn proposals with custom proposals. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) losses = dict() if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)): add_image_summary( 'image/origin', tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0], gt_bboxes[0].cpu(), gt_labels[0].cpu()) add_feature_summary('feature/x', x[-1].detach().cpu().numpy()) # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss( *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) losses.update(rpn_losses) proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, proposal_cfg) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) else: proposal_list = proposals # assign gts and sample proposals if self.with_bbox or self.with_mask: bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self) num_imgs = img.size(0) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] num_bgs = [] num_fgs = [] for i in range(num_imgs): assign_result = bbox_assigner.assign(proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i]) sampling_result = bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes[i], gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) num_fgs.append(sampling_result.pos_inds.shape[0]) num_bgs.append(sampling_result.neg_inds.shape[0]) add_summary(prefix="sample_fast_rcnn_targets", num_fgs=np.mean(num_fgs), num_bgs=np.mean(num_bgs)) # bbox head forward and loss if self.with_bbox: rois = bbox2roi([res.bboxes for res in sampling_results]) # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, *bbox_targets) losses.update(loss_bbox) # mask head forward and loss if self.with_mask: if not self.share_roi_extractor: pos_rois = bbox2roi( [res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_feats = bbox_feats[pos_inds] if mask_feats.shape[0] > 0: mask_pred = self.mask_head(mask_feats) mask_targets = self.mask_head.get_target( sampling_results, gt_masks, self.train_cfg.rcnn) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels) losses.update(loss_mask) return losses
def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target, centerness, wh_weight, hm_weight): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). pred_centerness: tensor or None, (batch, 1, h, w). heatmap: tensor, (batch, 80, h, w). box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). centerness: tensor or None, (batch, 1, h, w). wh_weight: tensor or None, (batch, 80, h, w). Returns: """ if every_n_local_step(100): pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4) gt_hm_summary = heatmap.clone() if self.fovea_hm: if not self.only_merge: pred_ctn_summary = torch.clamp( torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) add_feature_summary( 'centernet/centerness', pred_ctn_summary.detach().cpu().numpy(), type='f') add_feature_summary( 'centernet/merge', (pred_ctn_summary * pred_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/gt_centerness', centerness.detach().cpu().numpy(), type='f') add_feature_summary('centernet/gt_merge', (centerness * gt_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/heatmap', pred_hm_summary.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', gt_hm_summary.detach().cpu().numpy()) H, W = pred_hm.shape[2:] if not self.fovea_hm: pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_weight = None if self.ct_version else hm_weight hm_loss = ct_focal_loss(pred_hm, heatmap, hm_weight=hm_weight) * self.hm_weight centerness_loss = hm_loss.new_tensor([0.]) merge_loss = hm_loss.new_tensor([0.]) else: care_mask = (heatmap >= 0).float() avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6 if not self.only_merge: hm_loss = py_sigmoid_focal_loss( pred_hm, heatmap, care_mask, reduction='sum') / avg_factor * self.hm_weight pred_centerness = torch.clamp(torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) centerness_loss = ct_focal_loss( pred_centerness, centerness, gamma=2.) * self.ct_weight merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm) * pred_centerness, min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight else: hm_loss = pred_hm.new_tensor([0.]) centerness_loss = pred_hm.new_tensor([0.]) merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight if not self.wh_agnostic: pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W) box_target = box_target.view( box_target.size(0) * pred_hm.size(1), 4, H, W) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.giou_weight return hm_loss, wh_loss, centerness_loss, merge_loss
def forward(self, feats): """ Args: feats: list(tensor). Returns: hm: tensor, (batch, 80, h, w). wh: tensor, (batch, 2, h, w). reg: None or tensor, (batch, 2, h, w). """ x = feats[-1] for i, (deconv_layer, shortcut_layer) in enumerate( zip(self.deconv_layers, self.shortcut_layers)): x = deconv_layer(x) if self.use_shortcut: shortcut = shortcut_layer(feats[-i - 2]) x = x + shortcut if not self.predict_together: hm = self.hm(x) wh = self.wh(x) else: N, _, H, W = x.shape hmwh = self.hmwh(x).view(N, -1, 5, H, W).transpose(1, 2) hm = hmwh[:, 0] wh = hmwh[:, 1:5].transpose(1, 2).contiguous().view(N, -1, H, W) wh = wh.exp() if self.use_exp_wh else F.relu(wh) if self.wh_offset_base is not None: if isinstance(self.wh_offset_base, nn.Module): wh = self.wh_offset_base(wh) else: wh *= self.wh_offset_base if self.norm_wh: N, _, H, W = wh.shape wh = wh.view(N, -1, 4, H, W).transpose(1, 2) wh[:, [0, 2]] = wh[:, [0, 2]] * hm.size(3) wh[:, [1, 3]] = wh[:, [1, 3]] * hm.size(2) wh = wh.transpose(1, 2).view(N, -1, H, W) if every_n_local_step(100): add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu()) add_histogram_summary('ct_head_feat/wh', wh.detach().cpu()) if not self.predict_together: hm_summary = self.hm[-1] wh_summary = self.wh[-1] add_histogram_summary('ct_head_param/hm', hm_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/wh', wh_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/hm_grad', hm_summary.weight.grad.detach().cpu(), collect_type='none') add_histogram_summary('ct_head_param/wh_grad', wh_summary.weight.grad.detach().cpu(), collect_type='none') centerness = None if self.fovea_hm and not self.only_merge: centerness = self.centerness(x) return hm, wh, centerness
def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh, reg_mask, ind, reg_offset, center_location): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 2, h, w). pred_reg_offset: None or tensor, (batch, 2, h, w). heatmap: tensor, (batch, 80, h, w). wh: tensor, (batch, max_obj, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj). ind: tensor, (batch, max_obj). reg_offset: tensor, (batch, max_obj, 2). center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU. Returns: """ H, W = pred_hm.shape[2:] pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight # (batch, 2, h, w) => (batch, max_obj, 2) pred = tranpose_and_gather_feat(pred_wh, ind) mask = reg_mask.unsqueeze(2).expand_as(pred).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat( (center_location - pred / 2., center_location + pred / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) boxes = torch.cat( (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2) mask_no_expand = mask[:, :, 0] wh_loss = giou_loss(pred_boxes, boxes, mask_no_expand) * self.giou_weight else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1(pred, wh, mask, avg_factor=avg_factor) * self.wh_weight off_loss = hm_loss.new_tensor(0.) if self.use_reg_offset: pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind) off_loss = weighted_l1( pred_reg, reg_offset, mask, avg_factor=avg_factor) * self.off_weight add_summary('centernet', gt_reg_off=reg_offset[reg_offset > 0].mean().item()) if every_n_local_step(500): add_feature_summary('centernet/heatmap', pred_hm.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', heatmap.detach().cpu().numpy()) if self.use_reg_offset: add_feature_summary('centernet/reg_offset', pred_reg_offset.detach().cpu().numpy()) return hm_loss, wh_loss, off_loss
def forward(self, feats): """ Args: feats: list(tensor). Returns: hm: tensor, (batch, 80, h, w). wh: tensor, (batch, 2, h, w). reg: None or tensor, (batch, 2, h, w). """ x = feats[-1] if not self.use_dla: for i, (deconv_layer, shortcut_layer) in enumerate( zip(self.deconv_layers, self.shortcut_layers)): x = deconv_layer(x) if self.use_shortcut: shortcut = shortcut_layer(feats[-i - 2]) if self.neg_shortcut: shortcut = -1 * F.relu(-1 * shortcut) x = x + shortcut if every_n_local_step(500): add_feature_summary('ct_head_shortcut_{}'.format(i), shortcut.detach().cpu().numpy()) if self.use_rep_points: share_feat = self.share_head_conv(x) o1, o2, mask = torch.chunk(self.wh(share_feat), 3, dim=1) offset = torch.cat( (o1, o2), dim=1) # 18 channels for example, h1, w1, h2, w2, ... mask = torch.sigmoid(mask) hm = self.hm(share_feat, offset, mask) # seems like the code below will not improve the mAP, but it suppose to. kernel_spatial = self.rep_points_kernel**2 o1, o2 = torch.chunk(offset.permute(0, 2, 3, 1).contiguous().view( -1, kernel_spatial, 2).transpose(1, 2).contiguous().view( offset.shape[0], *offset.shape[2:], kernel_spatial * 2).permute(0, 3, 1, 2), 2, dim=1) if every_n_local_step(100): for i in range(offset.shape[1]): add_histogram_summary('ct_rep_points_{}'.format(i), offset[:, [i]].detach().cpu()) radius = (self.rep_points_kernel - 1) // 2 h_base = hm.new_tensor([i for i in range(-radius, radius + 1)]) h_base = torch.stack( [h_base for _ in range(self.rep_points_kernel)], dim=1).view(1, kernel_spatial, 1, 1) w_base = hm.new_tensor([i for i in range(-radius, radius + 1)])[None] w_base = torch.cat([w_base for _ in range(self.rep_points_kernel)], dim=0).view(1, kernel_spatial, 1, 1) h_loc, w_loc = o1 + h_base, o2 + w_base wh = torch.cat([ w_loc.max(1, keepdim=True)[0] - w_loc.min(1, keepdim=True)[0], h_loc.max(1, keepdim=True)[0] - h_loc.min(1, keepdim=True)[0] ], dim=1) else: hm = self.hm(x) wh = self.wh(x) reg = self.reg(x) if self.use_reg_offset else None if self.use_exp_wh: wh = wh.exp() if every_n_local_step(500): add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu()) add_histogram_summary('ct_head_feat/wh', wh.detach().cpu()) if self.use_reg_offset: add_histogram_summary('ct_head_feat/reg', reg.detach().cpu()) if self.use_rep_points: hm_summary, wh_summary = self.hm, self.wh elif self.use_exp_hm: hm_summary, wh_summary = self.hm[-1].conv, self.wh[-1] else: hm_summary, wh_summary = self.hm[-1], self.wh[-1] add_histogram_summary('ct_head_param/hm', hm_summary.weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/wh', wh_summary.weight.detach().cpu(), is_param=True) if self.use_reg_offset: add_histogram_summary('ct_head_param/reg', self.reg[-1].weight.detach().cpu(), is_param=True) add_histogram_summary('ct_head_param/hm_grad', hm_summary.weight.grad.detach().cpu(), collect_type='none') add_histogram_summary('ct_head_param/wh_grad', wh_summary.weight.grad.detach().cpu(), collect_type='none') return hm, wh, reg
def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind, center_location): """ Args: pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w). pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w). heatmap: tensor, (batch, 80, h*w for all levels). wh: tensor, (batch, max_obj*level_num, 2). reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num). ind: tensor, (batch, max_obj*level_num). center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when using GIOU. Returns: """ if every_n_local_step(500): for lvl, hm in enumerate(pred_hm): hm_summary = hm.clone().detach().sigmoid_() add_feature_summary('centernet_heatmap_lv{}'.format(lvl), hm_summary.cpu().numpy()) H, W = pred_hm[0].shape[2:] level_num = len(pred_hm) pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm], dim=-1) pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2) ind_levels = ind.chunk(level_num, dim=1) pred_wh_pruned = [] for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels): pred_wh_pruned.append( tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl)) pred_wh_pruned = torch.cat(pred_wh_pruned, dim=1) # (batch, max_obj*level_num, 2) mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float() avg_factor = mask.sum() + 1e-4 if self.use_giou: pred_boxes = torch.cat((center_location - pred_wh_pruned / 2., center_location + pred_wh_pruned / 2.), dim=2) box_br = center_location + wh / 2. box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1) box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1) box_tl = torch.clamp(center_location - wh / 2., min=0) boxes = torch.cat((box_tl, box_br), dim=2) mask_expand_4 = mask.repeat(1, 1, 2) wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4) else: if self.use_smooth_l1: wh_loss = smooth_l1_loss( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight else: wh_loss = weighted_l1( pred_wh_pruned, wh, mask, avg_factor=avg_factor) * self.wh_weight return hm_loss, wh_loss
def target_single_image(self, gt_boxes, gt_labels, feat_shapes, obj_sizes_of_interest): """ Args: gt_boxes: tensor, tensor <=> img, (num_gt, 4). gt_labels: tensor, tensor <=> img, (num_gt,). feat_shape: list(tuple). tuple <=> level. obj_sizes_of_interest: tensor, (level_num, 2). Returns: all_heatmap: tensor, tensor <=> img, (80, h*w for all levels). all_wh: tensor, tensor <=> img, (max_obj*level_num, 2). all_reg_mask: tensor, tensor <=> img, (max_obj*level_num,). all_ind: tensor, tensor <=> img, (max_obj*level_num,). all_center_location: tensor or None, tensor <=> img, (max_obj*level_num, 2). """ level_size = len(self.fpn_strides) all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location = [], [], [], [], [] max_wh_target = torch.max(gt_boxes[:, 3] - gt_boxes[:, 1], gt_boxes[:, 2] - gt_boxes[:, 0]).unsqueeze(-1).repeat( 1, level_size) is_cared_in_the_level = \ (max_wh_target >= obj_sizes_of_interest[:, 0]) & \ (max_wh_target <= obj_sizes_of_interest[:, 1]) # (gt_num, level_num) cared_gt_num_per_level = {} for lvl in range(level_size): cared_gt_num_per_level['num_lv{}'.format(lvl)] = \ is_cared_in_the_level[:, lvl].sum().item() add_summary('centernet', **cared_gt_num_per_level) for lvl in range(level_size): # get target for a single level of a single image. output_h, output_w = feat_shapes[lvl] heatmap = gt_boxes.new_zeros( (self.num_classes, output_h, output_w)) wh = gt_boxes.new_zeros((self.max_objs, 2)) reg_mask = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.uint8) ind = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.long) center_location = None if self.use_giou: center_location = gt_boxes.new_zeros((self.max_objs, 2)) gt_boxes_in_lvl = gt_boxes[is_cared_in_the_level[:, lvl]] if gt_boxes_in_lvl.size(0) > 0: gt_boxes_in_lvl /= self.fpn_strides[lvl] gt_boxes_in_lvl[:, [0, 2]] = torch.clamp( gt_boxes_in_lvl[:, [0, 2]], 0, output_w - 1) gt_boxes_in_lvl[:, [1, 3]] = torch.clamp( gt_boxes_in_lvl[:, [1, 3]], 0, output_h - 1) hs = gt_boxes_in_lvl[:, 3] - gt_boxes_in_lvl[:, 1] ws = gt_boxes_in_lvl[:, 2] - gt_boxes_in_lvl[:, 0] if every_n_local_step(500): add_histogram_summary('mlct_head_hs_lv{}'.format(lvl), hs.detach().cpu(), collect_type='none') add_histogram_summary('mlct_head_ws_lv{}'.format(lvl), ws.detach().cpu(), collect_type='none') for k in range(gt_boxes_in_lvl.shape[0]): cls_id = gt_labels[k] - 1 h, w = hs[k], ws[k] if h > 0 and w > 0: radius = gaussian_radius((h.ceil(), w.ceil())) radius = max(0, int(radius.item())) center = gt_boxes.new_tensor([ (gt_boxes_in_lvl[k, 0] + gt_boxes_in_lvl[k, 2]) / 2, (gt_boxes_in_lvl[k, 1] + gt_boxes_in_lvl[k, 3]) / 2 ]) # no peak will fall between pixels ct_int = center.to(torch.int) draw_umich_gaussian(heatmap[cls_id], ct_int, radius) # directly predict the width and height wh[k] = wh.new_tensor([1. * w, 1. * h]) ind[k] = ct_int[1] * output_w + ct_int[0] if self.use_giou: center_location[k] = center reg_mask[k] = 1 all_heatmap.append(heatmap.view(heatmap.shape[0], -1)) all_wh.append(wh) all_reg_mask.append(reg_mask) all_ind.append(ind) all_center_location.append(center_location) all_heatmap, all_reg_mask, all_ind = [ torch.cat(x, dim=-1) for x in [all_heatmap, all_reg_mask, all_ind] ] all_wh = torch.cat(all_wh, dim=0) if self.use_giou: all_center_location = torch.cat(all_center_location, dim=0) return all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location