Exemplo n.º 1
0
    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):
        x = self.extract_feat(
            img)  # each tensor in this tuple is corresponding to a level.

        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            # TODO remove hardcode
            add_image_summary(
                'image/origin',
                tensor2imgs(img,
                            mean=[123.675, 116.28, 103.53],
                            std=[57.12, 58.395, 57.375],
                            to_rgb=True)[0], gt_bboxes[0].cpu(),
                gt_labels[0].cpu())
            if isinstance(x[0], tuple):
                feature_p = x[0]
            else:
                feature_p = x
            add_feature_summary('feature/x',
                                feature_p[-1].detach().cpu().numpy())

        outs = self.bbox_head(x)
        loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
        losses = self.bbox_head.loss(*loss_inputs,
                                     gt_bboxes_ignore=gt_bboxes_ignore)
        return losses
Exemplo n.º 2
0
    def forward(self, x):
        # R50: 22.5ms
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        mask = None
        for i, layer_name in enumerate(self.res_layers):
            if self.attention_mask_layer_n and self.attention_mask_layer_n == i:
                x, mask = self.attention_mask(x)
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)

            if self.add_summay_every_n_step and every_n_local_step(
                    self.add_summay_every_n_step):
                add_histogram_summary('resnet_feat_layer{}'.format(i + 1),
                                      x.detach().cpu())
                add_histogram_summary(
                    'resnet_weight_layer{}'.format(i + 1),
                    res_layer[-1].conv2.weight.detach().cpu(),
                    is_param=True)

        if self.attention_mask_layer_n:
            return tuple(outs), mask

        if self.return_fc:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

        return tuple(outs)
Exemplo n.º 3
0
    def forward(self, img, img_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        img_q = img
        if every_n_local_step(self.train_cfg.get('vis_freq', 100)):
            # add_image_summary('origin', img[0], type='0to1')
            add_image_summary('query', img_q[0], type='mean0var')
            add_image_summary('key', img_k[0], type='mean0var')

        # compute query features
        q = self.encoder_q(img_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+len_queue)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.temperature

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        if self.training:
            loss = self.criterion(logits, labels)
            acc1, acc5 = accuracy(logits, labels, topk=(1, 5))
            add_summary('acc', top1=acc1[0], top5=acc5[0])
            return dict(moco_loss=loss)

        return logits, labels
Exemplo n.º 4
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            hms: list(tensor), tensor <=> level. (batch, 80, h, w).
            whs: list(tensor), tensor <=> level. (batch, 2, h, w).
        """
        hms, whs = [], []
        if self.select_feat_index:
            feats = [feats[i] for i in self.select_feat_index]

        for feat in feats:
            hm = self.hm(feat)
            wh = self.wh(feat)
            if self.use_neg_wh:
                wh = wh * -1
            if self.use_exp_wh:
                wh = wh.exp()
            hms.append(hm)
            whs.append(wh)

        if every_n_local_step(500):
            for lvl, (feat, hm, wh) in enumerate(zip(feats, hms, whs)):
                add_histogram_summary('mlct_head_feat_fpn_lv{}'.format(lvl),
                                      feat.detach().cpu())
                add_histogram_summary(
                    'mlct_head_feat_heatmap_lv{}'.format(lvl),
                    hm.detach().cpu())
                add_histogram_summary('mlct_head_feat_wh_lv{}'.format(lvl),
                                      wh.detach().cpu())

            hm_summary = self.hm[-1]
            wh_summary = self.wh[-1]
            add_histogram_summary('mlct_head_param_hm',
                                  hm_summary.weight.detach().cpu(),
                                  is_param=True)
            add_histogram_summary('mlct_head_param_wh',
                                  wh_summary.weight.detach().cpu(),
                                  is_param=True)

            add_histogram_summary('mlct_head_param_hm_grad',
                                  hm_summary.weight.grad.detach().cpu(),
                                  collect_type='none')
            add_histogram_summary('mlct_head_param_wh_grad',
                                  wh_summary.weight.grad.detach().cpu(),
                                  collect_type='none')

        return hms, whs
Exemplo n.º 5
0
    def forward_train(self,
                      img,
                      img_meta,
                      gt_bboxes=None,
                      gt_labels=None,
                      gt_bboxes_ignore=None):
        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            add_image_summary(
                'image/origin',
                tensor2imgs(img,
                            mean=[123.675, 116.28, 103.53],
                            std=[58.395, 57.12, 57.375])[0],
                gt_bboxes[0].cpu().numpy(), gt_labels[0].cpu().numpy())

        x = self.extract_feat(img)
        rpn_outs = self.rpn_head(x)

        rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
        losses = self.rpn_head.loss(*rpn_loss_inputs,
                                    gt_bboxes_ignore=gt_bboxes_ignore)
        return losses
Exemplo n.º 6
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            heatmap: tensor, (batch, cls, h, w).
            heights: tensor, (batch, 3, h, w).
            xoffset: tensor, (batch, 3, h, w).
            yoffset: tensor, (batch, 3, h, w).
            poses: tensor, (batch, 8, h, w).
            feat: tensor, (batch, c, h, w).
        """

        x = feats[-1]
        for i, (deconv_layer, shortcut_layer) in enumerate(
                zip(self.deconv_layers, self.shortcut_layers)):
            x = deconv_layer(x)

            if self.use_shortcut:
                shortcut = shortcut_layer(feats[-i - 2])
                if self.neg_shortcut:
                    shortcut = -1 * F.relu(-1 * shortcut)
                x = x + shortcut

                if every_n_local_step(500):
                    add_feature_summary('ct_head_shortcut_{}'.format(i),
                                        shortcut.detach().cpu().numpy())

        heatmap = self.hm(x)
        heights = self.heights_head(x)
        xoffset = self.xoffset_head(x)
        yoffset = self.yoffset_head(x)
        poses = self.pose_head(x)

        return heatmap, heights, xoffset, yoffset, poses, x
Exemplo n.º 7
0
    def forward_single_level(self, x, idx):
        """
        Retina-R50: R50 takes 24.0ms, FPN takes 3.62ms, HEAD takes 17.48ms, consuming 46.5ms.

        |      | cls_feat | reg_feat | cls_score | bbox_pred | Total  |
        | ---- | -------- | -------- | --------- | --------- | ------ |
        | P3   | 3.47ms   | 3.46ms   | 2.23ms    | 0.24ms    | 9.41ms |
        | P4   | 1.30ms   | 1.26ms   | 0.74ms    | 0.11ms    | 3.43ms |
        | P5   | NA       | NA       | NA        | NA        | 1.92ms |
        | P6   | NA       | NA       | NA        | NA        | 1.35ms |
        | P7   | NA       | NA       | NA        | NA        | 1.37ms |

        Args:
            x: tensor.

        Returns:

        """
        # for a single level of multiply images.
        if isinstance(x, tuple):
            cls_feat, reg_feat = x[0], x[1]
        else:
            cls_feat, reg_feat = x, x

        if every_n_local_step():
            add_histogram_summary('retina_head_feat/cls_in_{}'.format(idx),
                                  cls_feat.detach().cpu())
            add_histogram_summary('retina_head_feat/reg_in_{}'.format(idx),
                                  reg_feat.detach().cpu())

        for cls_conv in self.cls_convs:
            cls_feat = cls_conv(cls_feat)
        for reg_conv in self.reg_convs:
            reg_feat = reg_conv(reg_feat)

        if every_n_local_step():
            add_histogram_summary('retina_head_feat/cls_out_{}'.format(idx),
                                  cls_feat.detach().cpu())
            add_histogram_summary('retina_head_feat/reg_out_{}'.format(idx),
                                  reg_feat.detach().cpu())
            if idx == 0:
                for i, (cls_conv, reg_conv) in enumerate(
                        zip(self.cls_convs, self.reg_convs)):
                    add_histogram_summary(
                        'retina_head_param/cls_conv_{}'.format(i),
                        cls_conv.conv.weight.detach().cpu(),
                        is_param=True)
                    add_histogram_summary(
                        'retina_head_param/reg_conv_{}'.format(i),
                        reg_conv.conv.weight.detach().cpu(),
                        is_param=True)

                add_histogram_summary(
                    'retina_head_param/cls_convs_grad',
                    self.cls_convs[-1].conv.weight.grad.detach().cpu(),
                    collect_type='none')
                add_histogram_summary(
                    'retina_head_param/reg_conv_grad',
                    self.reg_convs[-1].conv.weight.grad.detach().cpu(),
                    collect_type='none')

        cls_score = self.retina_cls(cls_feat)
        bbox_pred = self.retina_reg(reg_feat)
        return cls_score, bbox_pred
Exemplo n.º 8
0
    def forward_train(self,
                      img,
                      img_meta,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None):
        """
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.

            img_meta (list[dict]): list of image info dict where each dict has:
                'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.

            gt_labels (list[Tensor]): class indices corresponding to each box

            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        x = self.extract_feat(img)

        losses = dict()
        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            add_image_summary(
                'image/origin',
                tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0],
                gt_bboxes[0].cpu(),
                gt_labels[0].cpu())
            add_feature_summary('feature/x', x[-1].detach().cpu().numpy())

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self)
            num_imgs = img.size(0)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            num_bgs = []
            num_fgs = []
            for i in range(num_imgs):
                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes[i],
                                                     gt_bboxes_ignore[i],
                                                     gt_labels[i])
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes[i],
                    gt_labels[i],
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)
                num_fgs.append(sampling_result.pos_inds.shape[0])
                num_bgs.append(sampling_result.neg_inds.shape[0])
            add_summary(prefix="sample_fast_rcnn_targets",
                        num_fgs=np.mean(num_fgs), num_bgs=np.mean(num_bgs))

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi([res.bboxes for res in sampling_results])
            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)
            cls_score, bbox_pred = self.bbox_head(bbox_feats)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets)
            losses.update(loss_bbox)

        # mask head forward and loss
        if self.with_mask:
            if not self.share_roi_extractor:
                pos_rois = bbox2roi(
                    [res.pos_bboxes for res in sampling_results])
                mask_feats = self.mask_roi_extractor(
                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
            else:
                pos_inds = []
                device = bbox_feats.device
                for res in sampling_results:
                    pos_inds.append(
                        torch.ones(
                            res.pos_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                    pos_inds.append(
                        torch.zeros(
                            res.neg_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                pos_inds = torch.cat(pos_inds)
                mask_feats = bbox_feats[pos_inds]

            if mask_feats.shape[0] > 0:
                mask_pred = self.mask_head(mask_feats)
                mask_targets = self.mask_head.get_target(
                    sampling_results, gt_masks, self.train_cfg.rcnn)
                pos_labels = torch.cat(
                    [res.pos_gt_labels for res in sampling_results])
                loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                                pos_labels)
                losses.update(loss_mask)

        return losses
Exemplo n.º 9
0
    def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target,
                 centerness, wh_weight, hm_weight):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            pred_centerness: tensor or None, (batch, 1, h, w).
            heatmap: tensor, (batch, 80, h, w).
            box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            centerness: tensor or None, (batch, 1, h, w).
            wh_weight: tensor or None, (batch, 80, h, w).

        Returns:

        """
        if every_n_local_step(100):
            pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm),
                                          min=1e-4,
                                          max=1 - 1e-4)
            gt_hm_summary = heatmap.clone()
            if self.fovea_hm:
                if not self.only_merge:
                    pred_ctn_summary = torch.clamp(
                        torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4)
                    add_feature_summary(
                        'centernet/centerness',
                        pred_ctn_summary.detach().cpu().numpy(),
                        type='f')
                    add_feature_summary(
                        'centernet/merge',
                        (pred_ctn_summary *
                         pred_hm_summary).detach().cpu().numpy(),
                        type='max')

                add_feature_summary('centernet/gt_centerness',
                                    centerness.detach().cpu().numpy(),
                                    type='f')
                add_feature_summary('centernet/gt_merge',
                                    (centerness *
                                     gt_hm_summary).detach().cpu().numpy(),
                                    type='max')

            add_feature_summary('centernet/heatmap',
                                pred_hm_summary.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                gt_hm_summary.detach().cpu().numpy())

        H, W = pred_hm.shape[2:]
        if not self.fovea_hm:
            pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
            hm_weight = None if self.ct_version else hm_weight
            hm_loss = ct_focal_loss(pred_hm, heatmap,
                                    hm_weight=hm_weight) * self.hm_weight
            centerness_loss = hm_loss.new_tensor([0.])
            merge_loss = hm_loss.new_tensor([0.])
        else:
            care_mask = (heatmap >= 0).float()
            avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6
            if not self.only_merge:
                hm_loss = py_sigmoid_focal_loss(
                    pred_hm, heatmap, care_mask,
                    reduction='sum') / avg_factor * self.hm_weight

                pred_centerness = torch.clamp(torch.sigmoid(pred_centerness),
                                              min=1e-4,
                                              max=1 - 1e-4)
                centerness_loss = ct_focal_loss(
                    pred_centerness, centerness, gamma=2.) * self.ct_weight

                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm) * pred_centerness,
                                min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight
            else:
                hm_loss = pred_hm.new_tensor([0.])
                centerness_loss = pred_hm.new_tensor([0.])
                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm), min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight

        if not self.wh_agnostic:
            pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W)
            box_target = box_target.view(
                box_target.size(0) * pred_hm.size(1), 4, H, W)
        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.giou_weight

        return hm_loss, wh_loss, centerness_loss, merge_loss
Exemplo n.º 10
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            hm: tensor, (batch, 80, h, w).
            wh: tensor, (batch, 2, h, w).
            reg: None or tensor, (batch, 2, h, w).
        """
        x = feats[-1]
        for i, (deconv_layer, shortcut_layer) in enumerate(
                zip(self.deconv_layers, self.shortcut_layers)):
            x = deconv_layer(x)
            if self.use_shortcut:
                shortcut = shortcut_layer(feats[-i - 2])
                x = x + shortcut

        if not self.predict_together:
            hm = self.hm(x)
            wh = self.wh(x)
        else:
            N, _, H, W = x.shape
            hmwh = self.hmwh(x).view(N, -1, 5, H, W).transpose(1, 2)
            hm = hmwh[:, 0]
            wh = hmwh[:, 1:5].transpose(1, 2).contiguous().view(N, -1, H, W)
        wh = wh.exp() if self.use_exp_wh else F.relu(wh)

        if self.wh_offset_base is not None:
            if isinstance(self.wh_offset_base, nn.Module):
                wh = self.wh_offset_base(wh)
            else:
                wh *= self.wh_offset_base

        if self.norm_wh:
            N, _, H, W = wh.shape
            wh = wh.view(N, -1, 4, H, W).transpose(1, 2)
            wh[:, [0, 2]] = wh[:, [0, 2]] * hm.size(3)
            wh[:, [1, 3]] = wh[:, [1, 3]] * hm.size(2)
            wh = wh.transpose(1, 2).view(N, -1, H, W)

        if every_n_local_step(100):
            add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu())
            add_histogram_summary('ct_head_feat/wh', wh.detach().cpu())

            if not self.predict_together:
                hm_summary = self.hm[-1]
                wh_summary = self.wh[-1]
                add_histogram_summary('ct_head_param/hm',
                                      hm_summary.weight.detach().cpu(),
                                      is_param=True)
                add_histogram_summary('ct_head_param/wh',
                                      wh_summary.weight.detach().cpu(),
                                      is_param=True)

                add_histogram_summary('ct_head_param/hm_grad',
                                      hm_summary.weight.grad.detach().cpu(),
                                      collect_type='none')
                add_histogram_summary('ct_head_param/wh_grad',
                                      wh_summary.weight.grad.detach().cpu(),
                                      collect_type='none')
        centerness = None
        if self.fovea_hm and not self.only_merge:
            centerness = self.centerness(x)

        return hm, wh, centerness
Exemplo n.º 11
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Exemplo n.º 12
0
    def forward(self, feats):
        """

        Args:
            feats: list(tensor).

        Returns:
            hm: tensor, (batch, 80, h, w).
            wh: tensor, (batch, 2, h, w).
            reg: None or tensor, (batch, 2, h, w).
        """

        x = feats[-1]
        if not self.use_dla:
            for i, (deconv_layer, shortcut_layer) in enumerate(
                    zip(self.deconv_layers, self.shortcut_layers)):
                x = deconv_layer(x)

                if self.use_shortcut:
                    shortcut = shortcut_layer(feats[-i - 2])
                    if self.neg_shortcut:
                        shortcut = -1 * F.relu(-1 * shortcut)
                    x = x + shortcut

                    if every_n_local_step(500):
                        add_feature_summary('ct_head_shortcut_{}'.format(i),
                                            shortcut.detach().cpu().numpy())

        if self.use_rep_points:
            share_feat = self.share_head_conv(x)
            o1, o2, mask = torch.chunk(self.wh(share_feat), 3, dim=1)
            offset = torch.cat(
                (o1, o2),
                dim=1)  # 18 channels for example, h1, w1, h2, w2, ...
            mask = torch.sigmoid(mask)
            hm = self.hm(share_feat, offset, mask)

            # seems like the code below will not improve the mAP, but it suppose to.
            kernel_spatial = self.rep_points_kernel**2
            o1, o2 = torch.chunk(offset.permute(0, 2, 3, 1).contiguous().view(
                -1, kernel_spatial, 2).transpose(1, 2).contiguous().view(
                    offset.shape[0], *offset.shape[2:],
                    kernel_spatial * 2).permute(0, 3, 1, 2),
                                 2,
                                 dim=1)

            if every_n_local_step(100):
                for i in range(offset.shape[1]):
                    add_histogram_summary('ct_rep_points_{}'.format(i),
                                          offset[:, [i]].detach().cpu())

            radius = (self.rep_points_kernel - 1) // 2
            h_base = hm.new_tensor([i for i in range(-radius, radius + 1)])
            h_base = torch.stack(
                [h_base for _ in range(self.rep_points_kernel)],
                dim=1).view(1, kernel_spatial, 1, 1)
            w_base = hm.new_tensor([i
                                    for i in range(-radius, radius + 1)])[None]
            w_base = torch.cat([w_base for _ in range(self.rep_points_kernel)],
                               dim=0).view(1, kernel_spatial, 1, 1)

            h_loc, w_loc = o1 + h_base, o2 + w_base
            wh = torch.cat([
                w_loc.max(1, keepdim=True)[0] - w_loc.min(1, keepdim=True)[0],
                h_loc.max(1, keepdim=True)[0] - h_loc.min(1, keepdim=True)[0]
            ],
                           dim=1)
        else:
            hm = self.hm(x)
            wh = self.wh(x)
        reg = self.reg(x) if self.use_reg_offset else None
        if self.use_exp_wh:
            wh = wh.exp()

        if every_n_local_step(500):
            add_histogram_summary('ct_head_feat/heatmap', hm.detach().cpu())
            add_histogram_summary('ct_head_feat/wh', wh.detach().cpu())
            if self.use_reg_offset:
                add_histogram_summary('ct_head_feat/reg', reg.detach().cpu())

            if self.use_rep_points:
                hm_summary, wh_summary = self.hm, self.wh
            elif self.use_exp_hm:
                hm_summary, wh_summary = self.hm[-1].conv, self.wh[-1]
            else:
                hm_summary, wh_summary = self.hm[-1], self.wh[-1]

            add_histogram_summary('ct_head_param/hm',
                                  hm_summary.weight.detach().cpu(),
                                  is_param=True)
            add_histogram_summary('ct_head_param/wh',
                                  wh_summary.weight.detach().cpu(),
                                  is_param=True)
            if self.use_reg_offset:
                add_histogram_summary('ct_head_param/reg',
                                      self.reg[-1].weight.detach().cpu(),
                                      is_param=True)

            add_histogram_summary('ct_head_param/hm_grad',
                                  hm_summary.weight.grad.detach().cpu(),
                                  collect_type='none')
            add_histogram_summary('ct_head_param/wh_grad',
                                  wh_summary.weight.grad.detach().cpu(),
                                  collect_type='none')

        return hm, wh, reg
Exemplo n.º 13
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 center_location):
        """

        Args:
            pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w).
            pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h*w for all levels).
            wh: tensor, (batch, max_obj*level_num, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num).
            ind: tensor, (batch, max_obj*level_num).
            center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when
                using GIOU.

        Returns:

        """
        if every_n_local_step(500):
            for lvl, hm in enumerate(pred_hm):
                hm_summary = hm.clone().detach().sigmoid_()
                add_feature_summary('centernet_heatmap_lv{}'.format(lvl),
                                    hm_summary.cpu().numpy())

        H, W = pred_hm[0].shape[2:]
        level_num = len(pred_hm)
        pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm],
                            dim=-1)
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight

        # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2)
        ind_levels = ind.chunk(level_num, dim=1)
        pred_wh_pruned = []
        for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels):
            pred_wh_pruned.append(
                tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl))
        pred_wh_pruned = torch.cat(pred_wh_pruned,
                                   dim=1)  # (batch, max_obj*level_num, 2)
        mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat((center_location - pred_wh_pruned / 2.,
                                    center_location + pred_wh_pruned / 2.),
                                   dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            box_tl = torch.clamp(center_location - wh / 2., min=0)
            boxes = torch.cat((box_tl, box_br), dim=2)
            mask_expand_4 = mask.repeat(1, 1, 2)
            wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4)
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss
Exemplo n.º 14
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shapes,
                            obj_sizes_of_interest):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: list(tuple). tuple <=> level.
            obj_sizes_of_interest: tensor, (level_num, 2).

        Returns:
            all_heatmap: tensor, tensor <=> img, (80, h*w for all levels).
            all_wh: tensor, tensor <=> img, (max_obj*level_num, 2).
            all_reg_mask: tensor, tensor <=> img, (max_obj*level_num,).
            all_ind: tensor, tensor <=> img, (max_obj*level_num,).
            all_center_location: tensor or None, tensor <=> img, (max_obj*level_num, 2).
        """
        level_size = len(self.fpn_strides)
        all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location = [], [], [], [], []
        max_wh_target = torch.max(gt_boxes[:, 3] - gt_boxes[:, 1],
                                  gt_boxes[:, 2] -
                                  gt_boxes[:, 0]).unsqueeze(-1).repeat(
                                      1, level_size)
        is_cared_in_the_level = \
            (max_wh_target >= obj_sizes_of_interest[:, 0]) & \
            (max_wh_target <= obj_sizes_of_interest[:, 1])  # (gt_num, level_num)

        cared_gt_num_per_level = {}
        for lvl in range(level_size):
            cared_gt_num_per_level['num_lv{}'.format(lvl)] = \
                is_cared_in_the_level[:, lvl].sum().item()
        add_summary('centernet', **cared_gt_num_per_level)

        for lvl in range(level_size):
            # get target for a single level of a single image.
            output_h, output_w = feat_shapes[lvl]
            heatmap = gt_boxes.new_zeros(
                (self.num_classes, output_h, output_w))
            wh = gt_boxes.new_zeros((self.max_objs, 2))
            reg_mask = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.uint8)
            ind = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.long)

            center_location = None
            if self.use_giou:
                center_location = gt_boxes.new_zeros((self.max_objs, 2))

            gt_boxes_in_lvl = gt_boxes[is_cared_in_the_level[:, lvl]]
            if gt_boxes_in_lvl.size(0) > 0:
                gt_boxes_in_lvl /= self.fpn_strides[lvl]
                gt_boxes_in_lvl[:, [0, 2]] = torch.clamp(
                    gt_boxes_in_lvl[:, [0, 2]], 0, output_w - 1)
                gt_boxes_in_lvl[:, [1, 3]] = torch.clamp(
                    gt_boxes_in_lvl[:, [1, 3]], 0, output_h - 1)
                hs = gt_boxes_in_lvl[:, 3] - gt_boxes_in_lvl[:, 1]
                ws = gt_boxes_in_lvl[:, 2] - gt_boxes_in_lvl[:, 0]

                if every_n_local_step(500):
                    add_histogram_summary('mlct_head_hs_lv{}'.format(lvl),
                                          hs.detach().cpu(),
                                          collect_type='none')
                    add_histogram_summary('mlct_head_ws_lv{}'.format(lvl),
                                          ws.detach().cpu(),
                                          collect_type='none')

                for k in range(gt_boxes_in_lvl.shape[0]):
                    cls_id = gt_labels[k] - 1
                    h, w = hs[k], ws[k]
                    if h > 0 and w > 0:
                        radius = gaussian_radius((h.ceil(), w.ceil()))
                        radius = max(0, int(radius.item()))
                        center = gt_boxes.new_tensor([
                            (gt_boxes_in_lvl[k, 0] + gt_boxes_in_lvl[k, 2]) /
                            2,
                            (gt_boxes_in_lvl[k, 1] + gt_boxes_in_lvl[k, 3]) / 2
                        ])
                        # no peak will fall between pixels
                        ct_int = center.to(torch.int)
                        draw_umich_gaussian(heatmap[cls_id], ct_int, radius)
                        # directly predict the width and height
                        wh[k] = wh.new_tensor([1. * w, 1. * h])
                        ind[k] = ct_int[1] * output_w + ct_int[0]
                        if self.use_giou:
                            center_location[k] = center
                        reg_mask[k] = 1

            all_heatmap.append(heatmap.view(heatmap.shape[0], -1))
            all_wh.append(wh)
            all_reg_mask.append(reg_mask)
            all_ind.append(ind)
            all_center_location.append(center_location)

        all_heatmap, all_reg_mask, all_ind = [
            torch.cat(x, dim=-1) for x in [all_heatmap, all_reg_mask, all_ind]
        ]
        all_wh = torch.cat(all_wh, dim=0)
        if self.use_giou:
            all_center_location = torch.cat(all_center_location, dim=0)

        return all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location