Beispiel #1
0
    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):
        x = self.extract_feat(
            img)  # each tensor in this tuple is corresponding to a level.

        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            # TODO remove hardcode
            add_image_summary(
                'image/origin',
                tensor2imgs(img,
                            mean=[123.675, 116.28, 103.53],
                            std=[57.12, 58.395, 57.375],
                            to_rgb=True)[0], gt_bboxes[0].cpu(),
                gt_labels[0].cpu())
            if isinstance(x[0], tuple):
                feature_p = x[0]
            else:
                feature_p = x
            add_feature_summary('feature/x',
                                feature_p[-1].detach().cpu().numpy())

        outs = self.bbox_head(x)
        loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
        losses = self.bbox_head.loss(*loss_inputs,
                                     gt_bboxes_ignore=gt_bboxes_ignore)
        return losses
Beispiel #2
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            heatmap: tensor, (batch, cls, h, w).
            heights: tensor, (batch, 3, h, w).
            xoffset: tensor, (batch, 3, h, w).
            yoffset: tensor, (batch, 3, h, w).
            poses: tensor, (batch, 8, h, w).
            feat: tensor, (batch, c, h, w).
        """

        x = feats[-1]
        for i, (deconv_layer, shortcut_layer) in enumerate(
                zip(self.deconv_layers, self.shortcut_layers)):
            x = deconv_layer(x)

            if self.use_shortcut:
                shortcut = shortcut_layer(feats[-i - 2])
                if self.neg_shortcut:
                    shortcut = -1 * F.relu(-1 * shortcut)
                x = x + shortcut

                if every_n_local_step(500):
                    add_feature_summary('ct_head_shortcut_{}'.format(i),
                                        shortcut.detach().cpu().numpy())

        heatmap = self.hm(x)
        heights = self.heights_head(x)
        xoffset = self.xoffset_head(x)
        yoffset = self.yoffset_head(x)
        poses = self.pose_head(x)

        return heatmap, heights, xoffset, yoffset, poses, x
Beispiel #3
0
    def forward_train(self,
                      img,
                      img_meta,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None):
        """
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.

            img_meta (list[dict]): list of image info dict where each dict has:
                'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.

            gt_labels (list[Tensor]): class indices corresponding to each box

            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        x = self.extract_feat(img)

        losses = dict()
        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            add_image_summary(
                'image/origin',
                tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0],
                gt_bboxes[0].cpu(),
                gt_labels[0].cpu())
            add_feature_summary('feature/x', x[-1].detach().cpu().numpy())

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self)
            num_imgs = img.size(0)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            num_bgs = []
            num_fgs = []
            for i in range(num_imgs):
                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes[i],
                                                     gt_bboxes_ignore[i],
                                                     gt_labels[i])
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes[i],
                    gt_labels[i],
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)
                num_fgs.append(sampling_result.pos_inds.shape[0])
                num_bgs.append(sampling_result.neg_inds.shape[0])
            add_summary(prefix="sample_fast_rcnn_targets",
                        num_fgs=np.mean(num_fgs), num_bgs=np.mean(num_bgs))

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi([res.bboxes for res in sampling_results])
            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)
            cls_score, bbox_pred = self.bbox_head(bbox_feats)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets)
            losses.update(loss_bbox)

        # mask head forward and loss
        if self.with_mask:
            if not self.share_roi_extractor:
                pos_rois = bbox2roi(
                    [res.pos_bboxes for res in sampling_results])
                mask_feats = self.mask_roi_extractor(
                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
            else:
                pos_inds = []
                device = bbox_feats.device
                for res in sampling_results:
                    pos_inds.append(
                        torch.ones(
                            res.pos_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                    pos_inds.append(
                        torch.zeros(
                            res.neg_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                pos_inds = torch.cat(pos_inds)
                mask_feats = bbox_feats[pos_inds]

            if mask_feats.shape[0] > 0:
                mask_pred = self.mask_head(mask_feats)
                mask_targets = self.mask_head.get_target(
                    sampling_results, gt_masks, self.train_cfg.rcnn)
                pos_labels = torch.cat(
                    [res.pos_gt_labels for res in sampling_results])
                loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                                pos_labels)
                losses.update(loss_mask)

        return losses
Beispiel #4
0
    def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target,
                 centerness, wh_weight, hm_weight):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            pred_centerness: tensor or None, (batch, 1, h, w).
            heatmap: tensor, (batch, 80, h, w).
            box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            centerness: tensor or None, (batch, 1, h, w).
            wh_weight: tensor or None, (batch, 80, h, w).

        Returns:

        """
        if every_n_local_step(100):
            pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm),
                                          min=1e-4,
                                          max=1 - 1e-4)
            gt_hm_summary = heatmap.clone()
            if self.fovea_hm:
                if not self.only_merge:
                    pred_ctn_summary = torch.clamp(
                        torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4)
                    add_feature_summary(
                        'centernet/centerness',
                        pred_ctn_summary.detach().cpu().numpy(),
                        type='f')
                    add_feature_summary(
                        'centernet/merge',
                        (pred_ctn_summary *
                         pred_hm_summary).detach().cpu().numpy(),
                        type='max')

                add_feature_summary('centernet/gt_centerness',
                                    centerness.detach().cpu().numpy(),
                                    type='f')
                add_feature_summary('centernet/gt_merge',
                                    (centerness *
                                     gt_hm_summary).detach().cpu().numpy(),
                                    type='max')

            add_feature_summary('centernet/heatmap',
                                pred_hm_summary.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                gt_hm_summary.detach().cpu().numpy())

        H, W = pred_hm.shape[2:]
        if not self.fovea_hm:
            pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
            hm_weight = None if self.ct_version else hm_weight
            hm_loss = ct_focal_loss(pred_hm, heatmap,
                                    hm_weight=hm_weight) * self.hm_weight
            centerness_loss = hm_loss.new_tensor([0.])
            merge_loss = hm_loss.new_tensor([0.])
        else:
            care_mask = (heatmap >= 0).float()
            avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6
            if not self.only_merge:
                hm_loss = py_sigmoid_focal_loss(
                    pred_hm, heatmap, care_mask,
                    reduction='sum') / avg_factor * self.hm_weight

                pred_centerness = torch.clamp(torch.sigmoid(pred_centerness),
                                              min=1e-4,
                                              max=1 - 1e-4)
                centerness_loss = ct_focal_loss(
                    pred_centerness, centerness, gamma=2.) * self.ct_weight

                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm) * pred_centerness,
                                min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight
            else:
                hm_loss = pred_hm.new_tensor([0.])
                centerness_loss = pred_hm.new_tensor([0.])
                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm), min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight

        if not self.wh_agnostic:
            pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W)
            box_target = box_target.view(
                box_target.size(0) * pred_hm.size(1), 4, H, W)
        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.giou_weight

        return hm_loss, wh_loss, centerness_loss, merge_loss
Beispiel #5
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Beispiel #6
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            hm: tensor, (batch, 80, h, w).
            wh: tensor, (batch, 2, h, w).
            reg: None or tensor, (batch, 2, h, w).
        """

        x = feats[-1]
        if not self.use_dla:
            for i, (deconv_layer, shortcut_layer) in enumerate(
                    zip(self.deconv_layers, self.shortcut_layers)):
                x = deconv_layer(x)

                if self.use_shortcut:
                    shortcut = shortcut_layer(feats[-i - 2])
                    if self.neg_shortcut:
                        shortcut = -1 * F.relu(-1 * shortcut)
                    x = x + shortcut

                    if every_n_local_step(500):
                        add_feature_summary('ct_head_shortcut_{}'.format(i),
                                            shortcut.detach().cpu().numpy())

        if self.use_rep_points:
            share_feat = self.share_head_conv(x)
            o1, o2, mask = torch.chunk(self.wh(share_feat), 3, dim=1)
            offset = torch.cat(
                (o1, o2),
                dim=1)  # 18 channels for example, h1, w1, h2, w2, ...
            mask = torch.sigmoid(mask)
            hm = self.hm(share_feat, offset, mask)

            # seems like the code below will not improve the mAP, but it suppose to.
            kernel_spatial = self.rep_points_kernel**2
            o1, o2 = torch.chunk(offset.permute(0, 2, 3, 1).contiguous().view(
                -1, kernel_spatial, 2).transpose(1, 2).contiguous().view(
                    offset.shape[0], *offset.shape[2:],
                    kernel_spatial * 2).permute(0, 3, 1, 2),
                                 2,
                                 dim=1)

            if every_n_local_step(100):
                for i in range(offset.shape[1]):
                    add_histogram_summary('ct_rep_points_{}'.format(i),
                                          offset[:, [i]].detach().cpu())

            radius = (self.rep_points_kernel - 1) // 2
            h_base = hm.new_tensor([i for i in range(-radius, radius + 1)])
            h_base = torch.stack(
                [h_base for _ in range(self.rep_points_kernel)],
                dim=1).view(1, kernel_spatial, 1, 1)
            w_base = hm.new_tensor([i
                                    for i in range(-radius, radius + 1)])[None]
            w_base = torch.cat([w_base for _ in range(self.rep_points_kernel)],
                               dim=0).view(1, kernel_spatial, 1, 1)

            h_loc, w_loc = o1 + h_base, o2 + w_base
            wh = torch.cat([
                w_loc.max(1, keepdim=True)[0] - w_loc.min(1, keepdim=True)[0],
                h_loc.max(1, keepdim=True)[0] - h_loc.min(1, keepdim=True)[0]
            ],
                           dim=1)
        else:
            hm = self.hm(x)
            wh = self.wh(x)
        reg = self.reg(x) if self.use_reg_offset else None
        if self.use_exp_wh:
            wh = wh.exp()

        if every_n_local_step(500):
            add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu())
            add_histogram_summary('ct_head_feat/wh', wh.detach().cpu())
            if self.use_reg_offset:
                add_histogram_summary('ct_head_feat/reg', reg.detach().cpu())

            if self.use_rep_points:
                hm_summary, wh_summary = self.hm, self.wh
            elif self.use_exp_hm:
                hm_summary, wh_summary = self.hm[-1].conv, self.wh[-1]
            else:
                hm_summary, wh_summary = self.hm[-1], self.wh[-1]

            add_histogram_summary('ct_head_param/hm',
                                  hm_summary.weight.detach().cpu(),
                                  is_param=True)
            add_histogram_summary('ct_head_param/wh',
                                  wh_summary.weight.detach().cpu(),
                                  is_param=True)
            if self.use_reg_offset:
                add_histogram_summary('ct_head_param/reg',
                                      self.reg[-1].weight.detach().cpu(),
                                      is_param=True)

            add_histogram_summary('ct_head_param/hm_grad',
                                  hm_summary.weight.grad.detach().cpu(),
                                  collect_type='none')
            add_histogram_summary('ct_head_param/wh_grad',
                                  wh_summary.weight.grad.detach().cpu(),
                                  collect_type='none')

        return hm, wh, reg
Beispiel #7
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 center_location):
        """

        Args:
            pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w).
            pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h*w for all levels).
            wh: tensor, (batch, max_obj*level_num, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num).
            ind: tensor, (batch, max_obj*level_num).
            center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when
                using GIOU.

        Returns:

        """
        if every_n_local_step(500):
            for lvl, hm in enumerate(pred_hm):
                hm_summary = hm.clone().detach().sigmoid_()
                add_feature_summary('centernet_heatmap_lv{}'.format(lvl),
                                    hm_summary.cpu().numpy())

        H, W = pred_hm[0].shape[2:]
        level_num = len(pred_hm)
        pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm],
                            dim=-1)
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight

        # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2)
        ind_levels = ind.chunk(level_num, dim=1)
        pred_wh_pruned = []
        for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels):
            pred_wh_pruned.append(
                tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl))
        pred_wh_pruned = torch.cat(pred_wh_pruned,
                                   dim=1)  # (batch, max_obj*level_num, 2)
        mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat((center_location - pred_wh_pruned / 2.,
                                    center_location + pred_wh_pruned / 2.),
                                   dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            box_tl = torch.clamp(center_location - wh / 2., min=0)
            boxes = torch.cat((box_tl, box_br), dim=2)
            mask_expand_4 = mask.repeat(1, 1, 2)
            wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4)
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss