def encode_centroid_box(self, targets, anchors): box_encodings = second_box_encode(targets, anchors, smooth_dim=self.smooth_dim) # yaw diff in [-pi/2, pi/2] box_encodings[:, -1] = limit_period(box_encodings[:, -1], 0.5, math.pi) box_encodings = box_encodings * self.weights.to(box_encodings.device) return box_encodings
def encode_torch(self, boxes, anchors): return box_torch_ops.second_box_encode(boxes, anchors, self.vec_encode, self.linear_dim)
def create_refine_loss_V2(loc_loss_ftor, cls_loss_ftor, example, coarse_box_batch_preds, coarse_cls_batch_preds, refine_box_batch_preds, refine_cls_batch_preds, num_class, loss_norm_type, encode_background_as_zeros=True, encode_rad_error_by_sin=True, box_code_size=7): batch_size = example['anchors'].shape[0] batch_anchors_shape = example['anchors'].shape # coordinates = example['coordinates'] gt_batch_boxes = example['gt_boxes'] gt_batch_classes = example['gt_classes'] anchors_batch_mask = example['anchors_mask'] coarse_box_batch_preds = coarse_box_batch_preds.view( batch_size, -1, box_code_size) refine_box_batch_preds = refine_box_batch_preds.view( batch_size, -1, box_code_size) anchor_batch = example["anchors"].view(batch_size, -1, box_code_size) batch_out_True_bbox = torch.zeros(batch_anchors_shape, dtype=torch.float32).cuda() batch_out_True_label = -torch.ones(batch_anchors_shape[:2], dtype=torch.int64).cuda() for i in range(batch_size): anchors = anchor_batch[i, :, :] coarse_box_preds = coarse_box_batch_preds[i, :, :] refine_box_preds = refine_box_batch_preds[i, :, :] de_coarse_boxes = box_torch_ops.second_box_decode( coarse_box_preds, anchors) anchors = de_coarse_boxes gt_boxes_mask = gt_batch_boxes[:, 0] == i gt_boxes = gt_batch_boxes[gt_boxes_mask, 1:] gt_classes = gt_batch_classes[gt_boxes_mask] anchors_mask = anchors_batch_mask[i, :] vaild_anchors = torch.arange(len(anchors_mask))[anchors_mask] ############## compute overlap num_inside = len(vaild_anchors) total_anchors = anchors.shape[0] coarse_boxes = anchors[vaild_anchors, :] matched_threshold = 0.6 * torch.ones(num_inside).cuda() unmatched_threshold = 0.45 * torch.ones(num_inside).cuda() labels = -torch.ones((num_inside, ), dtype=torch.int64).cuda() gt_ids = -torch.ones((num_inside, ), dtype=torch.int64).cuda() if len(gt_boxes) > 0 and coarse_boxes.shape[0] > 0: # Compute overlaps between the anchors and the gt boxes overlaps anchor_by_gt_overlap = similarity_fn_torch(coarse_boxes, gt_boxes) # Map from anchor to gt box that has highest overlap anchor_to_gt_argmax = anchor_by_gt_overlap.argmax( dim=1).type_as(labels) # For each anchor, amount of overlap with most overlapping gt box anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_inside), anchor_to_gt_argmax] # # Map from gt box to an anchor that has highest overlap gt_to_anchor_argmax = anchor_by_gt_overlap.argmax( dim=0).type_as(labels) # For each gt box, amount of overlap with most overlapping anchor gt_to_anchor_max = anchor_by_gt_overlap[ gt_to_anchor_argmax, torch.arange(anchor_by_gt_overlap.shape[1])] # must remove gt which doesn't match any anchor. empty_gt_mask = gt_to_anchor_max == 0 gt_to_anchor_max[empty_gt_mask] = -1 # Find all anchors that share the max overlap amount # (this includes many ties) mask = torch.eq(anchor_by_gt_overlap, gt_to_anchor_max) anchors_with_max_overlap = torch.argmax(mask, dim=0).sort()[0] anchors_with_max_overlap = anchors_with_max_overlap.type_as(labels) ####anchors_with_max_overlap = np.where(anchor_by_gt_overlap == gt_to_anchor_max)[0] # Fg label: for each gt use anchors with highest overlap # (including ties) gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap] labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] gt_ids[anchors_with_max_overlap] = gt_inds_force # Fg label: above threshold IOU pos_inds = anchor_to_gt_max >= matched_threshold ## gt_inds = anchor_to_gt_argmax[pos_inds] labels[pos_inds] = gt_classes[gt_inds] gt_ids[pos_inds] = gt_inds # bg_inds = np.where(anchor_to_gt_max < unmatched_threshold)[0] bg_inds = anchor_to_gt_max < unmatched_threshold else: bg_inds = torch.arange(num_inside) fg_inds = labels > 0 if len(gt_boxes) == 0 or anchors.shape[0] == 0: labels[:] = 0 else: labels[bg_inds] = 0 # re-enable anchors_with_max_overlap labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] bbox_targets = torch.zeros((num_inside, box_code_size), dtype=coarse_boxes.dtype).cuda() if len(gt_boxes) > 0 and anchors.shape[0] > 0: bbox_targets[fg_inds, :] = box_torch_ops.second_box_encode( gt_boxes[anchor_to_gt_argmax[fg_inds], :], coarse_boxes[fg_inds, :]) bbox_outside_weights = torch.zeros((num_inside, ), dtype=coarse_boxes.dtype).cuda() bbox_outside_weights[labels > 0] = 1.0 #### output: #ret_label = -torch.ones(total_anchors, dtype=torch.int32) #ret_label[vaild_anchors] = labels #ret_bbox = torch.zeros(anchors.shape, dtype=torch.float32) #ret_bbox[vaild_anchors, :] = bbox_targets batch_out_True_bbox[i, vaild_anchors, :] = bbox_targets batch_out_True_label[i, vaild_anchors] = labels # ret_outside_weight = torch.zeros(total_anchors, dtype=torch.float32) # ret_outside_weight[vaild_anchors] = bbox_outside_weights cls_weights, reg_weights, cared = prepare_loss_weights( batch_out_True_label, pos_cls_weight=1.0, # pos_cls_weight = 1.0 neg_cls_weight=1.0, # neg_cls_weight = 1.0 loss_norm_type=loss_norm_type, #################### dtype=torch.float32) cls_targets = batch_out_True_label * cared.type_as(batch_out_True_label) cls_targets = cls_targets.unsqueeze(-1) if encode_background_as_zeros: coarse_conf = coarse_cls_batch_preds.view(batch_size, -1, num_class) refine_conf = refine_cls_batch_preds.view(batch_size, -1, num_class) else: coarse_conf = coarse_cls_batch_preds.view(batch_size, -1, num_class + 1) refine_conf = refine_cls_batch_preds.view(batch_size, -1, num_class + 1) cls_targets = cls_targets.squeeze(-1) one_hot_targets = torchplus.nn.one_hot(cls_targets, depth=num_class + 1, dtype=refine_box_batch_preds.dtype) if encode_background_as_zeros: # True one_hot_targets = one_hot_targets[..., 1:] if encode_rad_error_by_sin: # sin(a - b) = sinacosb-cosasinb box_preds, reg_targets = add_sin_difference(refine_box_batch_preds, batch_out_True_bbox) refine_loc_losses = loc_loss_ftor( box_preds, reg_targets, weights=reg_weights) # [N, M] # [2,70400,7] refine_cls_losses = cls_loss_ftor(refine_conf, one_hot_targets, weights=cls_weights) # [N, M] return refine_loc_losses, refine_cls_losses
def create_refine_loss( loc_loss_ftor, cls_loss_ftor, example, coarse_box_preds, coarse_cls_preds, refine_box_preds, refine_cls_preds, cls_targets, # [B,H*W,1] cls_weights, # [B,H*W] reg_targets, # [B,H*W, 7] reg_weights, # [B,H*W] num_class, encode_background_as_zeros=True, encode_rad_error_by_sin=True, box_code_size=7, reg_weights_ori=None): batch_size = example['anchors'].shape[0] anchors = example["anchors"].view(batch_size, -1, box_code_size) coarse_box_preds = coarse_box_preds.view(batch_size, -1, box_code_size) refine_box_preds = refine_box_preds.view(batch_size, -1, box_code_size) ## Decode coarse boxes and Prior Anchors de_coarse_boxes = box_torch_ops.second_box_decode(coarse_box_preds, anchors) ### Decode GT and Prior Anchors de_gt_boxes = box_torch_ops.second_box_decode(reg_targets, anchors) #### Encode new_gt = box_torch_ops.second_box_encode(de_gt_boxes, de_coarse_boxes) if encode_background_as_zeros: coarse_conf = coarse_cls_preds.view(batch_size, -1, num_class) refine_conf = refine_cls_preds.view(batch_size, -1, num_class) else: coarse_conf = coarse_cls_preds.view(batch_size, -1, num_class + 1) refine_conf = refine_cls_preds.view(batch_size, -1, num_class + 1) cls_targets = cls_targets.squeeze(-1) one_hot_targets = torchplus.nn.one_hot(cls_targets, depth=num_class + 1, dtype=refine_box_preds.dtype) if encode_background_as_zeros: # True one_hot_targets = one_hot_targets[..., 1:] if encode_rad_error_by_sin: # sin(a - b) = sinacosb-cosasinb box_preds, reg_targets = add_sin_difference(refine_box_preds, new_gt) refine_loc_losses = loc_loss_ftor( box_preds, reg_targets, weights=reg_weights) # [N, M] # [2,70400,7] refine_cls_losses = cls_loss_ftor(refine_conf, one_hot_targets, weights=cls_weights) # [N, M] ############################## # ## if cfg.USE_IOU_LOSS: # coarse_iou_loss = compute_iou_loss(de_coarse_boxes, de_gt_boxes, coarse_conf, thre = 0.7, weights=reg_weights_ori) # de_refine_boxes = box_torch_ops.second_box_decode(refine_box_preds, de_coarse_boxes) # refine_iou_loss = compute_iou_loss(de_refine_boxes, de_gt_boxes, refine_conf, thre = 0.7, weights=reg_weights_ori) # ############################# ## if cfg.USE_CORNER_LOSS: # corner_loss_sum = corners_loss(de_refine_boxes, de_gt_boxes, weights=reg_weights) return refine_loc_losses, refine_cls_losses #,coarse_iou_loss, refine_iou_loss #, 0.5*corner_loss_sum