def loss_refinement_head(self, datum, output): if self.train_rpnonly: return torch.scalar_tensor(0.).to(self.device), dict() optimize_ref = False if output.get('ref_class_logits') is not None: if output['ref_class_logits'].shape[0] \ > self.config.train_ref_min_sample_per_batch * self.config.batch_size: optimize_ref = True ref_labels = torch.cat(output['ref_target_class_ids']) ref_class_loss = self.ref_class_criterion(output['ref_class_logits'], ref_labels) else: ref_class_loss = torch.scalar_tensor(0.).to(self.device) if output.get('ref_bbox') is not None: positive_rois = torch.where(ref_labels > 0)[0] bbox_idx = ref_labels[positive_rois] - 1 ref_bbox_pred = output['ref_bbox'][positive_rois, bbox_idx] ref_bbox_gt = torch.cat(output['ref_target_bbox']) ref_bbox_loss = self.ref_bbox_criterion(ref_bbox_pred, ref_bbox_gt) else: ref_bbox_loss = torch.scalar_tensor(0.).to(self.device) ref_rotation_diff = torch.scalar_tensor(0.).to(self.device) train_rotation = output.get('ref_rotdelta') is not None if train_rotation: ref_rotation_pred = output['ref_rotdelta'][positive_rois, bbox_idx] ref_rotation_gt = torch.cat(output['ref_target_rots']) ref_rotation_loss = self.ref_rotation_criterion(ref_rotation_pred, ref_rotation_gt) ref_rotation_pred = self.ref_rotation_criterion.pred(ref_rotation_pred) ref_rotation_diff = torch.abs( detection_utils.normalize_rotation(ref_rotation_pred - ref_rotation_gt)).mean() else: ref_rotation_loss = torch.scalar_tensor(0.).to(self.device) loss = torch.scalar_tensor(0.).to(self.device) if optimize_ref: loss += self.config.ref_class_weight * ref_class_loss \ + self.config.ref_bbox_weight * ref_bbox_loss if train_rotation: loss += self.config.ref_rotation_weight * ref_rotation_loss loss_details = { 'ref_class_loss': ref_class_loss, 'ref_bbox_loss': ref_bbox_loss, 'optimize_ref': torch.tensor([optimize_ref], dtype=torch.int).to(loss)[0], } if train_rotation: loss_details['ref_rotation_loss'] = ref_rotation_loss loss_details['ref_rotation_diff'] = ref_rotation_diff return loss, loss_details
def loss_region_proposal(self, datum, output): # Compute RPN class loss. rpn_match_gt = datum['rpn_match'].flatten() rpn_match_mask = rpn_match_gt.nonzero().squeeze() if rpn_match_mask.size(0) > 0: rpn_match_gt_valid = (rpn_match_gt[rpn_match_mask] == 1).long() rpn_class_logits = output['rpn_class_logits'].reshape(-1, 2)[rpn_match_mask] rpn_class_loss = self.rpn_class_criterion(rpn_class_logits, rpn_match_gt_valid) else: rpn_class_loss = torch.scalar_tensor(0.).to(self.device) # Compute RPN bbox regression loss. rpn_bbox_mask = torch.where(rpn_match_gt == 1)[0] if rpn_bbox_mask.size(0) > 0: rpn_bbox_pred = output['rpn_bbox'].reshape(-1, 6)[rpn_bbox_mask] rpn_bbox_gt = datum['rpn_bbox'].reshape(-1, 6) rpn_gt_mask = ~torch.all(rpn_bbox_gt == 0, 1) rpn_bbox_gt = rpn_bbox_gt[rpn_gt_mask] rpn_bbox_loss = self.rpn_bbox_criterion(rpn_bbox_pred, rpn_bbox_gt) else: rpn_bbox_loss = torch.scalar_tensor(0.).to(self.device) loss = self.config.rpn_class_weight * rpn_class_loss \ + self.config.rpn_bbox_weight * rpn_bbox_loss # Compute RPN bbox rotation loss. rpn_rotation_diff = torch.scalar_tensor(0.).to(self.device) if output.get('rpn_rotation') is not None: if rpn_bbox_mask.size(0) > 0: rpn_rotation_pred = output['rpn_rotation'].reshape( -1, self.rpn_rotation_criterion.NUM_OUTPUT)[rpn_bbox_mask] rpn_rotation_gt = datum['rpn_rotation'].flatten()[rpn_gt_mask] rpn_rotation_loss = self.rpn_rotation_criterion(rpn_rotation_pred, rpn_rotation_gt) rpn_rotation_pred = self.rpn_rotation_criterion.pred(rpn_rotation_pred) rpn_rotation_diff = torch.abs( detection_utils.normalize_rotation(rpn_rotation_pred - rpn_rotation_gt)).mean() else: rpn_rotation_loss = torch.scalar_tensor(0.).to(self.device) loss += self.config.rpn_rotation_weight * rpn_rotation_loss loss_details = { 'rpn_class_loss': rpn_class_loss, 'rpn_bbox_loss': rpn_bbox_loss, } if output.get('rpn_rotation') is not None: loss_details['rpn_rotation_loss'] = rpn_rotation_loss loss_details['rpn_rotation_diff'] = rpn_rotation_diff return loss, loss_details
def loss_region_proposal(self, datum, output): if output['fpn_targets'] is not None: num_layers = len(self.backbone.OUT_PIXEL_DIST) sfpn_class_loss = torch.scalar_tensor(0.).to(self.device) for logits, targets in zip(output['fpn_classification'], output['fpn_targets']): if targets is None: continue sfpn_class_loss += self.sfpn_class_criterion(logits.F, targets.to(self.device)) / num_layers if output.get('rpn2anchor_maps') is None: output['rpn2anchor_maps'] = [] for i, (rpn_class_logits, anchor) in enumerate( zip(output['rpn_class_logits'], datum['sparse_anchor_coords'])): if rpn_class_logits is None: output['rpn2anchor_maps'].append((None, None)) else: assert rpn_class_logits.coords_key == output['rpn_bbox'][i].coords_key if output['rpn_rotation'][i] is not None: assert rpn_class_logits.coords_key == output['rpn_rotation'][i].coords_key output['rpn2anchor_maps'].append(detection_utils.map_coordinates( rpn_class_logits, anchor, check_input_map=True)) rpn_masks = [] rpn_class_preds = [] rpn_class_gts = [] for rpn_class_logits, rpn_match, (rpn2anchor, anchor2rpn) in zip( output['rpn_class_logits'], datum['sparse_rpn_match'], output['rpn2anchor_maps']): if rpn_class_logits is None: rpn_masks.append(None) continue rpn_match = rpn_match[anchor2rpn].flatten() rpn_masks.append(torch.where(rpn_match == 1)[0]) rpn_match_mask = rpn_match.nonzero().squeeze(-1) if rpn_match_mask.size(0) > 0: rpn_class_preds.append(rpn_class_logits.F[rpn2anchor].reshape(-1, 2)[rpn_match_mask]) rpn_class_gts.append((rpn_match[rpn_match_mask] == 1).long().to(self.device)) rpn_class_loss = torch.scalar_tensor(0.).to(self.device) if rpn_class_preds: rpn_class_loss = self.rpn_class_criterion(torch.cat(rpn_class_preds), torch.cat(rpn_class_gts)) rpn_bbox_preds = [] rpn_bbox_gts = [] for rpn_bbox_pred, rpn_bbox_gt, rpn_mask, (rpn2anchor, anchor2rpn) in zip( output['rpn_bbox'], datum['sparse_rpn_bbox'], rpn_masks, output['rpn2anchor_maps']): if rpn_bbox_pred is not None and rpn_mask.size(0) > 0: rpn_bbox_preds.append(rpn_bbox_pred.F[rpn2anchor].reshape(-1, 6)[rpn_mask]) rpn_bbox_gts.append(rpn_bbox_gt[anchor2rpn].reshape(-1, 6)[rpn_mask].to(rpn_bbox_pred.F)) rpn_bbox_loss = torch.scalar_tensor(0.).to(self.device) if rpn_bbox_preds: rpn_bbox_loss = self.rpn_bbox_criterion(torch.cat(rpn_bbox_preds), torch.cat(rpn_bbox_gts)) loss = self.config.rpn_class_weight * rpn_class_loss \ + self.config.rpn_bbox_weight * rpn_bbox_loss if output['fpn_targets'] is not None: loss += self.config.sfpn_class_weight * sfpn_class_loss train_rotation = any([rot is not None for rot in output.get('rpn_rotation')]) if train_rotation: rpn_rotation_preds = [] rpn_rotation_gts = [] for rpn_rotation_pred, rpn_rotation_gt, rpn_mask, (rpn2anchor, anchor2rpn) in zip( output['rpn_rotation'], datum['sparse_rpn_rotation'], rpn_masks, output['rpn2anchor_maps']): if rpn_rotation_pred is not None and rpn_mask.size(0) > 0: rpn_rotation_preds.append(rpn_rotation_pred.F[rpn2anchor].reshape( -1, self.rpn_rotation_criterion.NUM_OUTPUT)[rpn_mask]) rpn_rotation_gts.append( rpn_rotation_gt[anchor2rpn].flatten()[rpn_mask].to(rpn_rotation_pred.F)) rpn_rotation_loss = torch.scalar_tensor(0.).to(self.device) rpn_rotation_diff = torch.scalar_tensor(0.).to(self.device) if rpn_rotation_preds: rpn_rotation_preds = torch.cat(rpn_rotation_preds) rpn_rotation_gts = torch.cat(rpn_rotation_gts) rpn_rotation_loss = self.rpn_rotation_criterion(rpn_rotation_preds, rpn_rotation_gts) rpn_rotation_preds = self.rpn_rotation_criterion.pred(rpn_rotation_preds) rpn_rotation_diff = torch.abs(detection_utils.normalize_rotation( rpn_rotation_preds - rpn_rotation_gts)).mean() loss += self.config.rpn_rotation_weight * rpn_rotation_loss loss_details = { 'rpn_class_loss': rpn_class_loss, 'rpn_bbox_loss': rpn_bbox_loss, } if output['fpn_targets'] is not None: loss_details['sfpn_class_loss'] = sfpn_class_loss if train_rotation: loss_details['rpn_rotation_loss'] = rpn_rotation_loss loss_details['rpn_rotation_diff'] = rpn_rotation_diff return loss, loss_details
def detection_refinement(self, b_probs, b_rois, b_deltas, b_rots, b_rotdeltas): if b_probs is None: num_channel = 9 if b_rots is None else 8 return [np.zeros((0, num_channel))] num_batch = [rois.shape[0] for rois in b_rois] num_samples = sum(num_batch) assert num_samples == b_probs.shape[0] == b_deltas.shape[0] if b_rots is not None: assert num_samples == sum(rots.shape[0] for rots in b_rots) == b_rotdeltas.shape[0] batch_split = [(sum(num_batch[:i]), sum(num_batch[:(i + 1)])) for i in range(len(num_batch))] b_probs = [b_probs[i:j] for (i, j) in batch_split] b_deltas = [b_deltas[i:j] for (i, j) in batch_split] if b_rots is not None: b_rotdeltas = [b_rotdeltas[i:j] for (i, j) in batch_split] b_nms = [] b_nms_rot = None if b_rots is None else [] for i, (probs, rois, deltas) in enumerate(zip(b_probs, b_rois, b_deltas)): rois = rois.reshape(-1, rois.shape[-1]) class_ids = torch.argmax(probs, dim=1) batch_slice = range(probs.shape[0]) class_scores = probs[batch_slice, class_ids] class_deltas = deltas[batch_slice, class_ids - 1] class_deltas *= torch.tensor(self.config.rpn_bbox_std).to(deltas) refined_rois = detection_utils.apply_box_deltas(rois, class_deltas, self.config.normalize_bbox) if b_rots is not None: class_rot_deltas = b_rotdeltas[i][batch_slice, class_ids - 1] class_rot_deltas = self.ref_rotation_criterion.pred(class_rot_deltas) refined_rots = detection_utils.normalize_rotation(b_rots[i] + class_rot_deltas) keep = torch.where(class_ids > 0)[0].cpu().numpy() if self.config.detection_min_confidence: conf_keep = torch.where(class_scores > self.config.detection_min_confidence)[0] keep = np.array(list(set(conf_keep.cpu().numpy()).intersection(keep))) if keep.size == 0: b_nms.append(np.zeros((0, 8))) if b_rots is not None: b_nms_rot.append(np.zeros(0)) else: pre_nms_class_ids = class_ids[keep] - 1 pre_nms_scores = class_scores[keep] pre_nms_rois = refined_rois[keep] if b_rots is not None: pre_nms_rots = refined_rots[keep] nms_scores = [] nms_rois = [] nms_classes = [] nms_rots = [] for class_id in torch.unique(pre_nms_class_ids): class_nms_mask = pre_nms_class_ids == class_id class_nms_scores = pre_nms_scores[class_nms_mask] class_nms_rois = pre_nms_rois[class_nms_mask] pre_nms_class_rots = None if b_rots is not None: pre_nms_class_rots = pre_nms_rots[class_nms_mask] nms_roi, nms_rot, nms_score = detection_utils.non_maximum_suppression( class_nms_rois, pre_nms_class_rots, class_nms_scores, self.config.detection_nms_threshold, self.config.detection_max_instances, self.config.detection_rot_nms, self.config.detection_aggregate_overlap) nms_rois.append(nms_roi) nms_scores.append(nms_score) nms_classes.append(torch.ones(len(nms_score)).to(class_nms_rois) * class_id) if b_rots is not None: if self.config.normalize_rotation2: nms_rot = nms_rot / 2 + np.pi / 2 nms_rots.append(nms_rot) nms_scores = torch.cat(nms_scores) nms_rois = torch.cat(nms_rois) nms_classes = torch.cat(nms_classes) detection_max_instances = min(self.config.detection_max_instances, nms_scores.shape[0]) ix = torch.topk(nms_scores, detection_max_instances)[1] nms_rois_unnorm = detection_utils.unnormalize_boxes( nms_rois[ix].cpu().numpy(), self.config.max_ptc_size) nms_bboxes = np.hstack((nms_rois_unnorm, nms_classes[ix, None].cpu().numpy(), nms_scores[ix, None].cpu().numpy())) if b_rots is not None: nms_rots = torch.cat(nms_rots)[ix, None].cpu().numpy() nms_bboxes = np.hstack((nms_bboxes[:, :6], nms_rots, nms_bboxes[:, 6:])) b_nms.append(nms_bboxes) return b_nms
def _get_detection_target(self, b_proposals, b_rotations, b_gt_classes, b_gt_boxes, b_gt_rotations): def _random_subsample_idxs(indices, num_samples): if indices.size(0) > num_samples: return torch.from_numpy(np.random.choice(indices.cpu().numpy(), num_samples, replace=False)).to(indices) return indices b_rois, b_roi_gt_classes, b_deltas, b_roi_gt_box_assignment = [], [], [], [] b_rots, b_rot_deltas = (None, None) if b_rotations is None else ([], []) for i, (proposals, gt_classes, gt_boxes) in enumerate( zip(b_proposals, b_gt_classes, b_gt_boxes)): with torch.no_grad(): proposals = proposals[~torch.all(proposals == 0, 1)] gt_boxes = torch.from_numpy(gt_boxes).to(proposals) if gt_boxes.shape[0] == 0 or proposals.shape[0] == 0: b_deltas.append(torch.zeros((gt_boxes.shape[0], proposals.shape[1])).to(proposals)) b_roi_gt_classes.append(torch.zeros(0).to(proposals).long()) b_rois.append(torch.zeros((gt_boxes.shape[0], proposals.shape[1])).to(proposals)) if b_rotations is not None: b_rots.append(torch.zeros(0).to(proposals)) b_rot_deltas.append(torch.zeros(0).to(proposals)) continue gt_classes = torch.from_numpy(gt_classes).to(proposals) pred_rotation, gt_rotation = None, None if self.config.ref_rotation_overlap and b_rotations is not None: pred_rotation = b_rotations[i] gt_rotation = b_gt_rotations[i] overlaps = detection_utils.compute_overlaps(proposals, gt_boxes, pred_rotation, gt_rotation) roi_iou_max = overlaps.max(1)[0] positive_roi = roi_iou_max > self.config.detection_match_positive_iou_threshold positive_indices = torch.where(positive_roi)[0] if self.config.force_proposal_match: torch.unique(torch.cat((torch.argmax(overlaps, 0), positive_indices))) negative_indices = torch.where(~positive_roi)[0] positive_count = int(self.config.roi_num_proposals_training * self.config.roi_positive_ratio_training) positive_indices = _random_subsample_idxs(positive_indices, positive_count) positive_count = positive_indices.size(0) negative_count = int( positive_count / self.config.roi_positive_ratio_training) - positive_count negative_indices = _random_subsample_idxs(negative_indices, negative_count) negative_count = negative_indices.size(0) positive_rois = torch.index_select(proposals, 0, positive_indices) negative_rois = torch.index_select(proposals, 0, negative_indices) positive_overlaps = torch.index_select(overlaps, 0, positive_indices) roi_gt_box_assignment = ( positive_overlaps.argmax(1) if positive_count else torch.empty(0).to(positive_overlaps).long()) b_roi_gt_box_assignment.append(roi_gt_box_assignment) roi_gt_boxes = torch.index_select(gt_boxes, 0, roi_gt_box_assignment) roi_gt_classes = torch.index_select(gt_classes, 0, roi_gt_box_assignment).long() deltas = detection_utils.get_bbox_target( positive_rois, roi_gt_boxes, self.config.rpn_bbox_std) if b_rotations is not None: rotation = torch.cat((b_rotations[i][positive_indices], b_rotations[i][negative_indices])) if self.config.normalize_rotation2: rotation = rotation / 2 + np.pi / 2 b_rots.append(rotation) b_rot_deltas.append( detection_utils.normalize_rotation( torch.from_numpy(b_gt_rotations[i]).to(proposals)[roi_gt_box_assignment] - b_rotations[i][positive_indices])) b_deltas.append(deltas) roi_gt_classes = torch.cat( (roi_gt_classes + 1, torch.zeros(negative_count).to(roi_gt_classes))) b_roi_gt_classes.append(roi_gt_classes) rois = torch.cat((positive_rois, negative_rois)) b_rois.append(rois) return b_rois, b_rots, b_roi_gt_classes, b_deltas, b_rot_deltas, b_roi_gt_box_assignment