Пример #1
0
def off_loss_(preds, target, mask):
    """
    :param preds: pred_offsets
    :param gt:  gt_offsets
    :param mask: denotes where is those corners
    :return: smooth l1 loss of offsets
    """
    mask = (mask.sum(1) > 0).unsqueeze(1).type_as(preds)
    preds *= mask
    target *= mask

    return smooth_l1_loss(preds, target, reduction='none')
Пример #2
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        return hm_loss

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss
Пример #3
0
    def loss(self,
             okpd_outs,
             rois,
             labels,
             label_weights,
             bbox_targets,
             bbox_weights,
             reduction_override=None):
        losses = dict()
        cls_score = okpd_outs[-2]
        bbox_pred = okpd_outs[-1]
        bg_class_ind = self.num_classes
        pos_inds = (labels >= 0) & (labels < bg_class_ind)
        avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)

        if cls_score.numel() > 0:
            losses['loss_okpd_cls'] = self.loss_cls(
                cls_score,
                labels,
                label_weights,
                avg_factor=avg_factor,
                reduction_override=reduction_override)
            losses['okpd_acc'] = accuracy(cls_score, labels)

        if bbox_pred is not None:
            # do not perform bounding box regression for BG anymore.
            if pos_inds.any():
                if self.reg_decoded_bbox:
                    bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
                if self.reg_class_agnostic:
                    pos_bbox_pred = bbox_pred.view(
                        bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
                else:
                    pos_bbox_pred = bbox_pred.view(
                        bbox_pred.size(0), -1,
                        4)[pos_inds.type(torch.bool),
                           labels[pos_inds.type(torch.bool)]]
                losses['loss_okpd_bbox'] = self.loss_bbox(
                    pos_bbox_pred,
                    bbox_targets[pos_inds.type(torch.bool)],
                    bbox_weights[pos_inds.type(torch.bool)],
                    avg_factor=bbox_targets.size(0),
                    reduction_override=reduction_override)
            else:
                losses['loss_bbox'] = bbox_pred.sum() * 0

        if self.num_kp > 0:
            kp_max_target = torch.zeros_like(okpd_outs[0])
            kp_max_target[pos_inds, ...] = 1.
            kp_max_wts = torch.zeros_like(okpd_outs[0]) + 0.5
            fn_inds = (kp_max_target > 0.) & (okpd_outs[0] < 0.5)
            fp_inds = (kp_max_target < 1.) & (okpd_outs[0] > 0.5)
            kp_max_wts[fp_inds | fn_inds] = 1.0
            Ld_loss = smooth_l1_loss(okpd_outs[0] * kp_max_wts,
                                     kp_max_target * kp_max_wts,
                                     beta=0.05)

            pos_mask = pos_inds.float()[:, None, None, None]
            kp_ch_sum_target = torch.zeros_like(okpd_outs[1])
            kp_ch_sum_target[pos_inds, ...] = 1.
            Lu_loss = smooth_l1_loss(okpd_outs[1] * pos_mask,
                                     kp_ch_sum_target * pos_mask,
                                     beta=1.0)
            losses['kp_loss'] = 0.03 * (Ld_loss + Lu_loss)

        return losses
Пример #4
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Пример #5
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 center_location):
        """

        Args:
            pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w).
            pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h*w for all levels).
            wh: tensor, (batch, max_obj*level_num, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num).
            ind: tensor, (batch, max_obj*level_num).
            center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when
                using GIOU.

        Returns:

        """
        if every_n_local_step(500):
            for lvl, hm in enumerate(pred_hm):
                hm_summary = hm.clone().detach().sigmoid_()
                add_feature_summary('centernet_heatmap_lv{}'.format(lvl),
                                    hm_summary.cpu().numpy())

        H, W = pred_hm[0].shape[2:]
        level_num = len(pred_hm)
        pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm],
                            dim=-1)
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight

        # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2)
        ind_levels = ind.chunk(level_num, dim=1)
        pred_wh_pruned = []
        for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels):
            pred_wh_pruned.append(
                tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl))
        pred_wh_pruned = torch.cat(pred_wh_pruned,
                                   dim=1)  # (batch, max_obj*level_num, 2)
        mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat((center_location - pred_wh_pruned / 2.,
                                    center_location + pred_wh_pruned / 2.),
                                   dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            box_tl = torch.clamp(center_location - wh / 2., min=0)
            boxes = torch.cat((box_tl, box_br), dim=2)
            mask_expand_4 = mask.repeat(1, 1, 2)
            wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4)
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss