Esempio n. 1
0
def _get_atss_threshold(gtbb, all_anchors, k):
    '''
    1. Choose the nearest k anchors from each level based on the L2 distance
    2. Put them together, calculate the IoU with ground truth
    3. Adaptive threshold = mean + standard deviation
    '''
    assert isinstance(all_anchors, list) and isinstance(k, int)
    cx, cy, w, h = gtbb
    all_candidates = []
    for anchors in all_anchors:
        l2_2 = (cx - anchors[:, 0]).pow(2) + (cy - anchors[:, 1]).pow(2)
        _, idxs = torch.topk(l2_2, k, largest=False, sorted=False)
        all_candidates.append(anchors[idxs, :])
    all_candidates = torch.cat(all_candidates, dim=0)
    assert all_candidates.shape == (len(all_anchors) * k, 4)
    ious = bboxes_iou(gtbb.view(1, 4), all_candidates, xyxy=False).squeeze()
    assert ious.dim() == 1
    mean, std = ious.mean(), ious.std()
    atss_thres = mean + std
    # debug = ious > atss_thres
    return atss_thres
Esempio n. 2
0
    def forward(self, raw: dict, img_size, labels=None):
        assert isinstance(raw, dict)
        t_bbox = raw['bbox']
        device = t_bbox.device
        nB = t_bbox.shape[0]  # batch size
        nA = self.num_anchors  # number of anchors
        nH, nW = t_bbox.shape[2:4]  # prediction grid size
        assert t_bbox.shape[1] == nA and t_bbox.shape[-1] == 4
        conf_logits = raw['conf']
        cls_logits = raw['class']

        if self.bbox_format == 'cxcywh':
            assert t_bbox.shape[-1] == 4
        elif self.bbox_format == 'cxcywhd':
            raise NotImplementedError()
        else:
            raise NotImplementedError()

        # ----------------------- logits to prediction -----------------------
        p_bbox = t_bbox.detach().clone().contiguous()
        # bounding box
        if self.grid.shape[2:4] != (nH, nW):
            self._make_grid(nH, nW, device)
        self.grid = self.grid.to(device=device)
        p_bbox = torch.sigmoid(p_bbox)
        # x, y
        p_bbox[...,
               0:2] = (p_bbox[..., 0:2] * 2 - 0.5 + self.grid) * self.stride
        # w, h
        anch_wh = self.anchors.view(1, nA, 1, 1, 2).to(device=device)
        p_bbox[..., 2:4] = (p_bbox[..., 2:4] * 2)**2 * anch_wh
        # angle
        if self.bbox_format == 'cxcywhd':
            raise NotImplementedError()
        bb_param = p_bbox.shape[-1]
        p_bbox = p_bbox.view(nB, nA * nH * nW, bb_param).cpu()

        # Logistic activation for confidence score
        p_conf = torch.sigmoid(conf_logits.detach())
        # Logistic activation for categories
        if self.n_cls > 0:
            p_cls = torch.sigmoid(cls_logits.detach())
        cls_score, cls_idx = torch.max(p_cls, dim=-1, keepdim=True)
        confs = p_conf * cls_score
        preds = {
            'bbox': p_bbox,
            'class_idx': cls_idx.view(nB, nA * nH * nW).cpu(),
            'score': confs.view(nB, nA * nH * nW).cpu(),
        }
        if labels is None:
            return preds, None

        if self.sample_selection == 'ATSS':
            raise NotImplementedError()
            # Build x,y meshgrid for all levels
            # Calculating this at each level is not very efficient
            # Ideally this should be done only once
            # But in order to achieve that, code structure must be changed.
            img_h, img_w = img_size
            all_anchor_bbs = []
            for li, s in enumerate(self.strides_all):
                assert img_w % s == 0 and img_h % s == 0
                _sdH, _sdW = img_h // s, img_w // s
                _x = torch.linspace(0, img_w, steps=_sdW + 1)[:-1] + 0.5 * s
                _y = torch.linspace(0, img_h, steps=_sdH + 1)[:-1] + 0.5 * s
                _gy, _gx = torch.meshgrid(_y, _x)
                # if s == stride:
                #     assert (_gy == gy).all() and (_gx == gx).all()
                assert _gy.shape == _gx.shape == (_sdH, _sdW)
                anch_wh = torch.ones(_sdH * _sdW, 2) * self.anchors_all[li]
                anch_bbs = torch.cat(
                    [_gx.reshape(-1, 1),
                     _gy.reshape(-1, 1), anch_wh], dim=1)
                all_anchor_bbs.append(anch_bbs)

        assert isinstance(labels, list)
        valid_gt_num = 0
        TargetConf = torch.zeros(nB, nA, nH, nW, 1)
        if self.conf_target == 'zero-one':
            IgnoredMask = torch.zeros(nB, nA, nH, nW, dtype=torch.bool)
        loss_xy = 0
        loss_wh = 0
        loss_cls = 0
        bce_logits = tnf.binary_cross_entropy_with_logits
        # traverse all images in a batch
        for b in range(nB):
            im_labels = labels[b]
            im_labels: ImageObjects
            im_labels.sanity_check()
            num_gt = len(im_labels)
            if num_gt == 0:
                # no ground truth
                continue
            gt_bboxes = im_labels.bboxes
            gt_cls_idxs = im_labels.cats
            assert gt_bboxes.shape[1] == 4  # TODO:

            for gi, (gt_bb, gt_cidx) in enumerate(zip(gt_bboxes, gt_cls_idxs)):
                # --------------- find positive samples
                if self.sample_selection == 'best':
                    _gt_00wh = gt_bb.clone()
                    _gt_00wh[0:2] = 0
                    anchor_ious = bboxes_iou(_gt_00wh,
                                             self.anch_00wh_all,
                                             xyxy=False)
                    anch_idx_all = torch.argmax(anchor_ious,
                                                dim=1).squeeze().item()
                    if not (self.indices == anch_idx_all).any():
                        # this layer is not responsible for this GT
                        continue
                    ta = anch_idx_all % nA
                    ti = (gt_bb[0] / self.stride).long()  # horizontal
                    tj = (gt_bb[1] / self.stride).long()  # vertical
                    valid_gt_num += 1
                    # positive sample is (ta, tj, ti)
                elif self.sample_selection == 'ATSS':
                    raise NotImplementedError()
                else:
                    raise NotImplementedError()

                # loss for bounding box
                if self.loss_bbox == 'smooth_L1':
                    _t_bb = t_bbox[b, ta, tj, ti]
                    assert _t_bb.dim() == 1
                    _tgtxy = ((gt_bb[:2] / self.stride) % 1 + 0.5) / 2
                    _tgtxy = _tgtxy.to(device=device)
                    loss_xy = loss_xy + bce_logits(
                        _t_bb[:2], _tgtxy, reduction='sum')
                    _tgtwh = torch.sqrt(gt_bb[2:4] / self.anchors[ta, :]) / 2
                    _tgtwh = _tgtwh.to(device=device)
                    loss_wh = loss_wh + bce_logits(
                        _t_bb[2:4], _tgtwh, reduction='sum')
                    if _t_bb.shape[0] > 4:
                        raise NotImplementedError()
                elif self.loss_bbox == 'GIoU':
                    raise NotImplementedError()

                # loss for categories
                if self.n_cls > 0:
                    _t_cls = cls_logits[b, ta, tj, ti]
                    assert _t_cls.shape[-1] == self.n_cls
                    _tgt_cls = torch.zeros_like(_t_cls)
                    _tgt_cls[..., gt_cidx] = 1
                    loss_cls = loss_cls + bce_logits(_t_cls, _tgt_cls)

                # regression target for confidence score
                if self.conf_target == 'zero-one':
                    TargetConf[b, ta, tj, ti] = 1

            # loss for confidence score
            if self.bbox_format == 'cxcywh':
                pred_ious = bboxes_iou(p_bbox[b], gt_bboxes, xyxy=False)
            elif self.bbox_format == 'cxcywhd':
                raise NotImplementedError()
            iou_with_gt, _ = pred_ious.max(dim=1)
            if self.conf_target == 'IoU':
                TargetConf[b] = iou_with_gt.view(nA, nH, nW, 1)
            elif self.conf_target == 'zero-one':
                if self.bbox_format == 'cxcywh':
                    IgnoredMask[b] = (iou_with_gt > self.negative_thres).view(
                        nA, nH, nW)
                elif self.bbox_format == 'cxcywhd':
                    raise NotImplementedError()

        # move the tagerts to GPU
        TargetConf = TargetConf.to(device=device)
        ignored_num = 0
        if self.conf_target == 'IoU':
            loss_conf = bce_logits(conf_logits, TargetConf, reduction='sum')
        elif self.conf_target == 'zero-one':
            IgnoredMask = IgnoredMask.to(device=device)
            _pos_mask = TargetConf.squeeze(-1).bool()
            _penalty = _pos_mask | (~IgnoredMask)
            loss_conf = bce_logits(conf_logits[_penalty],
                                   TargetConf[_penalty],
                                   reduction='sum')
            ignored_num = (IgnoredMask & (~_pos_mask)).sum().cpu().item()
        else:
            raise NotImplementedError()

        loss = loss_xy + loss_wh + loss_conf + loss_cls
        loss = loss / nB

        # logging
        ngt = valid_gt_num + 1e-16
        self.loss_str = f'yolo_{nH}x{nW} pos/ignore: {int(ngt)}/{ignored_num}: ' \
                        f'xy/gt {loss_xy/ngt:.3f}, wh/gt {loss_wh/ngt:.3f}, ' \
                        f'conf {loss_conf:.3f}, class {loss_cls:.3f}'
        self._assigned_num = valid_gt_num
        return preds, loss
Esempio n. 3
0
    def forward(self, raw: dict, img_size, labels=None):
        stride = self.stride
        img_h, img_w = img_size
        nA = self.num_anchors  # number of anchors
        nH, nW = int(img_h / stride), int(img_w / stride)
        nCls = self.n_cls

        t_xywh = raw['bbox']
        cls_logits = raw['class']
        nB = t_xywh.shape[0]  # batch size
        assert t_xywh.shape == (nB, nA, nH, nW, self.n_bbparam)
        assert cls_logits.shape == (nB, nA, nH, nW, nCls)
        device = t_xywh.device

        def _compute_anchors(dvc):
            # generate anchor boxes on a specific device
            a_cx = torch.arange(stride / 2, img_w, stride,
                                device=dvc).view(1, 1, 1, nW)
            a_cy = torch.arange(stride / 2, img_h, stride,
                                device=dvc).view(1, 1, nH, 1)
            a_wh = self.anchor_wh.view(1, nA, 1, 1, 2).to(device=dvc)
            return a_cx, a_cy, a_wh

        if labels is None:
            # -------------------- logits to prediction --------------------
            # cx, cy, w, h
            a_cx, a_cy, a_wh = _compute_anchors(dvc=device)
            p_xywh = torch.empty_like(t_xywh).contiguous()
            p_xywh[..., 0] = a_cx + t_xywh[..., 0] * a_wh[..., 0]
            p_xywh[..., 1] = a_cy + t_xywh[..., 1] * a_wh[..., 1]
            p_xywh[..., 2:4] = torch.exp(t_xywh[..., 2:4]) * a_wh
            p_xywh[..., 0:4].clamp_(min=1, max=max(img_size))
            if self.pred_bbox_format == 'cxcywhd':
                p_xywh[..., 4] = torch.sigmoid(t_xywh[..., 4]) * 360 - 180
            # classes
            p_cls = torch.sigmoid(cls_logits)
            cls_score, cls_idx = torch.max(p_cls, dim=-1)
            preds = {
                'bbox': p_xywh.view(nB, nA * nH * nW, self.n_bbparam).cpu(),
                'class_idx': cls_idx.view(nB, nA * nH * nW).cpu(),
                'score': cls_score.view(nB, nA * nH * nW).cpu(),
            }
            return preds, None

        a_cx, a_cy, a_wh = _compute_anchors(dvc=torch.device('cpu'))
        # a_cx, a_cy, a_wh = [a.squeeze(0) for a in [a_cx, a_cy, a_wh]]
        anch_bbs = torch.cat([
            a_cx.view(1, 1, nW, 1).expand(nA, nH, nW, 1),
            a_cy.view(1, nH, 1, 1).expand(nA, nH, nW, 1),
            a_wh.view(nA, 1, 1, 2).expand(nA, nH, nW, 2)
        ],
                             dim=-1)
        loss_xywh = 0
        loss_cls = 0
        total_pos_num = 0
        total_sample_num = 0
        for b in range(nB):
            im_labels = labels[b]
            assert isinstance(im_labels, ImageObjects)
            im_labels.sanity_check()
            assert self.pred_bbox_format == im_labels._bb_format

            if len(im_labels) == 0:
                tgt_cls = torch.zeros(nA, nH, nW, nCls, device=device)
                im_loss_cls = bce_w_logits(cls_logits[b],
                                           tgt_cls,
                                           reduction='sum')
                loss_cls = loss_cls + im_loss_cls
                continue

            gt_bbs = im_labels.bboxes
            ious = bboxes_iou(anch_bbs.view(-1, 4), gt_bbs[:, :4], xyxy=False)
            iou_with_gt, gt_idx = ious.max(dim=1)
            iou_with_gt = iou_with_gt.view(nA, nH, nW)
            gt_idx = gt_idx.view(nA, nH, nW)
            M_pos = (iou_with_gt > self.positive_thres)  # positive sample mask
            M_neg = (iou_with_gt < self.negative_thres)  # negative sample mask
            num_pos_sample = M_pos.sum()
            total_pos_num += num_pos_sample
            total_sample_num += nA * nH * nW

            # set bbox target
            tgt_xywh = torch.zeros(nA, nH, nW, 4)
            gt_bbs = gt_bbs[gt_idx, :]
            tgt_xywh[..., 0:2] = (gt_bbs[..., 0:2] -
                                  anch_bbs[..., 0:2]) / anch_bbs[..., 2:4]
            tgt_xywh[...,
                     2:4] = torch.log(gt_bbs[..., 2:4] / anch_bbs[..., 2:4] +
                                      1e-8)
            # set class target
            tgt_cls = torch.zeros(nA, nH, nW, nCls)
            tgt_cls[M_pos, im_labels.cats[gt_idx[M_pos]]] = 1
            # find the predictions which are not good enough
            cls_logits_copy_ = cls_logits[b].detach().cpu().squeeze(-1)
            high_enough = np.log(0.95 / (1 - 0.95))
            need_higher = M_pos & (cls_logits_copy_ < high_enough)
            low_enough = np.log(0.01 / (1 - 0.01))
            need_lower = M_neg & (cls_logits_copy_ > low_enough)
            # ignore the predictions which are already good enough
            cls_penalty_mask = need_higher | need_lower
            # Set angle target.
            if self.pred_bbox_format == 'cxcywhd':
                tgt_angle = torch.zeros(nA, nH, nW)
                # Use radian when calculating the angle loss
                tgt_angle = gt_bbs[..., 4] / 180 * np.pi
                tgt_angle = tgt_angle.to(device)
            tgt_xywh = tgt_xywh.to(device)
            tgt_cls = tgt_cls.to(device)
            # bbox loss
            if num_pos_sample > 0:
                # import matplotlib.pyplot as plt
                # for ia in range(nA):
                #     print(a_wh.squeeze()[ia,:])
                #     mask = M_pos[ia, :, :].numpy()
                #     plt.imshow(mask, cmap='gray')
                #     plt.show()
                im_loss_xywh = fvcore.nn.smooth_l1_loss(t_xywh[b][M_pos][:,
                                                                         0:4],
                                                        tgt_xywh[M_pos, :],
                                                        beta=0.1,
                                                        reduction='sum')
                if self.pred_bbox_format == 'cxcywhd':
                    p_angle = torch.sigmoid(
                        t_xywh[b][M_pos][:, 4]) * 2 * np.pi - np.pi
                    im_loss_angle = self.loss_angle(p_angle, tgt_angle[M_pos])
                    im_loss_xywh = im_loss_xywh + im_loss_angle
                # im_loss_xywh = tnf.mse_loss(t_xywh[b, M_pos, :],
                #                 tgt_xywh[M_pos, :], reduction='sum')
                loss_xywh = loss_xywh + im_loss_xywh
            # class loss
            # im_loss_cls = fvcore.nn.sigmoid_focal_loss(cls_logits[b, cls_penalty_mask],
            #         tgt_cls[cls_penalty_mask], alpha=0.25, gamma=2, reduction='sum')
            im_loss_cls = bce_w_logits(cls_logits[b, cls_penalty_mask],
                                       tgt_cls[cls_penalty_mask],
                                       reduction='sum')
            loss_cls = loss_cls + im_loss_cls  # / (num_pos_sample + 1)
        loss = (loss_xywh + loss_cls) / nB

        # logging
        self.loss_str = f'level_{nH}x{nW} pos {total_pos_num}/{total_sample_num}: ' \
                        f'xywh {loss_xywh:.3f}, class {loss_cls:.3f}'
        return None, loss
Esempio n. 4
0
    def forward(self, raw: dict, img_size, labels=None):
        assert isinstance(raw, dict)
        t_xywh = raw['bbox']
        device = t_xywh.device
        nB = t_xywh.shape[0]  # batch size
        nA = self.num_anchors  # number of anchors
        nH, nW = t_xywh.shape[2:4]  # prediction grid size
        assert t_xywh.shape[1] == nA and t_xywh.shape[-1] == 4
        conf_logits = raw['conf']
        cls_logits = raw['class']

        # ----------------------- logits to prediction -----------------------
        p_xywh = t_xywh.detach().clone().contiguous()
        # sigmoid activation for xy, obj_conf
        y_ = torch.arange(nH, dtype=torch.float,
                          device=device).view(1, 1, nH, 1)
        x_ = torch.arange(nW, dtype=torch.float,
                          device=device).view(1, 1, 1, nW)
        p_xywh[..., 0] = (torch.sigmoid(p_xywh[..., 0]) + x_) * self.stride
        p_xywh[..., 1] = (torch.sigmoid(p_xywh[..., 1]) + y_) * self.stride
        # w, h
        anch_wh = self.anchors.view(1, nA, 1, 1, 2).to(device=device)
        p_xywh[..., 2:4] = torch.exp(p_xywh[..., 2:4]) * anch_wh
        p_xywh = p_xywh.view(nB, nA * nH * nW, 4).cpu()

        # Logistic activation for confidence score
        p_conf = torch.sigmoid(conf_logits.detach())
        # Logistic activation for categories
        if self.n_cls > 0:
            p_cls = torch.sigmoid(cls_logits.detach())
            cls_score, cls_idx = torch.max(p_cls, dim=-1, keepdim=True)
            confs = p_conf * cls_score
        else:
            cls_idx = torch.zeros(nB, nA, nH, nW, dtype=torch.int64)
            confs = p_conf
        preds = {
            'bbox': p_xywh,
            'class_idx': cls_idx.view(nB, nA * nH * nW).cpu(),
            'score': confs.view(nB, nA * nH * nW).cpu(),
        }
        if labels is None:
            return preds, None

        assert isinstance(labels, list)
        valid_gt_num = 0
        gt_mask = torch.zeros(nB, nA, nH, nW, dtype=torch.bool)
        conf_loss_mask = torch.ones(nB, nA, nH, nW, dtype=torch.bool)
        weighted = torch.zeros(nB, nA, nH, nW)
        tgt_xywh = torch.zeros(nB, nA, nH, nW, 4)
        tgt_conf = torch.zeros(nB, nA, nH, nW, 1)
        tgt_cls = torch.zeros(nB, nA, nH, nW, self.n_cls)
        # traverse all images in a batch
        for b in range(nB):
            im_labels = labels[b]
            im_labels.sanity_check()
            num_gt = len(im_labels)
            if num_gt == 0:
                # no ground truth
                continue
            gt_bboxes = im_labels.bboxes
            gt_cls_idx = im_labels.cats
            assert gt_bboxes.shape[1] == 4

            # calculate iou between truth and reference anchors
            gt_00wh = torch.zeros(num_gt, 4)
            gt_00wh[:, 2:4] = gt_bboxes[:, 2:4]
            anchor_ious = bboxes_iou(gt_00wh, self.anch_00wh_all, xyxy=False)
            best_n_all = torch.argmax(anchor_ious, dim=1)
            best_n = best_n_all % self.num_anchors

            valid_mask = torch.zeros(num_gt, dtype=torch.bool)
            for ind in self.indices:
                valid_mask = (valid_mask | (best_n_all == ind))
            if valid_mask.sum() == 0:
                # no anchor is responsible for any ground truth
                continue
            else:
                valid_gt_num += sum(valid_mask)

            pred_ious = bboxes_iou(p_xywh[b], gt_bboxes, xyxy=False)
            iou_with_gt, _ = pred_ious.max(dim=1)
            # ignore the conf of a pred BB if it matches a gt more than 0.7
            conf_loss_mask[b] = (iou_with_gt < self.ignore_thre).view(
                nA, nH, nW)
            # conf_loss_mask = 1 -> give penalty

            gt_bboxes = gt_bboxes[valid_mask, :]
            grid_tx = gt_bboxes[:, 0] / self.stride
            grid_ty = gt_bboxes[:, 1] / self.stride
            ti, tj = grid_tx.long().clamp(max=nW -
                                          1), grid_ty.long().clamp(max=nH - 1)
            tn = best_n[valid_mask]  # target anchor box number

            conf_loss_mask[b, tn, tj, ti] = 1
            gt_mask[b, tn, tj, ti] = 1
            tgt_xywh[b, tn, tj, ti, 0] = grid_tx - grid_tx.floor()
            tgt_xywh[b, tn, tj, ti, 1] = grid_ty - grid_ty.floor()
            tgt_xywh[b, tn, tj, ti,
                     2] = torch.log(gt_bboxes[:, 2] / self.anchors[tn, 0] +
                                    1e-8)
            tgt_xywh[b, tn, tj, ti,
                     3] = torch.log(gt_bboxes[:, 3] / self.anchors[tn, 1] +
                                    1e-8)
            tgt_conf[b, tn, tj, ti] = 1  # objectness confidence
            if self.n_cls > 0:
                tgt_cls[b, tn, tj, ti, gt_cls_idx[valid_mask]] = 1
            # smaller objects have higher losses
            img_area = img_size[0] * img_size[1]
            weighted[b, tn, tj,
                     ti] = 2 - gt_bboxes[:, 2] * gt_bboxes[:, 3] / img_area

        # move the tagerts to GPU
        gt_mask = gt_mask.to(device=device)
        conf_loss_mask = conf_loss_mask.to(device=device)
        weighted = weighted.unsqueeze(-1).to(device=device)
        tgt_xywh = tgt_xywh.to(device=device)
        tgt_conf = tgt_conf.to(device=device)
        tgt_cls = tgt_cls.to(device=device)

        bce_logits = tnf.binary_cross_entropy_with_logits
        # weighted BCE loss for x,y
        loss_xy = bce_logits(t_xywh[..., 0:2][gt_mask],
                             tgt_xywh[..., 0:2][gt_mask],
                             weight=weighted[gt_mask],
                             reduction='sum')
        # weighted squared error for w,h
        loss_wh = (t_xywh[..., 2:4][gt_mask] -
                   tgt_xywh[..., 2:4][gt_mask]).pow(2)
        loss_wh = 0.5 * (weighted[gt_mask] * loss_wh).sum()
        loss_conf = bce_logits(conf_logits[conf_loss_mask],
                               tgt_conf[conf_loss_mask],
                               reduction='sum')
        if self.n_cls > 0:
            loss_cls = bce_logits(cls_logits[gt_mask],
                                  tgt_cls[gt_mask],
                                  reduction='sum')
        else:
            loss_cls = 0
        loss = loss_xy + loss_wh + loss_conf + loss_cls
        loss = loss / nB

        # logging
        ngt = valid_gt_num + 1e-16
        self.loss_str = f'yolo_{nH}x{nW} total {int(ngt)} objects: ' \
                        f'xy/gt {loss_xy/ngt:.3f}, wh/gt {loss_wh/ngt:.3f}, ' \
                        f'conf {loss_conf:.3f}, class {loss_cls:.3f}'
        self._assigned_num = valid_gt_num
        return preds, loss
Esempio n. 5
0
    def forward(self, raw, img_size, labels=None):
        stride = self.stride
        img_h, img_w = img_size
        nH, nW = int(img_h / stride), int(img_w / stride)
        nCls = self.n_cls
        assert isinstance(raw, dict)

        t_ltrb = raw['bbox']
        conf_logits = raw['conf']
        cls_logits = raw['class']
        nB = t_ltrb.shape[0]  # batch size
        assert t_ltrb.shape == (nB, nH, nW, 4)
        assert conf_logits.shape == (nB, nH, nW, 1)
        assert cls_logits.shape == (nB, nH, nW, nCls)
        device = t_ltrb.device

        # activation function for left, top, right, bottom
        if self.ltrb_setting.startswith('exp'):
            p_ltrb = torch.exp(t_ltrb.detach()) * stride
        elif self.ltrb_setting.startswith('relu'):
            p_ltrb = tnf.relu(t_ltrb.detach()) * stride
        else:
            raise Exception('Unknown ltrb_setting')

        # ---------------------------- testing ----------------------------
        # Force the prediction to be in the image
        p_xyxy = _ltrb_to(p_ltrb, nH, nW, stride, 'x1y1x2y2')
        p_xyxy[..., 0].clamp_(min=0, max=img_w)
        p_xyxy[..., 1].clamp_(min=0, max=img_h)
        p_xyxy[..., 2].clamp_(min=0, max=img_w)
        p_xyxy[..., 3].clamp_(min=0, max=img_h)
        p_xywh = _xyxy_to_xywh(p_xyxy)
        # Logistic activation for 'centerness'
        p_conf = torch.sigmoid(conf_logits.detach())
        # Logistic activation for categories
        p_cls = torch.sigmoid(cls_logits.detach())
        cls_score, cls_idx = torch.max(p_cls, dim=3, keepdim=True)
        confs = torch.sqrt(p_conf * cls_score)
        preds = {
            'bbox': p_xywh.view(nB, nH * nW, 4),
            'class_idx': cls_idx.view(nB, nH * nW),
            'score': confs.view(nB, nH * nW),
        }
        # Return the final predictions when testing
        if labels is None:
            return preds, None

        p_xywh = _ltrb_to(p_ltrb, nH, nW, stride, 'cxcywh').cpu()
        # ------------------------------ training ------------------------------
        assert isinstance(labels, list)
        # Build x,y meshgrid with size (1,nH,nW)
        x_ = torch.linspace(0, img_w, steps=nW + 1)[:-1] + 0.5 * stride
        y_ = torch.linspace(0, img_h, steps=nH + 1)[:-1] + 0.5 * stride
        gy, gx = torch.meshgrid(y_, x_)
        # Initialize the prediction target of the batch
        # positive: at the center region and max(tgt_ltrb) in (min, max)
        # ignored: predicted bbox IoU with GT > 0.6
        # Conf: positive or (not ignored)
        # LTRB, Ctr, CLs: positive
        PositiveMask = torch.zeros(nB, nH, nW, dtype=torch.bool)
        IgnoredMask = torch.zeros(nB, nH, nW, dtype=torch.bool)
        TargetConf = torch.zeros(nB, nH, nW, 1)
        TargetLTRB = torch.zeros(nB, nH, nW, 4)
        # TargetCtr = torch.zeros(nB, nH, nW, 1)
        TargetCls = torch.zeros(nB, nH, nW,
                                self.n_cls) if self.n_cls > 0 else None
        assert self.n_cls > 0
        # traverse all images in a batch
        for b in range(nB):
            im_labels = labels[b]
            assert isinstance(im_labels, ImageObjects)
            if len(im_labels) == 0:
                continue
            im_labels.sanity_check()

            gt_xywh = im_labels.bboxes
            areas = gt_xywh[:, 2] * gt_xywh[:, 3]
            lg2sml_idx = torch.argsort(areas, descending=True)
            gt_xywh = gt_xywh[lg2sml_idx, :]
            gt_cls_idx = im_labels.cats[lg2sml_idx]

            # ignore the conf of a pred BB if it matches a gt more than self.thres
            ious = bboxes_iou(p_xywh[b].view(-1, 4), gt_xywh, xyxy=False)
            iou_with_gt, _ = torch.max(ious, dim=1)
            IgnoredMask[b] = (iou_with_gt > self.ignore_thre).view(nH, nW)

            # Since the gt labels are sorted by area (descending), \
            # small object targets are set later so they get higher priority
            for bb, cidx in zip(gt_xywh, gt_cls_idx):
                # Convert cxcywh to x1y1x2y2
                Tx1, Ty1, Tx2, Ty2 = _xywh_to_xyxy(bb, cr=1)
                # regression target at each location
                tgt_l, tgt_t, tgt_r, tgt_b = gx - Tx1, gy - Ty1, Tx2 - gx, Ty2 - gy
                # stacking them together, we get target for ltrb
                tgt_ltrb = torch.stack([tgt_l, tgt_t, tgt_r, tgt_b], dim=-1)
                assert tgt_ltrb.shape == (nH, nW, 4)
                # full bounding box mask
                bbox_mask = torch.prod((tgt_ltrb > 0), dim=-1).bool()
                # Find positive samples for this bounding box
                # 1. the center part of the bounding box
                Cx1, Cy1, Cx2, Cy2 = _xywh_to_xyxy(bb, cr=self.center_region)
                center_mask = (gx > Cx1) & (gx < Cx2) & (gy > Cy1) & (gy < Cy2)
                # 2. max predicted ltrb within the range
                max_tgt_ltrb, _ = torch.max(tgt_ltrb, dim=-1)
                anch_mask = (self.anch_min < max_tgt_ltrb) & (max_tgt_ltrb <
                                                              self.anch_max)
                # 3. positive samples must satisfy both 1 and 2
                pos_mask = center_mask & anch_mask
                if not pos_mask.any():
                    continue
                # set target for ltrb
                TargetLTRB[b, pos_mask, :] = tgt_ltrb[pos_mask, :]
                # compute target for center score
                # tgt_center = torch.min(tgt_l, tgt_r) / torch.max(tgt_l, tgt_r) * \
                #             torch.min(tgt_t, tgt_b) / torch.max(tgt_t, tgt_b)
                # tgt_center.mul_(bbox_mask).sqrt_()
                # import matplotlib.pyplot as plt
                # plt.imshow(tgt_center.numpy(), cmap='gray'); plt.show()
                # assert TargetCtr.shape[-1] == 1
                # TargetCtr[b, bbox_mask] = tgt_center[bbox_mask].unsqueeze(-1)
                # the target for confidence socre is 1
                TargetConf[b, pos_mask] = 1
                # set target for category classification
                _Hidx, _Widx = pos_mask.nonzero(as_tuple=True)
                TargetCls[b, _Hidx, _Widx, cidx] = 1
                # Update the batch positive sample mask
                PositiveMask[b] = PositiveMask[b] | pos_mask

        # Transfer targets to GPU
        PositiveMask = PositiveMask.to(device=device)
        IgnoredMask = IgnoredMask.to(device=device)
        TargetConf = TargetConf.to(device=device)
        TargetLTRB = TargetLTRB.to(device=device)
        # TargetCtr = TargetCtr.to(device=device)
        TargetCls = TargetCls.to(device=device) if self.n_cls > 0 else None

        # Compute loss
        pLTRB, tgtLTRB = t_ltrb[PositiveMask], TargetLTRB[PositiveMask]
        assert (tgtLTRB > 0).all()  # Sanity check
        if self.ltrb_setting.startswith('exp'):
            tgtLTRB = torch.log(tgtLTRB / stride)
        else:
            raise NotImplementedError()
        if self.ltrb_setting.endswith('sl1'):
            # smooth L1 loss for l,t,r,b
            loss_bbox = lossLib.smooth_L1_loss(pLTRB,
                                               tgtLTRB,
                                               beta=0.2,
                                               reduction='sum')
        elif self.ltrb_setting.endswith('l2'):
            loss_bbox = tnf.mse_loss(pLTRB, tgtLTRB, reduction='sum')
        else:
            raise NotImplementedError()
        bce_logits = tnf.binary_cross_entropy_with_logits
        # Binary cross entropy for confidence score
        _penalty = PositiveMask | (~IgnoredMask)
        _pConf, _tgtConf = conf_logits[_penalty], TargetConf[_penalty]
        loss_conf = bce_logits(_pConf, _tgtConf, reduction='sum')
        # Binary cross entropy for category classification
        _pCls, _tgtCls = cls_logits[PositiveMask], TargetCls[PositiveMask]
        loss_cls = bce_logits(_pCls, _tgtCls, reduction='sum')
        loss = loss_bbox + loss_conf + loss_cls
        loss = loss / nB

        # logging
        pos_num = PositiveMask.sum().cpu().item()
        total_sample_num = nB * nH * nW
        ignored_num = (IgnoredMask & (~PositiveMask)).sum().cpu().item()
        self.loss_str = f'level_{nH}x{nW}, pos {pos_num}/{total_sample_num}, ' \
                        f'ignored {ignored_num}/{total_sample_num}: ' \
                        f'bbox/gt {loss_bbox:.3f}, conf {loss_conf:.3f}, ' \
                        f'class/gt {loss_cls:.3f}'
        return preds, loss
Esempio n. 6
0
    def forward(self, raw, img_size, labels=None):
        stride = self.stride
        img_h, img_w = img_size
        nH, nW = int(img_h / stride), int(img_w / stride)
        nCls = self.n_cls
        assert isinstance(raw, dict)

        t_ltrb = raw['bbox']
        conf_logits = raw['conf']
        cls_logits = raw['class']
        nB = t_ltrb.shape[0]  # batch size
        assert t_ltrb.shape == (nB, nH, nW, 4)
        assert conf_logits.shape == (nB, nH, nW, 1)
        assert cls_logits.shape == (nB, nH, nW, nCls)
        device = t_ltrb.device

        # activation function for left, top, right, bottom
        if self.ltrb_setting.startswith('exp'):
            p_ltrb = torch.exp(t_ltrb.detach()) * stride
        elif self.ltrb_setting.startswith('relu'):
            p_ltrb = tnf.relu(t_ltrb.detach()) * stride
        else:
            raise Exception('Unknown ltrb_setting')

        # ---------------------------- testing ----------------------------
        # Force the prediction to be in the image
        p_xyxy = _ltrb_to(p_ltrb, nH, nW, stride, 'x1y1x2y2')
        p_xyxy[..., 0].clamp_(min=0, max=img_w)
        p_xyxy[..., 1].clamp_(min=0, max=img_h)
        p_xyxy[..., 2].clamp_(min=0, max=img_w)
        p_xyxy[..., 3].clamp_(min=0, max=img_h)
        p_xywh = _xyxy_to_xywh(p_xyxy)
        # Logistic activation for 'centerness'
        p_conf = torch.sigmoid(conf_logits.detach())
        # Logistic activation for categories
        p_cls = torch.sigmoid(cls_logits.detach())
        cls_score, cls_idx = torch.max(p_cls, dim=3, keepdim=True)
        confs = torch.sqrt(p_conf * cls_score)
        preds = {
            'bbox': p_xywh.view(nB, nH * nW, 4),
            'class_idx': cls_idx.view(nB, nH * nW),
            'score': confs.view(nB, nH * nW),
        }
        # Return the final predictions when testing
        if labels is None:
            return preds, None

        p_xywh = _ltrb_to(p_ltrb, nH, nW, stride, 'cxcywh').cpu()
        # ------------------------------ training ------------------------------
        assert isinstance(labels, list)
        # Build x,y meshgrid with size (1,nH,nW)
        x_ = torch.linspace(0, img_w, steps=nW + 1)[:-1] + 0.5 * stride
        y_ = torch.linspace(0, img_h, steps=nH + 1)[:-1] + 0.5 * stride
        gy, gx = torch.meshgrid(y_, x_)
        gy, gx = gy.contiguous(), gx.contiguous()
        # Build x,y meshgrid for all levels
        # Calculating this at each level is not very efficient
        # Ideally this should be done only once
        # But to achieve that, code structure must be changed.
        all_anchor_bbs = []
        for li, s in enumerate(self.strides_all):
            assert img_w % s == 0 and img_h % s == 0
            _sdH, _sdW = img_h // s, img_w // s
            _x = torch.linspace(0, img_w, steps=_sdW + 1)[:-1] + 0.5 * s
            _y = torch.linspace(0, img_h, steps=_sdH + 1)[:-1] + 0.5 * s
            _gy, _gx = torch.meshgrid(_y, _x)
            # if s == stride:
            #     assert (_gy == gy).all() and (_gx == gx).all()
            assert _gy.shape == _gx.shape == (_sdH, _sdW)
            anch_wh = torch.ones(_sdH * _sdW, 2) * self.anchors_all[li]
            anch_bbs = torch.cat(
                [_gx.reshape(-1, 1),
                 _gy.reshape(-1, 1), anch_wh], dim=1)
            all_anchor_bbs.append(anch_bbs)
        # Initialize the prediction target of the batch
        # positive: at the center region and max(tgt_ltrb) in (min, max)
        # ignored: predicted bbox IoU with GT > 0.6
        # Conf: positive or (not ignored)
        # LTRB, Ctr, CLs: positive
        PositiveMask = torch.zeros(nB, nH, nW, dtype=torch.bool)
        IgnoredMask = torch.zeros(nB, nH, nW, dtype=torch.bool)
        TargetConf = torch.zeros(nB, nH, nW, 1)
        TargetLTRB = torch.zeros(nB, nH, nW, 4)
        # TargetCtr = torch.zeros(nB, nH, nW, 1)
        TargetCls = torch.zeros(nB, nH, nW,
                                self.n_cls) if self.n_cls > 0 else None
        assert self.n_cls > 0
        # traverse all images in a batch
        for b in range(nB):
            im_labels = labels[b]
            assert isinstance(im_labels, ImageObjects)
            if len(im_labels) == 0:
                continue
            im_labels.sanity_check()

            gt_xywh = im_labels.bboxes
            areas = gt_xywh[:, 2] * gt_xywh[:, 3]
            lg2sml_idx = torch.argsort(areas, descending=True)
            gt_xywh = gt_xywh[lg2sml_idx, :]
            gt_cls_idx = im_labels.cats[lg2sml_idx]

            # ignore the conf of a pred BB if it matches a gt more than 0.7
            ious = bboxes_iou(p_xywh[b].view(-1, 4), gt_xywh, xyxy=False)
            iou_with_gt, _ = torch.max(ious, dim=1)
            IgnoredMask[b] = (iou_with_gt > self.ignore_thre).view(nH, nW)

            # Since the gt labels are sorted by area (descending), \
            # small object targets are set later so they get higher priority
            for bb, cidx in zip(gt_xywh, gt_cls_idx):
                # Convert cxcywh to x1y1x2y2
                Tx1, Ty1, Tx2, Ty2 = _xywh_to_xyxy(bb, cr=1)
                # regression target at each location
                tgt_l, tgt_t, tgt_r, tgt_b = gx - Tx1, gy - Ty1, Tx2 - gx, Ty2 - gy
                # stacking them together, we get target for ltrb
                tgt_ltrb = torch.stack([tgt_l, tgt_t, tgt_r, tgt_b], dim=-1)
                assert tgt_ltrb.shape == (nH, nW, 4)
                # full bounding box mask
                bbox_mask = torch.prod((tgt_ltrb > 0), dim=-1).bool()
                # Find positive samples for this bounding box
                thres = _get_atss_threshold(bb, all_anchor_bbs, self.topk)
                anch_bbs = torch.stack([gx, gy], dim=-1)
                anch_bbs = torch.cat(
                    [anch_bbs, torch.ones(nH, nW, 2) * self.anchor],
                    dim=-1).view(nH * nW, 4)
                ious = bboxes_iou(anch_bbs, bb.view(1, 4),
                                  xyxy=False).squeeze()
                pos_mask = (ious > thres).view(nH, nW)
                pos_mask = pos_mask & bbox_mask
                if not pos_mask.any():
                    continue
                # set target for ltrb
                TargetLTRB[b, pos_mask, :] = tgt_ltrb[pos_mask, :]
                # the target for confidence socre is 1
                TargetConf[b, pos_mask] = 1
                # set target for category classification
                _Hidx, _Widx = pos_mask.nonzero(as_tuple=True)
                TargetCls[b, _Hidx, _Widx, cidx] = 1
                # Update the batch positive sample mask
                PositiveMask[b] = PositiveMask[b] | pos_mask

        # Transfer targets to GPU
        PositiveMask = PositiveMask.to(device=device)
        IgnoredMask = IgnoredMask.to(device=device)
        TargetConf = TargetConf.to(device=device)
        TargetLTRB = TargetLTRB.to(device=device)
        # TargetCtr = TargetCtr.to(device=device)
        TargetCls = TargetCls.to(device=device) if self.n_cls > 0 else None

        # Compute loss
        pLTRB, tgtLTRB = t_ltrb[PositiveMask], TargetLTRB[PositiveMask]
        assert (tgtLTRB > 0).all()  # Sanity check
        if self.ltrb_setting.startswith('exp'):
            tgtLTRB = torch.log(tgtLTRB / stride)
        else:
            raise NotImplementedError()
        if self.ltrb_setting.endswith('sl1'):
            # smooth L1 loss for l,t,r,b
            loss_bbox = lossLib.smooth_L1_loss(pLTRB,
                                               tgtLTRB,
                                               beta=0.2,
                                               reduction='sum')
        elif self.ltrb_setting.endswith('l2'):
            loss_bbox = tnf.mse_loss(pLTRB, tgtLTRB, reduction='sum')
        else:
            raise NotImplementedError()
        bce_logits = tnf.binary_cross_entropy_with_logits
        # Binary cross entropy for confidence score
        _penalty = PositiveMask | (~IgnoredMask)
        _pConf, _tgtConf = conf_logits[_penalty], TargetConf[_penalty]
        loss_conf = bce_logits(_pConf, _tgtConf, reduction='sum')
        # Binary cross entropy for category classification
        _pCls, _tgtCls = cls_logits[PositiveMask], TargetCls[PositiveMask]
        loss_cls = bce_logits(_pCls, _tgtCls, reduction='sum')
        loss = loss_bbox + loss_conf + loss_cls

        # logging
        pos_num = PositiveMask.sum().cpu().item()
        total_sample_num = nB * nH * nW
        ignored_num = (IgnoredMask & (~PositiveMask)).sum().cpu().item()
        self.loss_str = f'level_{nH}x{nW}, pos {pos_num}/{total_sample_num}, ' \
                        f'ignored {ignored_num}/{total_sample_num}: ' \
                        f'bbox/gt {loss_bbox:.3f}, conf {loss_conf:.3f}, ' \
                        f'class/gt {loss_cls:.3f}'
        return preds, loss