Esempio n. 1
0
    def __call__(self, pred_hm, pred_heights, pred_reg_xoffset,
                 pred_reg_yoffset, pred_pose, heatmap, heights, reg_xoffset,
                 reg_yoffset, pose, reg_mask, ind):
        """

        Args:
            pred_hm: tensor, (batch, cls, h, w).
            pred_heights: tensor, (batch, 3, h, w).
            pred_reg_xoffset: tensor, (batch, 3, h, w).
            pred_reg_yoffset: tensor, (batch, 3, h, w).
            pred_pose: tensor, (batch, 8, h, w).
            heatmap: tensor, (batch, cls, h, w).
            heights: tensor, (batch, max_obj, 3).
            reg_xoffset: tensor, (batch, max_obj, 3).
            reg_yoffset: tensor, (batch, max_obj, 3).
            pose: tensor, (batch, max_obj).
            reg_mask: tensor, (batch, max_obj, 3).
            ind: tensor, (batch, max_obj).

        Returns:

        """
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 3, h, w) => (batch, max_obj, 3)
        pred_heights = tranpose_and_gather_feat(pred_heights, ind)
        pred_reg_xoffset = tranpose_and_gather_feat(pred_reg_xoffset, ind)
        pred_reg_yoffset = tranpose_and_gather_feat(pred_reg_yoffset, ind)

        # cross_entropy only accepts (N,C,d) order
        # (batch, 8, h, w) => (batch, 8, max_obj)
        pred_pose = pred_pose.view(pred_pose.shape[0], pred_pose.shape[1], -1)
        pred_pose = pred_pose.gather(
            2,
            ind.unsqueeze(1).expand(-1, pred_pose.shape[1], -1))

        mask = reg_mask.float()
        avg_factor = mask.sum() + 1e-4

        heights_loss = weighted_l1(
            pred_heights, heights, mask,
            avg_factor=avg_factor) * self.heights_weight
        xoff_loss = weighted_l1(
            pred_reg_xoffset, reg_xoffset, mask,
            avg_factor=avg_factor) * self.xoff_weight
        yoff_loss = weighted_l1(
            pred_reg_yoffset, reg_yoffset, mask,
            avg_factor=avg_factor) * self.yoff_weight

        instance_mask = mask[..., 0]
        instance_af = instance_mask.sum() + 1e-4
        pose_loss = cross_entropy(
            pred_pose, pose, instance_mask,
            avg_factor=instance_af) * self.pose_weight

        return hm_loss, heights_loss, xoff_loss, yoff_loss, pose_loss
Esempio n. 2
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        return hm_loss

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss
Esempio n. 3
0
 def forward(self, objness, box_centers, box_scales, cls_preds, objness_t,
             center_t, scale_t, weight_t, class_t, class_mask):
     # compute some normalization count, except batch-size
     denorm = torch.tensor(objness_t.size())[1:].prod().to(torch.float32)
     class_mask = class_mask.to(torch.float32)
     weight_t = weight_t * objness_t
     hard_objness_t = torch.where(objness_t > 0, torch.ones_like(objness_t),
                                  objness_t)
     new_objness_mask = torch.where(objness_t > 0, objness_t,
                                    (objness_t >= 0).to(torch.float32))
     obj_loss = F.binary_cross_entropy_with_logits(
         objness, hard_objness_t, new_objness_mask) * denorm
     center_loss = F.binary_cross_entropy_with_logits(
         box_centers, center_t, weight_t) * denorm * 2
     scale_loss = weighted_l1(
         box_scales, scale_t, weight_t, reduction='mean') * denorm * 2
     denorm_class = torch.tensor(
         class_t.size())[1:].prod(dtype=torch.float32)
     class_mask = class_mask * objness_t
     cls_loss = F.binary_cross_entropy_with_logits(
         cls_preds, class_t, class_mask) * denorm_class
     return obj_loss, center_loss, scale_loss, cls_loss
Esempio n. 4
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Esempio n. 5
0
    def __call__(self, pred_hm, pred_wh, heatmap, wh, reg_mask, ind,
                 center_location):
        """

        Args:
            pred_hm: list(tensor), tensor <=> batch, (batch, 80, h, w).
            pred_wh: list(tensor), tensor <=> batch, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h*w for all levels).
            wh: tensor, (batch, max_obj*level_num, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj*level_num).
            ind: tensor, (batch, max_obj*level_num).
            center_location: tensor or None, (batch, max_obj*level_num, 2). Only useful when
                using GIOU.

        Returns:

        """
        if every_n_local_step(500):
            for lvl, hm in enumerate(pred_hm):
                hm_summary = hm.clone().detach().sigmoid_()
                add_feature_summary('centernet_heatmap_lv{}'.format(lvl),
                                    hm_summary.cpu().numpy())

        H, W = pred_hm[0].shape[2:]
        level_num = len(pred_hm)
        pred_hm = torch.cat([x.view(*x.shape[:2], -1) for x in pred_hm],
                            dim=-1)
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap, self.gamma) * self.hm_weight

        # (batch, 2, h, w) for all levels => (batch, max_obj*level_num, 2)
        ind_levels = ind.chunk(level_num, dim=1)
        pred_wh_pruned = []
        for pred_wh_per_lvl, ind_lvl in zip(pred_wh, ind_levels):
            pred_wh_pruned.append(
                tranpose_and_gather_feat(pred_wh_per_lvl, ind_lvl))
        pred_wh_pruned = torch.cat(pred_wh_pruned,
                                   dim=1)  # (batch, max_obj*level_num, 2)
        mask = reg_mask.unsqueeze(2).expand_as(pred_wh_pruned).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat((center_location - pred_wh_pruned / 2.,
                                    center_location + pred_wh_pruned / 2.),
                                   dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            box_tl = torch.clamp(center_location - wh / 2., min=0)
            boxes = torch.cat((box_tl, box_br), dim=2)
            mask_expand_4 = mask.repeat(1, 1, 2)
            wh_loss = giou_loss(pred_boxes, boxes, mask_expand_4)
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(
                    pred_wh_pruned, wh, mask,
                    avg_factor=avg_factor) * self.wh_weight

        return hm_loss, wh_loss