예제 #1
0
파일: detection.py 프로젝트: jgwak/GSDN
  def loss_refinement_head(self, datum, output):
    if self.train_rpnonly:
      return torch.scalar_tensor(0.).to(self.device), dict()

    optimize_ref = False
    if output.get('ref_class_logits') is not None:
      if output['ref_class_logits'].shape[0] \
         > self.config.train_ref_min_sample_per_batch * self.config.batch_size:
        optimize_ref = True
      ref_labels = torch.cat(output['ref_target_class_ids'])
      ref_class_loss = self.ref_class_criterion(output['ref_class_logits'], ref_labels)
    else:
      ref_class_loss = torch.scalar_tensor(0.).to(self.device)

    if output.get('ref_bbox') is not None:
      positive_rois = torch.where(ref_labels > 0)[0]
      bbox_idx = ref_labels[positive_rois] - 1
      ref_bbox_pred = output['ref_bbox'][positive_rois, bbox_idx]
      ref_bbox_gt = torch.cat(output['ref_target_bbox'])
      ref_bbox_loss = self.ref_bbox_criterion(ref_bbox_pred, ref_bbox_gt)
    else:
      ref_bbox_loss = torch.scalar_tensor(0.).to(self.device)

    ref_rotation_diff = torch.scalar_tensor(0.).to(self.device)
    train_rotation = output.get('ref_rotdelta') is not None
    if train_rotation:
      ref_rotation_pred = output['ref_rotdelta'][positive_rois, bbox_idx]
      ref_rotation_gt = torch.cat(output['ref_target_rots'])
      ref_rotation_loss = self.ref_rotation_criterion(ref_rotation_pred, ref_rotation_gt)
      ref_rotation_pred = self.ref_rotation_criterion.pred(ref_rotation_pred)
      ref_rotation_diff = torch.abs(
          detection_utils.normalize_rotation(ref_rotation_pred - ref_rotation_gt)).mean()
    else:
      ref_rotation_loss = torch.scalar_tensor(0.).to(self.device)

    loss = torch.scalar_tensor(0.).to(self.device)
    if optimize_ref:
      loss += self.config.ref_class_weight * ref_class_loss \
          + self.config.ref_bbox_weight * ref_bbox_loss
      if train_rotation:
        loss += self.config.ref_rotation_weight * ref_rotation_loss

    loss_details = {
        'ref_class_loss': ref_class_loss,
        'ref_bbox_loss': ref_bbox_loss,
        'optimize_ref': torch.tensor([optimize_ref], dtype=torch.int).to(loss)[0],
    }

    if train_rotation:
      loss_details['ref_rotation_loss'] = ref_rotation_loss
      loss_details['ref_rotation_diff'] = ref_rotation_diff

    return loss, loss_details
예제 #2
0
파일: detection.py 프로젝트: jgwak/GSDN
  def loss_region_proposal(self, datum, output):
    # Compute RPN class loss.
    rpn_match_gt = datum['rpn_match'].flatten()
    rpn_match_mask = rpn_match_gt.nonzero().squeeze()
    if rpn_match_mask.size(0) > 0:
      rpn_match_gt_valid = (rpn_match_gt[rpn_match_mask] == 1).long()
      rpn_class_logits = output['rpn_class_logits'].reshape(-1, 2)[rpn_match_mask]
      rpn_class_loss = self.rpn_class_criterion(rpn_class_logits, rpn_match_gt_valid)
    else:
      rpn_class_loss = torch.scalar_tensor(0.).to(self.device)

    # Compute RPN bbox regression loss.
    rpn_bbox_mask = torch.where(rpn_match_gt == 1)[0]
    if rpn_bbox_mask.size(0) > 0:
      rpn_bbox_pred = output['rpn_bbox'].reshape(-1, 6)[rpn_bbox_mask]
      rpn_bbox_gt = datum['rpn_bbox'].reshape(-1, 6)
      rpn_gt_mask = ~torch.all(rpn_bbox_gt == 0, 1)
      rpn_bbox_gt = rpn_bbox_gt[rpn_gt_mask]
      rpn_bbox_loss = self.rpn_bbox_criterion(rpn_bbox_pred, rpn_bbox_gt)
    else:
      rpn_bbox_loss = torch.scalar_tensor(0.).to(self.device)

    loss = self.config.rpn_class_weight * rpn_class_loss \
        + self.config.rpn_bbox_weight * rpn_bbox_loss

    # Compute RPN bbox rotation loss.
    rpn_rotation_diff = torch.scalar_tensor(0.).to(self.device)
    if output.get('rpn_rotation') is not None:
      if rpn_bbox_mask.size(0) > 0:
        rpn_rotation_pred = output['rpn_rotation'].reshape(
            -1, self.rpn_rotation_criterion.NUM_OUTPUT)[rpn_bbox_mask]
        rpn_rotation_gt = datum['rpn_rotation'].flatten()[rpn_gt_mask]
        rpn_rotation_loss = self.rpn_rotation_criterion(rpn_rotation_pred, rpn_rotation_gt)
        rpn_rotation_pred = self.rpn_rotation_criterion.pred(rpn_rotation_pred)
        rpn_rotation_diff = torch.abs(
            detection_utils.normalize_rotation(rpn_rotation_pred - rpn_rotation_gt)).mean()
      else:
        rpn_rotation_loss = torch.scalar_tensor(0.).to(self.device)
      loss += self.config.rpn_rotation_weight * rpn_rotation_loss

    loss_details = {
        'rpn_class_loss': rpn_class_loss,
        'rpn_bbox_loss': rpn_bbox_loss,
    }

    if output.get('rpn_rotation') is not None:
      loss_details['rpn_rotation_loss'] = rpn_rotation_loss
      loss_details['rpn_rotation_diff'] = rpn_rotation_diff

    return loss, loss_details
예제 #3
0
파일: detection.py 프로젝트: jgwak/GSDN
  def loss_region_proposal(self, datum, output):
    if output['fpn_targets'] is not None:
      num_layers = len(self.backbone.OUT_PIXEL_DIST)
      sfpn_class_loss = torch.scalar_tensor(0.).to(self.device)
      for logits, targets in zip(output['fpn_classification'], output['fpn_targets']):
        if targets is None:
          continue
        sfpn_class_loss += self.sfpn_class_criterion(logits.F, targets.to(self.device)) / num_layers

    if output.get('rpn2anchor_maps') is None:
      output['rpn2anchor_maps'] = []
      for i, (rpn_class_logits, anchor) in enumerate(
              zip(output['rpn_class_logits'], datum['sparse_anchor_coords'])):
        if rpn_class_logits is None:
          output['rpn2anchor_maps'].append((None, None))
        else:
          assert rpn_class_logits.coords_key == output['rpn_bbox'][i].coords_key
          if output['rpn_rotation'][i] is not None:
            assert rpn_class_logits.coords_key == output['rpn_rotation'][i].coords_key
          output['rpn2anchor_maps'].append(detection_utils.map_coordinates(
              rpn_class_logits, anchor, check_input_map=True))

    rpn_masks = []
    rpn_class_preds = []
    rpn_class_gts = []
    for rpn_class_logits, rpn_match, (rpn2anchor, anchor2rpn) in zip(
            output['rpn_class_logits'], datum['sparse_rpn_match'], output['rpn2anchor_maps']):
      if rpn_class_logits is None:
        rpn_masks.append(None)
        continue
      rpn_match = rpn_match[anchor2rpn].flatten()
      rpn_masks.append(torch.where(rpn_match == 1)[0])
      rpn_match_mask = rpn_match.nonzero().squeeze(-1)
      if rpn_match_mask.size(0) > 0:
        rpn_class_preds.append(rpn_class_logits.F[rpn2anchor].reshape(-1, 2)[rpn_match_mask])
        rpn_class_gts.append((rpn_match[rpn_match_mask] == 1).long().to(self.device))
    rpn_class_loss = torch.scalar_tensor(0.).to(self.device)
    if rpn_class_preds:
      rpn_class_loss = self.rpn_class_criterion(torch.cat(rpn_class_preds),
                                                torch.cat(rpn_class_gts))

    rpn_bbox_preds = []
    rpn_bbox_gts = []
    for rpn_bbox_pred, rpn_bbox_gt, rpn_mask, (rpn2anchor, anchor2rpn) in zip(
            output['rpn_bbox'], datum['sparse_rpn_bbox'], rpn_masks, output['rpn2anchor_maps']):
      if rpn_bbox_pred is not None and rpn_mask.size(0) > 0:
        rpn_bbox_preds.append(rpn_bbox_pred.F[rpn2anchor].reshape(-1, 6)[rpn_mask])
        rpn_bbox_gts.append(rpn_bbox_gt[anchor2rpn].reshape(-1, 6)[rpn_mask].to(rpn_bbox_pred.F))
    rpn_bbox_loss = torch.scalar_tensor(0.).to(self.device)
    if rpn_bbox_preds:
      rpn_bbox_loss = self.rpn_bbox_criterion(torch.cat(rpn_bbox_preds), torch.cat(rpn_bbox_gts))

    loss = self.config.rpn_class_weight * rpn_class_loss \
        + self.config.rpn_bbox_weight * rpn_bbox_loss

    if output['fpn_targets'] is not None:
      loss += self.config.sfpn_class_weight * sfpn_class_loss

    train_rotation = any([rot is not None for rot in output.get('rpn_rotation')])
    if train_rotation:
      rpn_rotation_preds = []
      rpn_rotation_gts = []
      for rpn_rotation_pred, rpn_rotation_gt, rpn_mask, (rpn2anchor, anchor2rpn) in zip(
              output['rpn_rotation'], datum['sparse_rpn_rotation'], rpn_masks,
              output['rpn2anchor_maps']):
        if rpn_rotation_pred is not None and rpn_mask.size(0) > 0:
          rpn_rotation_preds.append(rpn_rotation_pred.F[rpn2anchor].reshape(
              -1, self.rpn_rotation_criterion.NUM_OUTPUT)[rpn_mask])
          rpn_rotation_gts.append(
              rpn_rotation_gt[anchor2rpn].flatten()[rpn_mask].to(rpn_rotation_pred.F))
      rpn_rotation_loss = torch.scalar_tensor(0.).to(self.device)
      rpn_rotation_diff = torch.scalar_tensor(0.).to(self.device)
      if rpn_rotation_preds:
        rpn_rotation_preds = torch.cat(rpn_rotation_preds)
        rpn_rotation_gts = torch.cat(rpn_rotation_gts)
        rpn_rotation_loss = self.rpn_rotation_criterion(rpn_rotation_preds, rpn_rotation_gts)
        rpn_rotation_preds = self.rpn_rotation_criterion.pred(rpn_rotation_preds)
        rpn_rotation_diff = torch.abs(detection_utils.normalize_rotation(
            rpn_rotation_preds - rpn_rotation_gts)).mean()
      loss += self.config.rpn_rotation_weight * rpn_rotation_loss

    loss_details = {
        'rpn_class_loss': rpn_class_loss,
        'rpn_bbox_loss': rpn_bbox_loss,
    }

    if output['fpn_targets'] is not None:
      loss_details['sfpn_class_loss'] = sfpn_class_loss

    if train_rotation:
      loss_details['rpn_rotation_loss'] = rpn_rotation_loss
      loss_details['rpn_rotation_diff'] = rpn_rotation_diff

    return loss, loss_details
예제 #4
0
파일: detection.py 프로젝트: jgwak/GSDN
 def detection_refinement(self, b_probs, b_rois, b_deltas, b_rots, b_rotdeltas):
   if b_probs is None:
     num_channel = 9 if b_rots is None else 8
     return [np.zeros((0, num_channel))]
   num_batch = [rois.shape[0] for rois in b_rois]
   num_samples = sum(num_batch)
   assert num_samples == b_probs.shape[0] == b_deltas.shape[0]
   if b_rots is not None:
     assert num_samples == sum(rots.shape[0] for rots in b_rots) == b_rotdeltas.shape[0]
   batch_split = [(sum(num_batch[:i]), sum(num_batch[:(i + 1)])) for i in range(len(num_batch))]
   b_probs = [b_probs[i:j] for (i, j) in batch_split]
   b_deltas = [b_deltas[i:j] for (i, j) in batch_split]
   if b_rots is not None:
     b_rotdeltas = [b_rotdeltas[i:j] for (i, j) in batch_split]
   b_nms = []
   b_nms_rot = None if b_rots is None else []
   for i, (probs, rois, deltas) in enumerate(zip(b_probs, b_rois, b_deltas)):
     rois = rois.reshape(-1, rois.shape[-1])
     class_ids = torch.argmax(probs, dim=1)
     batch_slice = range(probs.shape[0])
     class_scores = probs[batch_slice, class_ids]
     class_deltas = deltas[batch_slice, class_ids - 1]
     class_deltas *= torch.tensor(self.config.rpn_bbox_std).to(deltas)
     refined_rois = detection_utils.apply_box_deltas(rois, class_deltas,
                                                     self.config.normalize_bbox)
     if b_rots is not None:
       class_rot_deltas = b_rotdeltas[i][batch_slice, class_ids - 1]
       class_rot_deltas = self.ref_rotation_criterion.pred(class_rot_deltas)
       refined_rots = detection_utils.normalize_rotation(b_rots[i] + class_rot_deltas)
     keep = torch.where(class_ids > 0)[0].cpu().numpy()
     if self.config.detection_min_confidence:
       conf_keep = torch.where(class_scores > self.config.detection_min_confidence)[0]
       keep = np.array(list(set(conf_keep.cpu().numpy()).intersection(keep)))
     if keep.size == 0:
       b_nms.append(np.zeros((0, 8)))
       if b_rots is not None:
         b_nms_rot.append(np.zeros(0))
     else:
       pre_nms_class_ids = class_ids[keep] - 1
       pre_nms_scores = class_scores[keep]
       pre_nms_rois = refined_rois[keep]
       if b_rots is not None:
         pre_nms_rots = refined_rots[keep]
       nms_scores = []
       nms_rois = []
       nms_classes = []
       nms_rots = []
       for class_id in torch.unique(pre_nms_class_ids):
         class_nms_mask = pre_nms_class_ids == class_id
         class_nms_scores = pre_nms_scores[class_nms_mask]
         class_nms_rois = pre_nms_rois[class_nms_mask]
         pre_nms_class_rots = None
         if b_rots is not None:
           pre_nms_class_rots = pre_nms_rots[class_nms_mask]
         nms_roi, nms_rot, nms_score = detection_utils.non_maximum_suppression(
             class_nms_rois, pre_nms_class_rots, class_nms_scores,
             self.config.detection_nms_threshold, self.config.detection_max_instances,
             self.config.detection_rot_nms, self.config.detection_aggregate_overlap)
         nms_rois.append(nms_roi)
         nms_scores.append(nms_score)
         nms_classes.append(torch.ones(len(nms_score)).to(class_nms_rois) * class_id)
         if b_rots is not None:
           if self.config.normalize_rotation2:
             nms_rot = nms_rot / 2 + np.pi / 2
           nms_rots.append(nms_rot)
       nms_scores = torch.cat(nms_scores)
       nms_rois = torch.cat(nms_rois)
       nms_classes = torch.cat(nms_classes)
       detection_max_instances = min(self.config.detection_max_instances, nms_scores.shape[0])
       ix = torch.topk(nms_scores, detection_max_instances)[1]
       nms_rois_unnorm = detection_utils.unnormalize_boxes(
           nms_rois[ix].cpu().numpy(), self.config.max_ptc_size)
       nms_bboxes = np.hstack((nms_rois_unnorm, nms_classes[ix, None].cpu().numpy(),
                               nms_scores[ix, None].cpu().numpy()))
       if b_rots is not None:
         nms_rots = torch.cat(nms_rots)[ix, None].cpu().numpy()
         nms_bboxes = np.hstack((nms_bboxes[:, :6], nms_rots, nms_bboxes[:, 6:]))
       b_nms.append(nms_bboxes)
   return b_nms
예제 #5
0
파일: detection.py 프로젝트: jgwak/GSDN
  def _get_detection_target(self, b_proposals, b_rotations, b_gt_classes, b_gt_boxes,
                            b_gt_rotations):
    def _random_subsample_idxs(indices, num_samples):
      if indices.size(0) > num_samples:
        return torch.from_numpy(np.random.choice(indices.cpu().numpy(), num_samples,
                                replace=False)).to(indices)
      return indices
    b_rois, b_roi_gt_classes, b_deltas, b_roi_gt_box_assignment = [], [], [], []
    b_rots, b_rot_deltas = (None, None) if b_rotations is None else ([], [])
    for i, (proposals, gt_classes, gt_boxes) in enumerate(
            zip(b_proposals, b_gt_classes, b_gt_boxes)):
      with torch.no_grad():
        proposals = proposals[~torch.all(proposals == 0, 1)]
        gt_boxes = torch.from_numpy(gt_boxes).to(proposals)
        if gt_boxes.shape[0] == 0 or proposals.shape[0] == 0:
          b_deltas.append(torch.zeros((gt_boxes.shape[0], proposals.shape[1])).to(proposals))
          b_roi_gt_classes.append(torch.zeros(0).to(proposals).long())
          b_rois.append(torch.zeros((gt_boxes.shape[0], proposals.shape[1])).to(proposals))
          if b_rotations is not None:
            b_rots.append(torch.zeros(0).to(proposals))
            b_rot_deltas.append(torch.zeros(0).to(proposals))
          continue
        gt_classes = torch.from_numpy(gt_classes).to(proposals)
        pred_rotation, gt_rotation = None, None
        if self.config.ref_rotation_overlap and b_rotations is not None:
          pred_rotation = b_rotations[i]
          gt_rotation = b_gt_rotations[i]
        overlaps = detection_utils.compute_overlaps(proposals, gt_boxes, pred_rotation, gt_rotation)
        roi_iou_max = overlaps.max(1)[0]
        positive_roi = roi_iou_max > self.config.detection_match_positive_iou_threshold
        positive_indices = torch.where(positive_roi)[0]
        if self.config.force_proposal_match:
          torch.unique(torch.cat((torch.argmax(overlaps, 0), positive_indices)))
        negative_indices = torch.where(~positive_roi)[0]
        positive_count = int(self.config.roi_num_proposals_training
                             * self.config.roi_positive_ratio_training)
        positive_indices = _random_subsample_idxs(positive_indices, positive_count)
        positive_count = positive_indices.size(0)
        negative_count = int(
            positive_count / self.config.roi_positive_ratio_training) - positive_count
        negative_indices = _random_subsample_idxs(negative_indices, negative_count)
        negative_count = negative_indices.size(0)

        positive_rois = torch.index_select(proposals, 0, positive_indices)
        negative_rois = torch.index_select(proposals, 0, negative_indices)

        positive_overlaps = torch.index_select(overlaps, 0, positive_indices)
        roi_gt_box_assignment = (
            positive_overlaps.argmax(1)
            if positive_count else torch.empty(0).to(positive_overlaps).long())
        b_roi_gt_box_assignment.append(roi_gt_box_assignment)

        roi_gt_boxes = torch.index_select(gt_boxes, 0, roi_gt_box_assignment)
        roi_gt_classes = torch.index_select(gt_classes, 0, roi_gt_box_assignment).long()
        deltas = detection_utils.get_bbox_target(
            positive_rois, roi_gt_boxes, self.config.rpn_bbox_std)
        if b_rotations is not None:
          rotation = torch.cat((b_rotations[i][positive_indices], b_rotations[i][negative_indices]))
          if self.config.normalize_rotation2:
            rotation = rotation / 2 + np.pi / 2
          b_rots.append(rotation)
          b_rot_deltas.append(
              detection_utils.normalize_rotation(
                  torch.from_numpy(b_gt_rotations[i]).to(proposals)[roi_gt_box_assignment]
                  - b_rotations[i][positive_indices]))

        b_deltas.append(deltas)
        roi_gt_classes = torch.cat(
            (roi_gt_classes + 1, torch.zeros(negative_count).to(roi_gt_classes)))
        b_roi_gt_classes.append(roi_gt_classes)

        rois = torch.cat((positive_rois, negative_rois))
        b_rois.append(rois)
    return b_rois, b_rots, b_roi_gt_classes, b_deltas, b_rot_deltas, b_roi_gt_box_assignment