예제 #1
0
 def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                 bbox_targets, bbox_weights, num_pos_samples, cfg):
     # classification loss
     labels = labels.contiguous().view(-1, self.cls_out_channels)
     label_weights = label_weights.contiguous().view(
         -1, self.cls_out_channels)
     cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view(
         -1, self.cls_out_channels)
     if 'ghmc' in cfg:
         loss_cls = self.ghmc_loss.calc(cls_score, labels, label_weights)
     else:
         loss_cls = weighted_sigmoid_focal_loss(cls_score,
                                                labels,
                                                label_weights,
                                                cfg.gamma,
                                                cfg.alpha,
                                                avg_factor=num_pos_samples)
     # regression loss
     bbox_targets = bbox_targets.contiguous().view(-1, 4)
     bbox_weights = bbox_weights.contiguous().view(-1, 4)
     bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4)
     if 'ghmr' in cfg:
         loss_reg = self.ghmr_loss.calc(bbox_pred, bbox_targets,
                                        bbox_weights)
     else:
         loss_reg = weighted_smoothl1(bbox_pred,
                                      bbox_targets,
                                      bbox_weights,
                                      beta=cfg.smoothl1_beta,
                                      avg_factor=num_pos_samples)
     return loss_cls, loss_reg
예제 #2
0
 def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                 bbox_targets, bbox_locs, num_total_samples, cfg):
     # classification loss
     labels = labels.reshape(-1)
     label_weights = label_weights.reshape(-1)
     cls_score = cls_score.permute(0, 2, 3,
                                   1).reshape(-1, self.cls_out_channels)
     loss_cls = weighted_sigmoid_focal_loss(cls_score,
                                            labels,
                                            label_weights,
                                            cfg.gamma,
                                            cfg.alpha,
                                            avg_factor=num_total_samples)
     # localization loss
     if bbox_targets.size(0) == 0:
         loss_reg = bbox_pred.new_zeros(1)
     else:
         bbox_pred = bbox_pred.permute(0, 2, 3, 1)
         bbox_pred = bbox_pred[bbox_locs[:, 0], bbox_locs[:, 1],
                               bbox_locs[:, 2], :]
         loss_reg = select_iou_loss(bbox_pred,
                                    bbox_targets,
                                    cfg.bbox_reg_weight,
                                    avg_factor=num_total_samples)
     return loss_cls, loss_reg
예제 #3
0
 def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
                     cfg):
     if self.loc_focal_loss:
         loss_loc = weighted_sigmoid_focal_loss(loc_pred.reshape(-1, 1),
                                                loc_target.reshape(
                                                    -1, 1).long(),
                                                loc_weight.reshape(-1, 1),
                                                avg_factor=loc_avg_factor)
     else:
         loss_loc = weighted_binary_cross_entropy(loc_pred.reshape(-1, 1),
                                                  loc_target.reshape(
                                                      -1, 1).long(),
                                                  loc_weight.reshape(-1, 1),
                                                  avg_factor=loc_avg_factor)
     if hasattr(cfg, 'loc_weight'):
         loss_loc = loss_loc * cfg.loc_weight
     return loss_loc
예제 #4
0
    def feat_level_select(self, cls_score_list, bbox_pred_list, gt_bboxes,
                          gt_labels, cfg):
        if cfg.online_select:
            num_levels = len(cls_score_list)
            num_boxes = gt_bboxes.size(0)
            feat_losses = gt_bboxes.new_zeros((num_boxes, num_levels))
            device = bbox_pred_list[0].device
            for lvl in range(num_levels):
                stride = self.feat_strides[lvl]
                norm = stride * self.norm_factor
                cls_score = cls_score_list[lvl].permute(1, 2, 0)  # h x w x C
                bbox_pred = bbox_pred_list[lvl].permute(1, 2, 0)  # h x w x 4
                h, w = cls_score.size()[:2]

                proj_boxes = gt_bboxes / stride
                x1, y1, x2, y2 = self.prop_box_bounds(proj_boxes,
                                                      cfg.pos_scale, w, h)

                for i in range(num_boxes):
                    locs_x = torch.arange(x1[i],
                                          x2[i],
                                          device=device,
                                          dtype=torch.long)
                    locs_y = torch.arange(y1[i],
                                          y2[i],
                                          device=device,
                                          dtype=torch.long)
                    locs_xx, locs_yy = self._meshgrid(locs_x, locs_y)
                    avg_factor = locs_xx.size(0)
                    # classification focal loss
                    scores = cls_score[locs_yy, locs_xx, :]
                    labels = gt_labels[i].repeat(avg_factor)
                    label_weights = torch.ones_like(labels).float()
                    loss_cls = weighted_sigmoid_focal_loss(
                        scores, labels, label_weights, cfg.gamma, cfg.alpha,
                        avg_factor)
                    # localization iou loss
                    deltas = bbox_pred[locs_yy, locs_xx, :]
                    shift_x = (locs_x.float() + 0.5) * stride
                    shift_y = (locs_y.float() + 0.5) * stride
                    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
                    shifts = torch.stack(
                        (shift_xx, shift_yy, shift_xx, shift_yy), dim=-1)
                    shifts[:, 0] = shifts[:, 0] - gt_bboxes[i, 0]
                    shifts[:, 1] = shifts[:, 1] - gt_bboxes[i, 1]
                    shifts[:, 2] = gt_bboxes[i, 2] - shifts[:, 2]
                    shifts[:, 3] = gt_bboxes[i, 3] - shifts[:, 3]
                    loss_loc = select_iou_loss(deltas, shifts / norm,
                                               cfg.bbox_reg_weight, avg_factor)
                    feat_losses[i, lvl] = loss_cls + loss_loc
            feat_levels = torch.argmin(feat_losses, dim=1)
        else:
            num_levels = len(self.feat_strides)
            lvl0 = cfg.canonical_level
            s0 = cfg.canonical_scale
            assert 0 <= lvl0 < num_levels
            gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
            gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
            s = torch.sqrt(gt_w * gt_h)
            # FPN Eq. (1)
            feat_levels = torch.floor(lvl0 + torch.log2(s / s0 + 1e-6))
            feat_levels = torch.clamp(feat_levels, 0, num_levels - 1).int()
        return feat_levels