예제 #1
0
    def decode(self, box_encodings, anchors):
        """
        From a set of original boxes and encoded relative box offsets,
        get the decoded boxes.

        Arguments:
            rel_codes (Tensor): encoded boxes
            boxes (Tensor): reference boxes.
        """
        assert box_encodings.shape[0] == anchors.shape[0]
        assert anchors.shape[1] == 7
        num_classes = int(box_encodings.shape[1]/7)
        if num_classes != 1:
          num_loc = box_encodings.shape[0]
          box_encodings = box_encodings.view(-1, 7)
          anchors = anchors.view(num_loc,1,7)
          anchors = anchors.repeat(1,num_classes,1).view(-1,7)

        box_encodings = box_encodings / self.weights.to(box_encodings.device)
        box_encodings[:,3:6] = torch.clamp(box_encodings[:,3:6], max=self.bbox_xform_clip)
        boxes_decoded = second_box_decode(box_encodings, anchors, smooth_dim=self.smooth_dim)
        # yaw diff in [-pi/2, pi/2]
        boxes_decoded[:,-1] = limit_period(boxes_decoded[:,-1], 0.5, math.pi)

        if num_classes != 1:
          boxes_decoded = boxes_decoded.view(-1,num_classes*7)

        return boxes_decoded
예제 #2
0
def create_refine_loss_V2(loc_loss_ftor,
                          cls_loss_ftor,
                          example,
                          coarse_box_batch_preds,
                          coarse_cls_batch_preds,
                          refine_box_batch_preds,
                          refine_cls_batch_preds,
                          num_class,
                          loss_norm_type,
                          encode_background_as_zeros=True,
                          encode_rad_error_by_sin=True,
                          box_code_size=7):

    batch_size = example['anchors'].shape[0]
    batch_anchors_shape = example['anchors'].shape
    # coordinates = example['coordinates']
    gt_batch_boxes = example['gt_boxes']
    gt_batch_classes = example['gt_classes']
    anchors_batch_mask = example['anchors_mask']
    coarse_box_batch_preds = coarse_box_batch_preds.view(
        batch_size, -1, box_code_size)
    refine_box_batch_preds = refine_box_batch_preds.view(
        batch_size, -1, box_code_size)
    anchor_batch = example["anchors"].view(batch_size, -1, box_code_size)
    batch_out_True_bbox = torch.zeros(batch_anchors_shape,
                                      dtype=torch.float32).cuda()

    batch_out_True_label = -torch.ones(batch_anchors_shape[:2],
                                       dtype=torch.int64).cuda()
    for i in range(batch_size):

        anchors = anchor_batch[i, :, :]
        coarse_box_preds = coarse_box_batch_preds[i, :, :]
        refine_box_preds = refine_box_batch_preds[i, :, :]

        de_coarse_boxes = box_torch_ops.second_box_decode(
            coarse_box_preds, anchors)
        anchors = de_coarse_boxes

        gt_boxes_mask = gt_batch_boxes[:, 0] == i
        gt_boxes = gt_batch_boxes[gt_boxes_mask, 1:]
        gt_classes = gt_batch_classes[gt_boxes_mask]

        anchors_mask = anchors_batch_mask[i, :]
        vaild_anchors = torch.arange(len(anchors_mask))[anchors_mask]

        ############## compute overlap
        num_inside = len(vaild_anchors)
        total_anchors = anchors.shape[0]
        coarse_boxes = anchors[vaild_anchors, :]

        matched_threshold = 0.6 * torch.ones(num_inside).cuda()
        unmatched_threshold = 0.45 * torch.ones(num_inside).cuda()

        labels = -torch.ones((num_inside, ), dtype=torch.int64).cuda()
        gt_ids = -torch.ones((num_inside, ), dtype=torch.int64).cuda()
        if len(gt_boxes) > 0 and coarse_boxes.shape[0] > 0:
            # Compute overlaps between the anchors and the gt boxes overlaps
            anchor_by_gt_overlap = similarity_fn_torch(coarse_boxes, gt_boxes)

            # Map from anchor to gt box that has highest overlap
            anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(
                dim=1).type_as(labels)
            # For each anchor, amount of overlap with most overlapping gt box
            anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_inside),
                                                    anchor_to_gt_argmax]  #

            # Map from gt box to an anchor that has highest overlap
            gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(
                dim=0).type_as(labels)
            # For each gt box, amount of overlap with most overlapping anchor
            gt_to_anchor_max = anchor_by_gt_overlap[
                gt_to_anchor_argmax,
                torch.arange(anchor_by_gt_overlap.shape[1])]

            # must remove gt which doesn't match any anchor.
            empty_gt_mask = gt_to_anchor_max == 0
            gt_to_anchor_max[empty_gt_mask] = -1

            # Find all anchors that share the max overlap amount
            # (this includes many ties)
            mask = torch.eq(anchor_by_gt_overlap, gt_to_anchor_max)
            anchors_with_max_overlap = torch.argmax(mask, dim=0).sort()[0]
            anchors_with_max_overlap = anchors_with_max_overlap.type_as(labels)
            ####anchors_with_max_overlap = np.where(anchor_by_gt_overlap == gt_to_anchor_max)[0]

            # Fg label: for each gt use anchors with highest overlap
            # (including ties)
            gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
            labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
            gt_ids[anchors_with_max_overlap] = gt_inds_force
            # Fg label: above threshold IOU
            pos_inds = anchor_to_gt_max >= matched_threshold  ##
            gt_inds = anchor_to_gt_argmax[pos_inds]
            labels[pos_inds] = gt_classes[gt_inds]
            gt_ids[pos_inds] = gt_inds

            # bg_inds = np.where(anchor_to_gt_max < unmatched_threshold)[0]
            bg_inds = anchor_to_gt_max < unmatched_threshold
        else:
            bg_inds = torch.arange(num_inside)

        fg_inds = labels > 0

        if len(gt_boxes) == 0 or anchors.shape[0] == 0:
            labels[:] = 0
        else:

            labels[bg_inds] = 0
            # re-enable anchors_with_max_overlap
            labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]

        bbox_targets = torch.zeros((num_inside, box_code_size),
                                   dtype=coarse_boxes.dtype).cuda()

        if len(gt_boxes) > 0 and anchors.shape[0] > 0:
            bbox_targets[fg_inds, :] = box_torch_ops.second_box_encode(
                gt_boxes[anchor_to_gt_argmax[fg_inds], :],
                coarse_boxes[fg_inds, :])

        bbox_outside_weights = torch.zeros((num_inside, ),
                                           dtype=coarse_boxes.dtype).cuda()
        bbox_outside_weights[labels > 0] = 1.0

        ####  output:
        #ret_label = -torch.ones(total_anchors, dtype=torch.int32)
        #ret_label[vaild_anchors] = labels

        #ret_bbox = torch.zeros(anchors.shape, dtype=torch.float32)
        #ret_bbox[vaild_anchors, :] = bbox_targets

        batch_out_True_bbox[i, vaild_anchors, :] = bbox_targets
        batch_out_True_label[i, vaild_anchors] = labels

# ret_outside_weight = torch.zeros(total_anchors, dtype=torch.float32)
# ret_outside_weight[vaild_anchors] = bbox_outside_weights

    cls_weights, reg_weights, cared = prepare_loss_weights(
        batch_out_True_label,
        pos_cls_weight=1.0,  # pos_cls_weight = 1.0
        neg_cls_weight=1.0,  # neg_cls_weight = 1.0
        loss_norm_type=loss_norm_type,  ####################
        dtype=torch.float32)
    cls_targets = batch_out_True_label * cared.type_as(batch_out_True_label)
    cls_targets = cls_targets.unsqueeze(-1)

    if encode_background_as_zeros:
        coarse_conf = coarse_cls_batch_preds.view(batch_size, -1, num_class)
        refine_conf = refine_cls_batch_preds.view(batch_size, -1, num_class)

    else:
        coarse_conf = coarse_cls_batch_preds.view(batch_size, -1,
                                                  num_class + 1)
        refine_conf = refine_cls_batch_preds.view(batch_size, -1,
                                                  num_class + 1)

    cls_targets = cls_targets.squeeze(-1)
    one_hot_targets = torchplus.nn.one_hot(cls_targets,
                                           depth=num_class + 1,
                                           dtype=refine_box_batch_preds.dtype)
    if encode_background_as_zeros:  # True
        one_hot_targets = one_hot_targets[..., 1:]
    if encode_rad_error_by_sin:
        # sin(a - b) = sinacosb-cosasinb
        box_preds, reg_targets = add_sin_difference(refine_box_batch_preds,
                                                    batch_out_True_bbox)
    refine_loc_losses = loc_loss_ftor(
        box_preds, reg_targets, weights=reg_weights)  # [N, M]    # [2,70400,7]

    refine_cls_losses = cls_loss_ftor(refine_conf,
                                      one_hot_targets,
                                      weights=cls_weights)  # [N, M]
    return refine_loc_losses, refine_cls_losses
 def decode_torch(self, boxes, anchors):
     return box_torch_ops.second_box_decode(boxes, anchors, self.vec_encode,
                                            self.linear_dim)
예제 #4
0
def create_refine_loss(
        loc_loss_ftor,
        cls_loss_ftor,
        example,
        coarse_box_preds,
        coarse_cls_preds,
        refine_box_preds,
        refine_cls_preds,
        cls_targets,  # [B,H*W,1]
        cls_weights,  # [B,H*W]
        reg_targets,  # [B,H*W, 7]
        reg_weights,  # [B,H*W]
        num_class,
        encode_background_as_zeros=True,
        encode_rad_error_by_sin=True,
        box_code_size=7,
        reg_weights_ori=None):

    batch_size = example['anchors'].shape[0]
    anchors = example["anchors"].view(batch_size, -1, box_code_size)
    coarse_box_preds = coarse_box_preds.view(batch_size, -1, box_code_size)
    refine_box_preds = refine_box_preds.view(batch_size, -1, box_code_size)

    ## Decode  coarse boxes and Prior Anchors
    de_coarse_boxes = box_torch_ops.second_box_decode(coarse_box_preds,
                                                      anchors)

    ### Decode  GT and Prior Anchors
    de_gt_boxes = box_torch_ops.second_box_decode(reg_targets, anchors)
    #### Encode
    new_gt = box_torch_ops.second_box_encode(de_gt_boxes, de_coarse_boxes)

    if encode_background_as_zeros:
        coarse_conf = coarse_cls_preds.view(batch_size, -1, num_class)
        refine_conf = refine_cls_preds.view(batch_size, -1, num_class)

    else:
        coarse_conf = coarse_cls_preds.view(batch_size, -1, num_class + 1)
        refine_conf = refine_cls_preds.view(batch_size, -1, num_class + 1)

    cls_targets = cls_targets.squeeze(-1)
    one_hot_targets = torchplus.nn.one_hot(cls_targets,
                                           depth=num_class + 1,
                                           dtype=refine_box_preds.dtype)
    if encode_background_as_zeros:  # True
        one_hot_targets = one_hot_targets[..., 1:]
    if encode_rad_error_by_sin:
        # sin(a - b) = sinacosb-cosasinb
        box_preds, reg_targets = add_sin_difference(refine_box_preds, new_gt)
    refine_loc_losses = loc_loss_ftor(
        box_preds, reg_targets, weights=reg_weights)  # [N, M]    # [2,70400,7]

    refine_cls_losses = cls_loss_ftor(refine_conf,
                                      one_hot_targets,
                                      weights=cls_weights)  # [N, M]

    ##############################
    # ## if cfg.USE_IOU_LOSS:
    # coarse_iou_loss = compute_iou_loss(de_coarse_boxes, de_gt_boxes, coarse_conf, thre = 0.7, weights=reg_weights_ori)
    # de_refine_boxes =  box_torch_ops.second_box_decode(refine_box_preds, de_coarse_boxes)
    # refine_iou_loss = compute_iou_loss(de_refine_boxes, de_gt_boxes, refine_conf, thre = 0.7, weights=reg_weights_ori)
    # #############################
    ## if cfg.USE_CORNER_LOSS:
    # corner_loss_sum = corners_loss(de_refine_boxes, de_gt_boxes, weights=reg_weights)

    return refine_loc_losses, refine_cls_losses  #,coarse_iou_loss, refine_iou_loss #, 0.5*corner_loss_sum