Example #1
0
    def loss_single(self, pred_hm, pred_wh, heatmap, box_target, wh_weight,
                    down_ratio, hm_weight_factor, wh_weight_factor):
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        loss_cls = ct_focal_loss(pred_hm, heatmap) * hm_weight_factor

        base_step = self.get_down_ratio(down_ratio)
        shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                base_step,
                                dtype=torch.float32,
                                device=heatmap.device)
        shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                base_step,
                                dtype=torch.float32,
                                device=heatmap.device)
        shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
        base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat(
            (base_loc - pred_wh[:, [0, 1]], base_loc + pred_wh[:, [2, 3]]),
            dim=1).permute(0, 2, 3, 1)
        boxes = box_target.permute(0, 2, 3, 1)

        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        loss_bbox = giou_loss(pred_boxes, boxes, mask,
                              avg_factor=avg_factor) * wh_weight_factor

        return loss_cls, loss_bbox
Example #2
0
    def __call__(self, pred_hm, pred_heights, pred_reg_xoffset,
                 pred_reg_yoffset, pred_pose, heatmap, heights, reg_xoffset,
                 reg_yoffset, pose, reg_mask, ind):
        """

        Args:
            pred_hm: tensor, (batch, cls, h, w).
            pred_heights: tensor, (batch, 3, h, w).
            pred_reg_xoffset: tensor, (batch, 3, h, w).
            pred_reg_yoffset: tensor, (batch, 3, h, w).
            pred_pose: tensor, (batch, 8, h, w).
            heatmap: tensor, (batch, cls, h, w).
            heights: tensor, (batch, max_obj, 3).
            reg_xoffset: tensor, (batch, max_obj, 3).
            reg_yoffset: tensor, (batch, max_obj, 3).
            pose: tensor, (batch, max_obj).
            reg_mask: tensor, (batch, max_obj, 3).
            ind: tensor, (batch, max_obj).

        Returns:

        """
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 3, h, w) => (batch, max_obj, 3)
        pred_heights = tranpose_and_gather_feat(pred_heights, ind)
        pred_reg_xoffset = tranpose_and_gather_feat(pred_reg_xoffset, ind)
        pred_reg_yoffset = tranpose_and_gather_feat(pred_reg_yoffset, ind)

        # cross_entropy only accepts (N,C,d) order
        # (batch, 8, h, w) => (batch, 8, max_obj)
        pred_pose = pred_pose.view(pred_pose.shape[0], pred_pose.shape[1], -1)
        pred_pose = pred_pose.gather(
            2,
            ind.unsqueeze(1).expand(-1, pred_pose.shape[1], -1))

        mask = reg_mask.float()
        avg_factor = mask.sum() + 1e-4

        heights_loss = weighted_l1(
            pred_heights, heights, mask,
            avg_factor=avg_factor) * self.heights_weight
        xoff_loss = weighted_l1(
            pred_reg_xoffset, reg_xoffset, mask,
            avg_factor=avg_factor) * self.xoff_weight
        yoff_loss = weighted_l1(
            pred_reg_yoffset, reg_yoffset, mask,
            avg_factor=avg_factor) * self.yoff_weight

        instance_mask = mask[..., 0]
        instance_af = instance_mask.sum() + 1e-4
        pose_loss = cross_entropy(
            pred_pose, pose, instance_mask,
            avg_factor=instance_af) * self.pose_weight

        return hm_loss, heights_loss, xoff_loss, yoff_loss, pose_loss
Example #3
0
    def loss_calc(self, pred_hm, pred_wh, pred_off, heatmap, box_target,
                  wh_weight, off_target):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            heatmap: tensor, same as pred_hm.
            box_target: tensor, same as pred_wh.
            wh_weight: tensor, same as pred_wh.

        Returns:
            hm_loss
            wh_loss
        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None or H != self.base_loc.shape[
                1] or W != self.base_loc.shape[2]:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.wh_weight
        mask_off = off_target.clone()
        mask_off = (mask_off > 0).float()
        num_mask = torch.nonzero(mask_off).size(0) * 0.5
        off_loss = self.crit_off(
            pred_off, off_target, weight=mask_off,
            avg_factor=num_mask) * self.off_weight

        return hm_loss, wh_loss, off_loss
Example #4
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        return hm_loss

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss
Example #5
0
    def loss_single(self,
                    pred_hm,
                    pred_wh,
                    heatmap,
                    box_target,
                    wh_weight,
                    down_ratio,
                    base_loc_name,
                    hm_weight_factor,
                    wh_weight_factor,
                    focal_loss_beta):
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        loss_cls = ct_focal_loss(pred_hm, heatmap,
                                 beta=focal_loss_beta) * hm_weight_factor

        if getattr(self, base_loc_name) is None or H != getattr(self, base_loc_name).shape[
            1] or W != getattr(self, base_loc_name).shape[2]:
            base_step = down_ratio
            shifts_x = torch.arange(
                0, (W - 1) * base_step + 1,
                base_step,
                dtype=torch.float32,
                device=heatmap.device)
            shifts_y = torch.arange(
                0, (H - 1) * base_step + 1,
                base_step,
                dtype=torch.float32,
                device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            setattr(self, base_loc_name, torch.stack((shift_x, shift_y), dim=0))  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((getattr(self, base_loc_name) - pred_wh[:, [0, 1]],
                                getattr(self, base_loc_name) + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        boxes = box_target.permute(0, 2, 3, 1)

        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        loss_bbox = giou_loss(
            pred_boxes, boxes, mask, avg_factor=avg_factor) * wh_weight_factor

        return loss_cls, loss_bbox
Example #6
0
    def loss_calc(self, pred_hm, pred_wh, pred_hm_2, pred_wh_2, pred_hm_3,
                  pred_wh_3, heatmap, box_target, wh_weight, heatmap_2,
                  box_target_2, wh_weight_2, heatmap_3, box_target_3,
                  wh_weight_3):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            heatmap: tensor, same as pred_hm.
            box_target: tensor, same as pred_wh.
            wh_weight: tensor, same as pred_wh.

        Returns:
            hm_loss
            wh_loss
        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        pred_hm_2 = torch.clamp(pred_hm_2.sigmoid_(), min=1e-4, max=1 - 1e-4)
        pred_hm_3 = torch.clamp(pred_hm_3.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight
        hm_loss_2 = ct_focal_loss(pred_hm_2, heatmap_2) * self.hm_weight_2
        hm_loss_3 = ct_focal_loss(pred_hm_3, heatmap_3) * self.hm_weight_3

        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        mask2 = wh_weight_2.view(-1, H, W)
        avg_factor2 = mask2.sum() + 1e-4

        mask3 = wh_weight_3.view(-1, H, W)
        avg_factor3 = mask3.sum() + 1e-4

        if self.base_loc is None or H != self.base_loc.shape[
                1] or W != self.base_loc.shape[2]:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        pred_boxes_2 = torch.cat((self.base_loc - pred_wh_2[:, [0, 1]],
                                  self.base_loc + pred_wh_2[:, [2, 3]]),
                                 dim=1).permute(0, 2, 3, 1)
        pred_boxes_3 = torch.cat((self.base_loc - pred_wh_3[:, [0, 1]],
                                  self.base_loc + pred_wh_3[:, [2, 3]]),
                                 dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        boxes_2 = box_target_2.permute(0, 2, 3, 1)
        boxes_3 = box_target_3.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.wh_weight
        wh_loss_2 = giou_loss(
            pred_boxes_2, boxes_2, mask2,
            avg_factor=avg_factor2) * self.wh_weight_2
        wh_loss_3 = giou_loss(
            pred_boxes_3, boxes_3, mask3,
            avg_factor=avg_factor3) * self.wh_weight_3

        return hm_loss, wh_loss, hm_loss_2, wh_loss_2, hm_loss_3, wh_loss_3
Example #7
0
    def loss_calc(self, pred_hm_large, pred_hm_little, pred_wh_large,
                  pred_wh_little, heatmap_large, heatmap_little,
                  box_target_large, box_target_little, wh_weight_large,
                  wh_weight_little):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            heatmap: tensor, same as pred_hm.
            box_target: tensor, same as pred_wh.
            wh_weight: tensor, same as pred_wh.

        Returns:
            hm_loss
            wh_loss
        """
        H, W = pred_hm_little.shape[2:]
        pred_hm_large = torch.clamp(pred_hm_large.sigmoid_(),
                                    min=1e-4,
                                    max=1 - 1e-4)
        pred_hm_little = torch.clamp(pred_hm_little.sigmoid_(),
                                     min=1e-4,
                                     max=1 - 1e-4)
        hm_loss = 1 * (
            ct_focal_loss(pred_hm_large, heatmap_large) * self.hm_weight +
            ct_focal_loss(pred_hm_little, heatmap_little) * self.hm_weight)

        mask_large = wh_weight_large.view(-1, H, W)
        mask_little = wh_weight_little.view(-1, H, W)
        avg_factor_large = mask_large.sum() + 1e-4
        avg_factor_little = mask_little.sum() + 1e-4

        if self.base_loc is None or H != self.base_loc.shape[
                1] or W != self.base_loc.shape[2]:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap_little.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap_little.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)

        pred_boxes_large = torch.cat(
            (self.base_loc - pred_wh_large[:, [0, 1]],
             self.base_loc + pred_wh_large[:, [2, 3]]),
            dim=1).permute(0, 2, 3, 1)

        pred_boxes_little = torch.cat(
            (self.base_loc - pred_wh_little[:, [0, 1]],
             self.base_loc + pred_wh_little[:, [2, 3]]),
            dim=1).permute(0, 2, 3, 1)

        # (batch, h, w, 4)
        boxes_large = box_target_large.permute(0, 2, 3, 1)
        boxes_little = box_target_little.permute(0, 2, 3, 1)
        wh_loss = giou_loss_ct(pred_boxes_large, boxes_large, mask_large, avg_factor=avg_factor_large) * self.wh_weight+ \
                  giou_loss_ct(pred_boxes_little, boxes_little, mask_little, avg_factor=avg_factor_little) * self.wh_weight

        return hm_loss, wh_loss
Example #8
0
    def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target,
                 centerness, wh_weight, hm_weight):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            pred_centerness: tensor or None, (batch, 1, h, w).
            heatmap: tensor, (batch, 80, h, w).
            box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            centerness: tensor or None, (batch, 1, h, w).
            wh_weight: tensor or None, (batch, 80, h, w).

        Returns:

        """
        if every_n_local_step(100):
            pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm),
                                          min=1e-4,
                                          max=1 - 1e-4)
            gt_hm_summary = heatmap.clone()
            if self.fovea_hm:
                if not self.only_merge:
                    pred_ctn_summary = torch.clamp(
                        torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4)
                    add_feature_summary(
                        'centernet/centerness',
                        pred_ctn_summary.detach().cpu().numpy(),
                        type='f')
                    add_feature_summary(
                        'centernet/merge',
                        (pred_ctn_summary *
                         pred_hm_summary).detach().cpu().numpy(),
                        type='max')

                add_feature_summary('centernet/gt_centerness',
                                    centerness.detach().cpu().numpy(),
                                    type='f')
                add_feature_summary('centernet/gt_merge',
                                    (centerness *
                                     gt_hm_summary).detach().cpu().numpy(),
                                    type='max')

            add_feature_summary('centernet/heatmap',
                                pred_hm_summary.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                gt_hm_summary.detach().cpu().numpy())

        H, W = pred_hm.shape[2:]
        if not self.fovea_hm:
            pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
            hm_weight = None if self.ct_version else hm_weight
            hm_loss = ct_focal_loss(pred_hm, heatmap,
                                    hm_weight=hm_weight) * self.hm_weight
            centerness_loss = hm_loss.new_tensor([0.])
            merge_loss = hm_loss.new_tensor([0.])
        else:
            care_mask = (heatmap >= 0).float()
            avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6
            if not self.only_merge:
                hm_loss = py_sigmoid_focal_loss(
                    pred_hm, heatmap, care_mask,
                    reduction='sum') / avg_factor * self.hm_weight

                pred_centerness = torch.clamp(torch.sigmoid(pred_centerness),
                                              min=1e-4,
                                              max=1 - 1e-4)
                centerness_loss = ct_focal_loss(
                    pred_centerness, centerness, gamma=2.) * self.ct_weight

                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm) * pred_centerness,
                                min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight
            else:
                hm_loss = pred_hm.new_tensor([0.])
                centerness_loss = pred_hm.new_tensor([0.])
                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm), min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight

        if not self.wh_agnostic:
            pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W)
            box_target = box_target.view(
                box_target.size(0) * pred_hm.size(1), 4, H, W)
        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.giou_weight

        return hm_loss, wh_loss, centerness_loss, merge_loss
Example #9
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Example #10
0
    def loss_calc(self, pred_feat, pred_hm, pred_wh, heatmap, box_target,
                  wh_weight):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            heatmap: tensor, same as pred_hm.
            box_target: tensor, same as pred_wh.
            wh_weight: tensor, same as pred_wh.

        Returns:
            hm_loss
            wh_loss
        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None or H != self.base_loc.shape[
                1] or W != self.base_loc.shape[2]:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = self.iou_loss(pred_boxes, boxes, mask,
                                avg_factor=avg_factor) * self.wh_weight

        wh2_loss = wh_loss.new_zeros([1])
        if self.two_stage:
            heat = simple_nms(pred_hm)
            scores, inds, clses, ys, xs = self._topk(heat, topk=100)

            pred_boxes_2 = pred_boxes.view(pred_boxes.size(0), -1,
                                           pred_boxes.size(3))
            boxes_2 = boxes.view(*pred_boxes_2.shape)
            inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1),
                                            pred_boxes_2.size(2))
            pred_boxes_2 = pred_boxes_2.gather(1, inds)  # (batch, 100, 4)
            boxes_2 = boxes_2.gather(1, inds)

            score_thr = 0.01
            scores_keep = scores > score_thr  # (batch, topk)

            batch_idx = pred_boxes_2.new_tensor(
                torch.arange(0., pred_boxes_2.shape[0],
                             1.)).view(-1, 1,
                                       1).expand(pred_boxes_2.shape[0],
                                                 pred_boxes_2.shape[1],
                                                 1)[scores_keep]
            pred_boxes_2 = pred_boxes_2[scores_keep]
            boxes_2 = boxes_2[scores_keep].detach()

            valid_boxes = (boxes_2 >= 0).min(1)[0]
            batch_idx = batch_idx[valid_boxes]  # (n, 1)
            pred_boxes_2 = pred_boxes_2[valid_boxes]  # (n, 4)
            boxes_2 = boxes_2[valid_boxes]  # (n, 4)
            roi_boxes = torch.cat((batch_idx, pred_boxes_2), dim=1).detach()

            if roi_boxes.size(0) > 0:
                rois = self.align(pred_feat, roi_boxes)  # (n, cha, 7, 7)
                pred_wh2 = self.wh2(rois).view(-1, 4)
                pred_boxes_2[:, [0, 1]] = pred_boxes_2[:, [0, 1]].detach() - \
                                          pred_wh2[:, [0, 1]] * 16
                pred_boxes_2[:, [2, 3]] = pred_boxes_2[:, [2, 3]].detach() + \
                                          pred_wh2[:, [2, 3]] * 16
                wh2_loss = giou_loss(pred_boxes_2, boxes_2,
                                     boxes_2.new_ones(boxes_2.size(0)))

        return hm_loss, wh_loss, wh2_loss
Example #11
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 center_location):
        """

        Args:
            pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w).
            pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h*w for all levels).
            wh: tensor, (batch, max_obj*level_num, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num).
            ind: tensor, (batch, max_obj*level_num).
            center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when
                using GIOU.

        Returns:

        """
        if every_n_local_step(500):
            for lvl, hm in enumerate(pred_hm):
                hm_summary = hm.clone().detach().sigmoid_()
                add_feature_summary('centernet_heatmap_lv{}'.format(lvl),
                                    hm_summary.cpu().numpy())

        H, W = pred_hm[0].shape[2:]
        level_num = len(pred_hm)
        pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm],
                            dim=-1)
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight

        # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2)
        ind_levels = ind.chunk(level_num, dim=1)
        pred_wh_pruned = []
        for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels):
            pred_wh_pruned.append(
                tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl))
        pred_wh_pruned = torch.cat(pred_wh_pruned,
                                   dim=1)  # (batch, max_obj*level_num, 2)
        mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat((center_location - pred_wh_pruned / 2.,
                                    center_location + pred_wh_pruned / 2.),
                                   dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            box_tl = torch.clamp(center_location - wh / 2., min=0)
            boxes = torch.cat((box_tl, box_br), dim=2)
            mask_expand_4 = mask.repeat(1, 1, 2)
            wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4)
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss