Example #1
0
    def encode(label_boxes_3d, label_boxes_2d, p2):
        """
        Args:
            label_boxes_3d: shape(N, K)
        Returns:
            C_2d: shape(N, 2)
            depth: shape(N, )
            side_points_2d: shape(N, 2, 2)
        """
        import ipdb
        ipdb.set_trace()
        num_samples = label_boxes_3d.shape[0]
        location = label_boxes_3d[:, :3]
        C_2d = geometry_utils.torch_points_3d_to_points_2d(location, p2)
        instance_depth = location[:, 2]

        # get side points (two side, yep we predict both of them)
        corners_2d = geometry_utils.torch_boxes_3d_to_corners_2d(
            label_boxes_3d, p2)
        bottom_corners = corners_2d[:, [0, 1, 2, 3]]
        #  left_side = corners_2d[:,[0,3]]
        #  right_side = corners_2d[:,[1,2]]

        encoded_all = torch.cat(
            [C_2d, instance_depth,
             bottom_corners.view(num_samples, -1)],
            dim=-1)
        return encoded_all
Example #2
0
    def encode(label_boxes_3d, p2):
        """
            projection points of 3d bbox center and its corners_3d in local
            coordinates frame

        Returns:
            depth of center:
            center 3d location:
            local_corners:
        """
        #  import ipdb
        #  ipdb.set_trace()
        # global to local
        global_corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        location = label_boxes_3d[:, :3]
        center_depth = location[:, -1:]
        center_2d = geometry_utils.torch_points_3d_to_points_2d(location, p2)
        ry = label_boxes_3d[:, -1:]

        num_boxes = global_corners_3d.shape[0]

        # local_corners_3d = (global_corners_3d.permute(0, 2, 1) -
        # location.unsqueeze(-1)).permute(
        # 0, 2, 1).contiguous().view(num_boxes, -1)

        # instance depth
        # instance_depth = location[:, -1:]
        dims = label_boxes_3d[:, 3:6]

        return torch.cat([dims, ry, center_2d, center_depth, location], dim=-1)
Example #3
0
def test_corners_3d_coder():

    # import ipdb
    # ipdb.set_trace()
    coder_config = {'type': constants.KEY_CORNERS_3D}
    bbox_coder = bbox_coders.build(coder_config)

    dataset = build_dataset()
    sample = dataset[0]
    label_boxes_3d = torch.from_numpy(sample[constants.KEY_LABEL_BOXES_3D])
    label_boxes_2d = torch.from_numpy(sample[constants.KEY_LABEL_BOXES_2D])
    p2 = torch.from_numpy(sample[constants.KEY_STEREO_CALIB_P2])
    proposals = torch.from_numpy(sample[constants.KEY_LABEL_BOXES_2D])
    num_instances = torch.from_numpy(sample[constants.KEY_NUM_INSTANCES])

    # ry = compute_ray_angle(label_boxes_3d[:, :3])
    # label_boxes_3d[:, -1] += ry

    label_boxes_3d = torch.stack(1 * [label_boxes_3d[:num_instances]], dim=0)
    label_boxes_2d = torch.stack(1 * [label_boxes_2d[:num_instances]], dim=0)
    proposals = torch.stack(1 * [proposals[:num_instances]], dim=0)
    p2 = torch.stack(1 * [p2], dim=0)

    # import ipdb
    # ipdb.set_trace()
    # label_boxes_3d[:, :, -1] = 0

    encoded_corners_3d = bbox_coder.encode_batch(label_boxes_3d,
                                                 label_boxes_2d, p2)
    #  torch.cat([encoded_corners_2d, ])
    num_boxes = encoded_corners_3d.shape[1]
    batch_size = encoded_corners_3d.shape[0]

    decoded_corners_3d = bbox_coder.decode_batch(
        encoded_corners_3d.view(batch_size, num_boxes, -1), proposals, p2)

    decoded_corners_2d = geometry_utils.torch_points_3d_to_points_2d(
        decoded_corners_3d[0].view(-1, 3), p2[0]).view(-1, 8, 2)
    decoded_corners_2d = decoded_corners_2d.cpu().detach().numpy()

    image_path = sample[constants.KEY_IMAGE_PATH]
    image_dir = '/data/object/training/image_2'
    result_dir = './results/data'
    save_dir = 'results/images'
    calib_dir = '/data/object/training/calib'
    label_dir = None
    calib_file = None
    visualizer = ImageVisualizer(image_dir,
                                 result_dir,
                                 label_dir=label_dir,
                                 calib_dir=calib_dir,
                                 calib_file=calib_file,
                                 online=False,
                                 save_dir=save_dir)
    visualizer.render_image_corners_2d(image_path, decoded_corners_2d)
    def encode(label_boxes_3d, label_boxes_2d, p2, image_info):
        """
            return projections of 3d bbox corners in the inner of 2d bbox.
            Note that set the visibility at the same time according to the 2d bbox
            and image boundary.(truncated or occluded)
        """
        # import ipdb
        # ipdb.set_trace()

        # shape(N, 8, 2)
        corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        corners_2d = geometry_utils.torch_points_3d_to_points_2d(
            corners_3d.reshape((-1, 3)), p2).reshape(-1, 8, 2)
        # corners_2d = geometry_utils.torch_boxes_3d_to_corners_2d(
        # label_boxes_3d, p2)
        corners_2d = NearestV2CornerCoder.reorder_boxes_4c(corners_2d)

        image_shape = torch.tensor([0, 0, image_info[1], image_info[0]])
        image_shape = image_shape.type_as(corners_2d).view(1, 4)
        image_filter = geometry_utils.torch_window_filter(corners_2d,
                                                          image_shape,
                                                          deltas=200)

        boxes_2d_filter = geometry_utils.torch_window_filter(
            corners_2d, label_boxes_2d)

        # disable it at preseant
        self_occluded_filter = Corner2DCoder.get_occluded_filter(corners_3d)
        # self_occluded_filter = torch.ones_like(image_filter)
        # self_occluded_filter = 0.1 * self_occluded_filter.float()

        # points outside of image must be filter out
        visibility = image_filter.float() * self_occluded_filter
        # visibility = visibility & boxes_2d_filter & self_occluded_filter

        # remove invisibility points
        #  corners_2d[~visibility] = -1

        # normalize using label bbox 2d
        label_boxes_2d_xywh = geometry_utils.torch_xyxy_to_xywh(
            label_boxes_2d.unsqueeze(0)).squeeze(0)
        wh = label_boxes_2d_xywh[:, 2:].unsqueeze(1)
        left_top = label_boxes_2d[:, :2].unsqueeze(1)
        # mid = label_boxes_2d_xywh[:, :2].unsqueeze(1)
        encoded_corners_2d = (corners_2d - left_top) / wh

        encoded_corners_2d = torch.cat(
            [encoded_corners_2d,
             visibility.unsqueeze(-1).float()], dim=-1)
        return encoded_corners_2d.contiguous().view(
            encoded_corners_2d.shape[0], -1)
Example #5
0
    def decode_bbox(self,
                    center_2d,
                    center_depth,
                    dims,
                    ry,
                    p2,
                    to_2d=False,
                    proposals_xywh=None):
        # location
        location = []
        N, M = center_2d.shape[:2]
        for batch_ind in range(N):
            location.append(
                geometry_utils.torch_points_2d_to_points_3d(
                    center_2d[batch_ind], center_depth[batch_ind],
                    p2[batch_ind]))

        location = torch.stack(location, dim=0)

        # local corners
        local_corners = []
        for batch_ind in range(N):
            local_corners.append(
                self.calc_local_corners(dims[batch_ind], ry[batch_ind]))
        local_corners = torch.stack(local_corners, dim=0)

        # global corners
        global_corners = (location.view(N, M, 1, 3) +
                          local_corners.view(N, M, 8, 3)).view(N, M, -1)
        if to_2d:
            corners_2d = []
            for batch_ind in range(N):
                corners_2d.append(
                    geometry_utils.torch_points_3d_to_points_2d(
                        global_corners[batch_ind].view(-1, 3), p2[batch_ind]))
            corners_2d = torch.stack(corners_2d, dim=0).view(N, M, -1)

            # encode for all proj points(Important!)
            encoded_corners_2d = (corners_2d.view(N, M, 8, 2) -
                                  proposals_xywh[:, :, None, :2]
                                  ) / proposals_xywh[:, :, None, :2]
            return encoded_corners_2d.view(N, M, -1)
        return global_corners
Example #6
0
    def encode_batch_bbox(self, gt_boxes_3d, proposals, assigned_gt_labels,
                          p2):
        """
        encoding dims may be better,here just encode dims_2d
        Args:
            dims: shape(N,6), (h,w,l) and their projection in 2d
        """
        # import ipdb
        # ipdb.set_trace()
        location = gt_boxes_3d[:, 3:6]
        dims = gt_boxes_3d[:, :3]
        ry = gt_boxes_3d[:, 6:]

        # ray_angle = -torch.atan2(location[:, 2], location[:, 0])
        # local_ry = ry - ray_angle.unsqueeze(-1)
        center_depth = location[:, -1:]
        center_2d = geometry_utils.torch_points_3d_to_points_2d(location, p2)

        targets = torch.cat([dims, ry, center_depth, center_2d], dim=-1)
        return targets
Example #7
0
    def encode(label_boxes_3d, label_boxes_2d, p2):
        """
            projection points of 3d bbox center and its corners_3d in local
            coordinates frame
        """
        # global to local
        global_corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        C = label_boxes_3d[:, :3]

        # proj of 3d bbox center
        C_2d = geometry_utils.torch_points_3d_to_points_2d(C, p2)

        alpha = geometry_utils.compute_ray_angle(C_2d.unsqueeze(0),
                                                 p2.unsqueeze(0)).squeeze(0)
        R = geometry_utils.torch_ry_to_rotation_matrix(-alpha).type_as(
            global_corners_3d)

        # local coords
        num_boxes = global_corners_3d.shape[0]
        local_corners_3d = torch.matmul(
            R,
            global_corners_3d.permute(0, 2, 1) - C.unsqueeze(-1)).permute(
                0, 2, 1).contiguous().view(num_boxes, -1)

        # instance depth
        instance_depth = C[:, -1:]

        # finally encode them(local_corners_3d is encoded already)
        # C_2d is encoded by center of 2d bbox
        # this func supports batch format only
        label_boxes_2d_xywh = geometry_utils.torch_xyxy_to_xywh(
            label_boxes_2d.unsqueeze(0)).squeeze(0)
        encoded_C_2d = (
            C_2d - label_boxes_2d_xywh[:, :2]) / label_boxes_2d_xywh[:, 2:]

        # instance_depth is encoded just by inverse it
        # instance_depth_inv = 1 / instance_depth

        return torch.cat([local_corners_3d, encoded_C_2d, instance_depth],
                         dim=-1)
Example #8
0
def test_geometry():

    dataset = build_dataset()
    for sample in dataset:
        # img_name = sample['img_name']
        # if img_name =='/data/object/training/image_2/001017.png':
        # import ipdb
        # ipdb.set_trace()
        # else:
        # continue

        label_boxes_3d = sample['gt_boxes_3d']
        p2 = torch.from_numpy(sample['p2'])
        label_boxes_3d = torch.cat([
            label_boxes_3d[:, 3:6], label_boxes_3d[:, :3], label_boxes_3d[:,
                                                                          6:]
        ],
                                   dim=-1)

        corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        front_mid = corners_3d[:, [0, 1]].mean(dim=1)
        rear_mid = corners_3d[:, [2, 3]].mean(dim=1)
        points_3d = torch.cat([rear_mid, front_mid], dim=0)
        points_2d = geometry_utils.torch_points_3d_to_points_2d(points_3d, p2)

        lines = points_2d.contiguous().view(2, -1, 2).permute(
            1, 0, 2).contiguous().view(-1, 4)
        # import ipdb
        # ipdb.set_trace()
        ry_pred1 = geometry_utils.torch_pts_2d_to_dir_3d_v2(
            lines.unsqueeze(0), p2.unsqueeze(0))[0]
        # ry_pred2 = geometry_utils.torch_dir_to_angle()
        # deltas = points_3d[1]-points_3d[0]
        # ry_pred2 = -torch.atan2(deltas[2], deltas[0])
        ry_gt = label_boxes_3d[:, -1]
        height = label_boxes_3d[:, 1]
        ry_gt[height < 0] = geometry_utils.reverse_angle(ry_gt[height < 0])
        cond = torch.abs(ry_pred1 - ry_gt) < 1e-4
        assert cond.all(), '{} error {} {}'.format(sample['img_name'], ry_gt,
                                                   ry_pred1)
Example #9
0
    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = super().loss(prediction_dict, feed_dict)
        targets = prediction_dict[constants.KEY_TARGETS]
        # rcnn_corners_loss = 0
        # rcnn_dim_loss = 0

        proposals = prediction_dict[constants.KEY_PROPOSALS]
        p2 = feed_dict[constants.KEY_STEREO_CALIB_P2]
        image_info = feed_dict[constants.KEY_IMAGE_INFO]
        mean_dims = torch.tensor([1.8, 1.8, 3.7]).type_as(proposals)
        corners_2d_loss = 0
        center_depth_loss = 0
        location_loss = 0

        for stage_ind in range(self.num_stages):
            corners_target = targets[stage_ind][2]
            # rcnn_corners_loss = rcnn_corners_loss + common_loss.calc_loss(
            # self.rcnn_corners_loss, orient_target, True)
            preds = corners_target['pred']
            targets = corners_target['target']
            weights = corners_target['weight']
            weights = weights.unsqueeze(-1)

            # gt
            local_corners_gt = targets[:, :, :24]
            location_gt = targets[:, :, 24:27]
            dims_gt = targets[:, :, 27:]
            N, M = local_corners_gt.shape[:2]

            global_corners_gt = (local_corners_gt.view(N, M, 8, 3) +
                                 location_gt.view(N, M, 1, 3)).view(N, M, -1)
            corners_depth_gt = global_corners_gt.view(N, M, 8, 3)[..., -1]
            center_depth_gt = location_gt[:, :, 2:]

            # preds
            corners_2d_preds = preds[:, :, :16]

            corners_2d_preds = self.decode_corners_2d(corners_2d_preds,
                                                      proposals)

            # import ipdb
            # ipdb.set_trace()
            local_corners_preds = []
            # calc local corners preds
            for batch_ind in range(N):
                local_corners_preds.append(
                    geometry_utils.torch_points_2d_to_points_3d(
                        corners_2d_preds[batch_ind].view(-1, 2),
                        corners_depth_gt[batch_ind].view(-1), p2[batch_ind]))
            local_corners_preds = torch.stack(
                local_corners_preds, dim=0).view(N, M, -1)
            # import ipdb
            # ipdb.set_trace()
            dims_preds = self.calc_dims_preds(local_corners_preds)

            dims_loss = self.l1_loss(dims_preds, dims_gt) * weights

            center_2d_deltas_preds = preds[:, :, 16:18]
            center_depth_preds = preds[:, :, 18:]
            # decode center_2d
            proposals_xywh = geometry_utils.torch_xyxy_to_xywh(proposals)
            center_2d_preds = (
                center_2d_deltas_preds * proposals_xywh[:, :, 2:] +
                proposals_xywh[:, :, :2])
            # center_depth_preds_detach = center_depth_preds.detach()

            # use gt depth to cal loss to make sure the gradient smooth
            location_preds = []
            for batch_ind in range(N):
                location_preds.append(
                    geometry_utils.torch_points_2d_to_points_3d(
                        center_2d_preds[batch_ind], center_depth_gt[batch_ind],
                        p2[batch_ind]))
            location_preds = torch.stack(location_preds, dim=0)
            global_corners_preds = (location_preds.view(N, M, 1, 3) +
                                    local_corners_preds.view(N, M, 8, 3)).view(
                                        N, M, -1)

            # import ipdb
            # ipdb.set_trace()
            # corners depth loss and center depth loss
            corners_depth_preds = local_corners_preds.view(N, M, 8, 3)[..., -1]
            corners_depth_gt = local_corners_gt.view(N, M, 8, 3)[..., -1]

            center_depth_loss = self.l1_loss(center_depth_preds,
                                             center_depth_gt) * weights

            # location loss
            location_loss = self.l1_loss(location_preds, location_gt) * weights

            # global corners loss
            global_corners_loss = self.l1_loss(global_corners_preds,
                                               global_corners_gt) * weights

            # proj 2d loss
            # corners_2d_preds = []
            corners_2d_gt = []
            for batch_ind in range(N):
                # corners_2d_preds.append(
                # geometry_utils.torch_points_3d_to_points_2d(
                # global_corners_preds[batch_ind].view(-1, 3),
                # p2[batch_ind]))
                corners_2d_gt.append(
                    geometry_utils.torch_points_3d_to_points_2d(
                        global_corners_gt[batch_ind].view(-1, 3),
                        p2[batch_ind]))

            # corners_2d_preds = torch.stack(
            # corners_2d_preds, dim=0).view(N, M, -1)
            corners_2d_gt = torch.stack(corners_2d_gt, dim=0).view(N, M, -1)

            # image filter
            # import ipdb
            # ipdb.set_trace()
            zeros = torch.zeros_like(image_info[:, 0])
            image_shape = torch.stack(
                [zeros, zeros, image_info[:, 1], image_info[:, 0]], dim=-1)
            image_shape = image_shape.type_as(corners_2d_gt).view(-1, 4)
            image_filter = geometry_utils.torch_window_filter(
                corners_2d_gt.view(N, -1, 2), image_shape,
                deltas=200).float().view(N, M, -1)

            # import ipdb
            # ipdb.set_trace()
            corners_2d_loss = self.l1_loss(
                corners_2d_preds.view(N, M, -1), corners_2d_gt) * weights
            corners_2d_loss = (corners_2d_loss.view(N, M, 8, 2) *
                               image_filter.unsqueeze(-1)).view(N, M, -1)
            corners_depth_loss = self.l1_loss(
                corners_depth_preds, corners_depth_gt) * weights * image_filter

            # import ipdb
            # ipdb.set_trace()
            # corners_3d_gt = []
            # for batch_ind in range(N):
            # corners_3d_gt.append(
            # geometry_utils.torch_points_2d_to_points_3d(
            # corners_2d_preds[batch_ind].view(-1, 2),
            # corners_depth_preds[batch_ind].view(-1), p2[batch_ind]))
            # corners_3d_gt = torch.stack(corners_3d_gt, dim=0).view(N, M, -1)

            # dim_target = targets[stage_ind][3]
            # rcnn_dim_loss = rcnn_dim_loss + common_loss.calc_loss(
            # self.rcnn_bbox_loss, dim_target, True)

            # global_corners_loss = self.l1_loss(global_corners_preds,
            # global_corners_gt) * weights
            # local_corners_loss = self.l1_loss(local_corners_preds,
            # local_corners_gt) * weights

        loss_dict.update({
            # 'global_corners_loss': global_corners_loss * 10,
            # 'local_corners_loss': local_corners_loss * 10,
            'corners_2d_loss': corners_2d_loss,
            # 'center_depth_loss': center_depth_loss * 10,
            # 'location_loss': location_loss * 10,
            # 'corners_depth_loss': corners_depth_loss * 10,
            # 'rcnn_corners_loss': rcnn_corners_loss,
            # 'rcnn_dim_loss': rcnn_dim_loss
            'dims_loss': dims_loss * 10
        })

        return loss_dict
Example #10
0
    def encode(label_boxes_3d, proposals, p2):
        """
        Args:
            label_boxes_3d: shape(N, 7)
            proposals: shape(N, 4)
            p2: shape(3, 4)
        """
        # import ipdb
        # ipdb.set_trace()
        # shape(N, 8, 3)
        corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        corners_2d = geometry_utils.torch_points_3d_to_points_2d(
            corners_3d.reshape((-1, 3)), p2).reshape(-1, 8, 2)
        # shape(N, 3)
        left_side_points_3d = (corners_3d[:, 0] + corners_3d[:, 3]) / 2
        right_side_points_3d = (corners_3d[:, 1] + corners_3d[:, 2]) / 2

        # shape(N, 2, 2)
        left_side = torch.stack([corners_2d[:, 0], corners_2d[:, 3]], dim=1)
        right_side = torch.stack([corners_2d[:, 1], corners_2d[:, 2]], dim=1)

        # shape(N, 2, 2, 2)
        side = torch.stack([left_side, right_side], dim=1)

        # no rotation
        K = p2[:3, :3]
        KT = p2[:, -1]
        T = torch.matmul(torch.inverse(K), KT)
        C = -T
        # shape(N, )
        left_dist = torch.norm(left_side_points_3d - C, dim=-1)
        right_dist = torch.norm(right_side_points_3d - C, dim=-1)
        dist = torch.stack([left_dist, right_dist], dim=-1)
        _, visible_index = torch.min(dist, dim=-1)

        row = torch.arange(visible_index.numel()).type_as(visible_index)
        # may be one of them or may be none of them
        visible_side = side[row, visible_index]

        # img_name = '/data/object/training/image_2/000052.png'
        # draw_line(img_name, visible_side)

        # in abnormal case both of them is invisible
        left_slope = geometry_utils.torch_line_to_orientation(left_side[:, 0],
                                                              left_side[:, 1])
        right_slope = geometry_utils.torch_line_to_orientation(
            right_side[:, 0], right_side[:, 1])
        non_visible_cond = left_slope * right_slope < 0

        visible_slope = geometry_utils.torch_line_to_orientation(
            visible_side[:, 0], visible_side[:, 1])
        # cls_orients
        cls_orients = visible_slope > 0
        cls_orients = cls_orients.float()
        cls_orients[non_visible_cond] = 2.0

        # reg_orients
        boxes_3d_proj = geometry_utils.torch_corners_2d_to_boxes_2d(corners_2d)
        # shape(N, 4)
        boxes_3d_proj_xywh = geometry_utils.torch_xyxy_to_xywh(
            boxes_3d_proj.unsqueeze(0)).squeeze(0)
        direction = torch.abs(visible_side[:, 0] - visible_side[:, 1])
        reg_orients = direction / boxes_3d_proj_xywh[:, 2:]

        return torch.cat([cls_orients.unsqueeze(-1), reg_orients], dim=-1)
Example #11
0
    def loss(self, prediction_dict, feed_dict):
        # import ipdb
        # ipdb.set_trace()
        loss_dict = {}
        anchors = prediction_dict['anchors']
        anchors_dict = {}
        anchors_dict[constants.KEY_PRIMARY] = anchors
        anchors_dict[
            constants.KEY_BOXES_2D] = prediction_dict['rpn_bbox_preds']
        anchors_dict[constants.KEY_CLASSES] = prediction_dict['rpn_cls_scores']
        anchors_dict[
            constants.KEY_CORNERS_3D_GRNET] = prediction_dict['corners_3d']

        gt_dict = {}
        gt_dict[constants.KEY_PRIMARY] = feed_dict[
            constants.KEY_LABEL_BOXES_2D]
        gt_dict[constants.KEY_CLASSES] = None
        gt_dict[constants.KEY_BOXES_2D] = None
        gt_dict[constants.KEY_CORNERS_3D_GRNET] = None

        auxiliary_dict = {}
        auxiliary_dict[constants.KEY_BOXES_2D] = feed_dict[
            constants.KEY_LABEL_BOXES_2D]
        gt_labels = feed_dict[constants.KEY_LABEL_CLASSES]
        auxiliary_dict[constants.KEY_CLASSES] = torch.ones_like(gt_labels)
        auxiliary_dict[constants.KEY_NUM_INSTANCES] = feed_dict[
            constants.KEY_NUM_INSTANCES]
        auxiliary_dict[constants.KEY_PROPOSALS] = anchors
        auxiliary_dict[constants.KEY_BOXES_3D] = feed_dict[
            constants.KEY_LABEL_BOXES_3D]
        auxiliary_dict[constants.KEY_STEREO_CALIB_P2] = feed_dict[
            constants.KEY_STEREO_CALIB_P2]

        # import ipdb
        # ipdb.set_trace()
        subsample = not self.use_focal_loss
        _, targets, _ = self.target_generators.generate_targets(
            anchors_dict, gt_dict, auxiliary_dict, subsample=subsample)

        cls_target = targets[constants.KEY_CLASSES]
        reg_target = targets[constants.KEY_BOXES_2D]

        # loss

        if self.use_focal_loss:
            # when using focal loss, dont normalize it by all samples
            cls_targets = cls_target['target']
            pos = cls_targets > 0  # [N,#anchors]
            num_pos = pos.long().sum().clamp(min=1).float()
            rpn_cls_loss = common_loss.calc_loss(
                self.rpn_cls_loss, cls_target, normalize=False) / num_pos
        else:
            rpn_cls_loss = common_loss.calc_loss(self.rpn_cls_loss, cls_target)
        rpn_reg_loss = common_loss.calc_loss(self.rpn_bbox_loss, reg_target)
        loss_dict.update({
            'rpn_cls_loss': rpn_cls_loss,
            'rpn_reg_loss': rpn_reg_loss
        })

        # return loss_dict
        # super().loss(prediction_dict, feed_dict)

        # proposals = prediction_dict[constants.KEY_PROPOSALS]
        proposals = anchors_dict[constants.KEY_PRIMARY]
        p2 = feed_dict[constants.KEY_STEREO_CALIB_P2]
        image_info = feed_dict[constants.KEY_IMAGE_INFO]
        mean_dims = torch.tensor([1.8, 1.8, 3.7]).type_as(proposals)
        corners_2d_loss = 0
        center_depth_loss = 0
        location_loss = 0

        corners_target = targets[constants.KEY_CORNERS_3D_GRNET]
        # rcnn_corners_loss = rcnn_corners_loss + common_loss.calc_loss(
        # self.rcnn_corners_loss, orient_target, True)
        preds = corners_target['pred']
        targets = corners_target['target']
        weights = corners_target['weight']
        weights = weights.unsqueeze(-1)

        local_corners_gt = targets[:, :, :24]
        location_gt = targets[:, :, 24:27]
        dims_gt = targets[:, :, 27:]
        N, M = local_corners_gt.shape[:2]

        global_corners_gt = (local_corners_gt.view(N, M, 8, 3) +
                             location_gt.view(N, M, 1, 3)).view(N, M, -1)
        center_depth_gt = location_gt[:, :, 2:]

        dims_preds = torch.exp(preds[:, :, :3]) * mean_dims
        # import ipdb
        # ipdb.set_trace()
        dims_loss = self.l1_loss(dims_preds, dims_gt) * weights
        ry_preds = preds[:, :, 3:4]
        # ray_angle = -torch.atan2(location_gt[:, :, 2], location_gt[:, :, 0])
        # ry_preds = ry_preds + ray_angle.unsqueeze(-1)
        local_corners_preds = []
        # calc local corners preds
        for batch_ind in range(N):
            local_corners_preds.append(
                self.calc_local_corners(dims_preds[batch_ind].detach(),
                                        ry_preds[batch_ind]))
        local_corners_preds = torch.stack(local_corners_preds, dim=0)

        center_2d_deltas_preds = preds[:, :, 4:6]
        center_depth_preds = preds[:, :, 6:]
        # import ipdb
        # ipdb.set_trace()
        # decode center_2d
        proposals_xywh = geometry_utils.torch_xyxy_to_xywh(proposals)
        center_depth_init = self.decode_center_depth(dims_preds,
                                                     proposals_xywh, p2)
        center_depth_preds = center_depth_init * center_depth_preds
        center_2d_preds = (center_2d_deltas_preds * proposals_xywh[:, :, 2:] +
                           proposals_xywh[:, :, :2])
        # center_depth_preds_detach = center_depth_preds.detach()

        # import ipdb
        # ipdb.set_trace()
        # use gt depth to cal loss to make sure the gradient smooth
        location_preds = []
        for batch_ind in range(N):
            location_preds.append(
                geometry_utils.torch_points_2d_to_points_3d(
                    center_2d_preds[batch_ind], center_depth_gt[batch_ind],
                    p2[batch_ind]))
        location_preds = torch.stack(location_preds, dim=0)
        global_corners_preds = (location_preds.view(N, M, 1, 3) +
                                local_corners_preds.view(N, M, 8, 3)).view(
                                    N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        # corners depth loss and center depth loss
        corners_depth_preds = local_corners_preds.view(N, M, 8, 3)[..., -1]
        corners_depth_gt = local_corners_gt.view(N, M, 8, 3)[..., -1]

        center_depth_loss = self.l1_loss(center_depth_preds,
                                         center_depth_gt) * weights

        # location loss
        location_loss = self.l1_loss(location_preds, location_gt) * weights

        # global corners loss
        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights

        # proj 2d loss
        corners_2d_preds = []
        corners_2d_gt = []
        for batch_ind in range(N):
            corners_2d_preds.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_preds[batch_ind].view(-1, 3),
                    p2[batch_ind]))
            corners_2d_gt.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_gt[batch_ind].view(-1, 3), p2[batch_ind]))

        corners_2d_preds = torch.stack(corners_2d_preds, dim=0).view(N, M, -1)
        corners_2d_gt = torch.stack(corners_2d_gt, dim=0).view(N, M, -1)

        # image filter
        # import ipdb
        # ipdb.set_trace()
        zeros = torch.zeros_like(image_info[:, 0])
        image_shape = torch.stack(
            [zeros, zeros, image_info[:, 1], image_info[:, 0]], dim=-1)
        image_shape = image_shape.type_as(corners_2d_gt).view(-1, 4)
        image_filter = geometry_utils.torch_window_filter(
            corners_2d_gt.view(N, -1, 2), image_shape,
            deltas=200).float().view(N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        encoded_corners_2d_gt = corners_2d_gt.view(N, M, 8, 2)
        encoded_corners_2d_preds = corners_2d_preds.view(N, M, 8, 2)
        corners_2d_loss = self.l2_loss(encoded_corners_2d_preds.view(
            N, M, -1), encoded_corners_2d_gt.view(N, M, -1)) * weights
        corners_2d_loss = (corners_2d_loss.view(N, M, 8, 2) *
                           image_filter.unsqueeze(-1))
        # import ipdb
        # ipdb.set_trace()
        # mask = self.select_corners(global_corners_gt)
        # mask = mask.unsqueeze(-1).expand_as(corners_2d_loss).float()
        corners_2d_loss = corners_2d_loss.view(N, M, -1)
        corners_depth_loss = self.l1_loss(
            corners_depth_preds, corners_depth_gt) * weights * image_filter

        # import ipdb
        # ipdb.set_trace()
        # corners_3d_gt = []
        # for batch_ind in range(N):
        # corners_3d_gt.append(
        # geometry_utils.torch_points_2d_to_points_3d(
        # corners_2d_preds[batch_ind].view(-1, 2),
        # corners_depth_preds[batch_ind].view(-1), p2[batch_ind]))
        # corners_3d_gt = torch.stack(corners_3d_gt, dim=0).view(N, M, -1)

        # dim_target = targets[stage_ind][3]
        # rcnn_dim_loss = rcnn_dim_loss + common_loss.calc_loss(
        # self.rcnn_bbox_loss, dim_target, True)

        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights
        # local_corners_loss = self.l1_loss(local_corners_preds,
        # local_corners_gt) * weights
        # import ipdb
        # ipdb.set_trace()
        num_pos = (weights > 0).long().sum().clamp(min=1).float()

        loss_dict.update({
            # 'global_corners_loss': global_corners_loss,
            # 'local_corners_loss': local_corners_loss * 10,
            'corners_2d_loss': corners_2d_loss,
            # 'center_depth_loss': center_depth_loss,
            # 'location_loss': location_loss,
            # 'corners_depth_loss': corners_depth_loss * 10,
            # 'rcnn_corners_loss': rcnn_corners_loss,
            # 'rcnn_dim_loss': rcnn_dim_loss
            # 'dims_loss': dims_loss
        })

        return loss_dict
Example #12
0
    def loss(self, prediction_dict, feed_dict):
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]

        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]
        corners_target = targets[constants.KEY_CORNERS_3D_GRNET]
        # dims_target = targets[constants.KEY_DIMS]
        # orients_target = targets[constants.KEY_ORIENTS_V2]

        loc1_preds = loc1_target['pred']
        loc2_preds = loc2_target['pred']
        loc1_target = loc1_target['target']
        loc2_target = loc2_target['target']
        assert loc1_target.shape == loc2_target.shape
        loc_target = loc1_target

        conf_preds = cls_target['pred']
        conf_target = cls_target['target']
        conf_weight = cls_target['weight']
        conf_target[conf_weight == 0] = -1

        os_preds = os_target['pred']
        os_target_ = os_target['target']
        os_weight = os_target['weight']
        os_target_[os_weight == 0] = -1

        loc_loss, os_loss, conf_loss = self.two_step_loss(loc1_preds,
                                                          loc2_preds,
                                                          loc_target,
                                                          conf_preds,
                                                          conf_target,
                                                          os_preds,
                                                          os_target_,
                                                          is_print=False)

        # import ipdb
        # ipdb.set_trace()
        # 3d loss
        # corners_loss = common_loss.calc_loss(self.rcnn_corners_loss,
        # corners_2d_target)

        # import ipdb
        # ipdb.set_trace()
        preds = corners_target['pred']
        targets = corners_target['target']
        weights = corners_target['weight']
        proposals = prediction_dict[constants.KEY_PROPOSALS]
        p2 = feed_dict[constants.KEY_STEREO_CALIB_P2]
        image_info = feed_dict[constants.KEY_IMAGE_INFO]
        weights = weights.unsqueeze(-1)

        local_corners_gt = targets[:, :, :24]
        location_gt = targets[:, :, 24:27]
        dims_gt = targets[:, :, 27:]
        N, M = local_corners_gt.shape[:2]

        global_corners_gt = (local_corners_gt.view(N, M, 8, 3) +
                             location_gt.view(N, M, 1, 3)).view(N, M, -1)
        center_depth_gt = location_gt[:, :, 2:]

        mean_dims = torch.tensor([1.8, 1.8, 3.7]).type_as(preds)
        dims_preds = torch.exp(preds[:, :, :3]) * mean_dims
        # import ipdb
        # ipdb.set_trace()
        dims_loss = self.l1_loss(dims_preds, dims_gt) * weights
        ry_preds = preds[:, :, 3:4]
        # ray_angle = -torch.atan2(location_gt[:, :, 2],
        # location_gt[:, :, 0])
        # ry_preds = ry_preds + ray_angle.unsqueeze(-1)
        local_corners_preds = []
        # calc local corners preds
        for batch_ind in range(N):
            local_corners_preds.append(
                self.calc_local_corners(dims_preds[batch_ind].detach(),
                                        ry_preds[batch_ind]))
        local_corners_preds = torch.stack(local_corners_preds, dim=0)

        center_2d_deltas_preds = preds[:, :, 4:6]
        center_depth_preds = preds[:, :, 6:]
        # import ipdb
        # ipdb.set_trace()
        # decode center_2d
        proposals_xywh = geometry_utils.torch_xyxy_to_xywh(proposals)
        center_depth_init = self.decode_center_depth(dims_preds,
                                                     proposals_xywh, p2)
        center_depth_preds = center_depth_init * center_depth_preds
        center_2d_preds = (center_2d_deltas_preds * proposals_xywh[:, :, 2:] +
                           proposals_xywh[:, :, :2])
        # center_depth_preds_detach = center_depth_preds.detach()

        # import ipdb
        # ipdb.set_trace()
        # use gt depth to cal loss to make sure the gradient smooth
        location_preds = []
        for batch_ind in range(N):
            location_preds.append(
                geometry_utils.torch_points_2d_to_points_3d(
                    center_2d_preds[batch_ind], center_depth_preds[batch_ind],
                    p2[batch_ind]))
        location_preds = torch.stack(location_preds, dim=0)
        global_corners_preds = (location_preds.view(N, M, 1, 3) +
                                local_corners_preds.view(N, M, 8, 3)).view(
                                    N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        # corners depth loss and center depth loss
        corners_depth_preds = local_corners_preds.view(N, M, 8, 3)[..., -1]
        corners_depth_gt = local_corners_gt.view(N, M, 8, 3)[..., -1]

        # import ipdb
        # ipdb.set_trace()
        center_depth_loss = self.l1_loss(center_depth_preds,
                                         center_depth_gt) * weights

        # location loss
        location_loss = self.l1_loss(location_preds, location_gt) * weights

        # global corners loss
        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights

        # proj 2d loss
        corners_2d_preds = []
        corners_2d_gt = []
        for batch_ind in range(N):
            corners_2d_preds.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_preds[batch_ind].view(-1, 3),
                    p2[batch_ind]))
            corners_2d_gt.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_gt[batch_ind].view(-1, 3), p2[batch_ind]))

        corners_2d_preds = torch.stack(corners_2d_preds, dim=0).view(N, M, -1)
        corners_2d_gt = torch.stack(corners_2d_gt, dim=0).view(N, M, -1)

        # image filter
        # import ipdb
        # ipdb.set_trace()
        zeros = torch.zeros_like(image_info[:, 0])
        image_shape = torch.stack(
            [zeros, zeros, image_info[:, 1], image_info[:, 0]], dim=-1)
        image_shape = image_shape.type_as(corners_2d_gt).view(-1, 4)
        image_filter = geometry_utils.torch_window_filter(
            corners_2d_gt.view(N, -1, 2), image_shape,
            deltas=200).float().view(N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        encoded_corners_2d_gt = corners_2d_gt.view(N, M, 8, 2)
        encoded_corners_2d_preds = corners_2d_preds.view(N, M, 8, 2)
        # import ipdb
        # ipdb.set_trace()
        corners_2d_loss = self.l1_loss(encoded_corners_2d_preds.view(
            N, M, -1), encoded_corners_2d_gt.view(N, M, -1)) * weights
        corners_2d_loss = (corners_2d_loss.view(N, M, 8, 2) *
                           image_filter.unsqueeze(-1))
        # import ipdb
        # ipdb.set_trace()
        # mask = self.select_corners(global_corners_gt)
        # mask = mask.unsqueeze(-1).expand_as(corners_2d_loss).float()
        corners_2d_loss = corners_2d_loss.view(N, M, -1)
        corners_depth_loss = self.l1_loss(
            corners_depth_preds, corners_depth_gt) * weights * image_filter

        # import ipdb
        # ipdb.set_trace()
        # corners_3d_gt = []
        # for batch_ind in range(N):
        # corners_3d_gt.append(
        # geometry_utils.torch_points_2d_to_points_3d(
        # corners_2d_preds[batch_ind].view(-1, 2),
        # corners_depth_preds[batch_ind].view(-1), p2[batch_ind]))
        # corners_3d_gt = torch.stack(corners_3d_gt, dim=0).view(N, M, -1)

        # dim_target = targets[stage_ind][3]
        # rcnn_dim_loss = rcnn_dim_loss + common_loss.calc_loss(
        # self.rcnn_bbox_loss, dim_target, True)

        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights

        # rpn_orients_loss = common_loss.calc_loss(self.rcnn_orient_loss,
        # corners_2d_target) * 100

        # loss

        # import ipdb
        # ipdb.set_trace()
        # loss_dict['total_loss'] = total_loss
        pos = weights > 0  # [N,#anchors]
        num_pos = pos.data.long().sum().clamp(min=1).float()

        loss_dict['loc_loss'] = loc_loss
        loss_dict['os_loss'] = os_loss
        loss_dict['conf_loss'] = conf_loss
        # loss_dict['corners_2d_loss'] = corners_2d_loss.sum() / num_pos * 0.1
        loss_dict['dims_loss'] = dims_loss.sum() / num_pos * 10
        loss_dict['global_corners_loss'] = global_corners_loss.sum(
        ) / num_pos * 10
        loss_dict['location_loss'] = location_loss.sum() / num_pos * 10
        loss_dict['center_depth_loss'] = center_depth_loss.sum() / num_pos * 10
        # loss_dict['orients_loss'] = rpn_orients_loss

        return loss_dict
Example #13
0
    def encode(label_boxes_3d, proposals, p2, image_info):
        """
        return projections of 3d bbox corners in the inner of 2d bbox.
            Note that set the visibility at the same time according to the 2d bbox
            and image boundary.(truncated or occluded)
        """
        label_boxes_2d = proposals
        # shape(N, 8, 2)
        corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        corners_2d = geometry_utils.torch_points_3d_to_points_2d(
            corners_3d.reshape((-1, 3)), p2).reshape(-1, 8, 2)

        image_shape = torch.tensor([0, 0, image_info[1], image_info[0]])
        image_shape = image_shape.type_as(corners_2d).view(1, 4)
        image_filter = geometry_utils.torch_window_filter(corners_2d,
                                                          image_shape,
                                                          deltas=200)

        # points outside of image must be filter out
        visibility = image_filter.float()

        # normalize using label bbox 2d
        label_boxes_2d_xywh = geometry_utils.torch_xyxy_to_xywh(
            label_boxes_2d.unsqueeze(0)).squeeze(0)
        # shape(N, 4, 2)
        label_corners_4c = geometry_utils.torch_xyxy_to_corner_4c(
            label_boxes_2d.unsqueeze(0)).squeeze(0)
        wh = label_boxes_2d_xywh[:, 2:].unsqueeze(1).unsqueeze(1)
        # left_top = label_boxes_2d[:, :2].unsqueeze(1)
        # mid = label_boxes_2d_xywh[:, :2].unsqueeze(1)
        corners_2d = corners_2d.unsqueeze(2)
        label_corners_4c = label_corners_4c.unsqueeze(1)
        encoded_corners_2d = (corners_2d - label_corners_4c) / wh
        # mean_size = torch.sqrt(wh[..., 0] * wh[..., 1])
        # weights = math_utils.gaussian2d(
        # corners_2d, label_corners_4c, sigma=mean_size)

        # import ipdb
        # ipdb.set_trace()
        dist = torch.norm(encoded_corners_2d, dim=-1)  # (N,8,4)
        dist_min, dist_argmin = dist.min(dim=-1)  # (N,8)
        corners_2d_scores = torch.zeros_like(dist)
        corners_2d_scores = corners_2d_scores.view(-1, 4)
        # offset = torch.arange(dist_argmin.numel()) * 4
        # col_index = dist_argmin.view(-1) + offset.type_as(dist_argmin)
        col_index = dist_argmin.view(-1)
        row_index = torch.arange(col_index.numel()).type_as(col_index)
        corners_2d_scores[row_index, col_index] = 1
        corners_2d_scores = corners_2d_scores.view(-1, 8, 4)
        # tensor_utils.multidim_index(corners_2d_scores, dist_argmin)
        visibility = visibility.unsqueeze(-1) * corners_2d_scores

        # encoded_corners_2d = torch.cat(
        # [
        # encoded_corners_2d,
        # visibility.unsqueeze(-1)
        # # corners_2d_scores.unsqueeze(-1)
        # ],
        # dim=-1)
        # encoded_corners_2d = torch.cat(
        # [
        # encoded_corners_2d.view(encoded_corners_2d.shape[0], 8, -1),
        # dist_argmin.unsqueeze(-1).float()
        # ],
        # dim=-1)
        # encoded_corners_2d = encoded_corners_2d.contiguous().view(
        # encoded_corners_2d.shape[0], -1)
        # import ipdb
        # ipdb.set_trace()
        N = encoded_corners_2d.shape[0]
        return torch.cat([
            encoded_corners_2d.contiguous().view(N, -1),
            visibility.view(N, -1),
            dist_argmin.float().view(N, -1)
        ],
                         dim=-1)
Example #14
0
    def encode(label_boxes_3d, proposals, p2, image_info, label_boxes_2d):
        """
            projection points of 3d bbox center and its corners_3d in local
            coordinates frame

        Returns:
            depth of center:
            center 3d location:
            local_corners:
        """
        num_instances = label_boxes_3d.shape[0]
        # global to local
        corners_2d = geometry_utils.torch_boxes_3d_to_corners_2d(
            label_boxes_3d, p2)

        proposals_xywh = geometry_utils.torch_xyxy_to_xywh(
            proposals.unsqueeze(0)).squeeze(0)
        wh = proposals_xywh[:, 2:].unsqueeze(1)
        xy = proposals_xywh[:, :2].unsqueeze(1)

        corners_3d = geometry_utils.torch_boxes_3d_to_corners_3d(
            label_boxes_3d)
        bottom_corners_3d = corners_3d[:, [0, 1, 2, 3]]
        visible_index = Corner3DCoder.find_visible_side(bottom_corners_3d)
        visible_corners_3d = tensor_utils.multidim_index(
            bottom_corners_3d, visible_index)
        visible_side_line_2d = geometry_utils.torch_points_3d_to_points_2d(
            visible_corners_3d.contiguous().view(-1, 3),
            p2).view(num_instances, -1, 2)
        visible_cond = (
            visible_side_line_2d[:, 1, 0] - visible_side_line_2d[:, 0, 0]
        ) * (visible_side_line_2d[:, 2, 0] - visible_side_line_2d[:, 0, 0]) < 0

        # visible_index[invisible_cond, -1] = visible_index[invisible_cond, -2]
        _, order = torch.sort(visible_side_line_2d[..., 0],
                              dim=-1,
                              descending=False)
        visible_index = tensor_utils.multidim_index(
            visible_index.unsqueeze(-1), order).squeeze(-1)

        # import ipdb
        # ipdb.set_trace()
        bottom_corners = corners_2d[:, [0, 1, 2, 3]]
        top_corners = corners_2d[:, [4, 5, 6, 7]]
        bottom_corners = tensor_utils.multidim_index(bottom_corners,
                                                     visible_index)
        top_corners = tensor_utils.multidim_index(top_corners, visible_index)
        bottom_corners_3d = tensor_utils.multidim_index(
            bottom_corners_3d, visible_index)
        dist = torch.norm(bottom_corners_3d, dim=-1)
        merge_left_cond = dist[:, 0] < dist[:, 2]

        # box truncated
        # import ipdb
        # ipdb.set_trace()
        # bottom
        # left
        bottom_corners[:, 0, 0] = torch.min(bottom_corners[:, 0, 0],
                                            label_boxes_2d[:, 2])
        bottom_corners[:, 0, 0] = torch.max(bottom_corners[:, 0, 0],
                                            label_boxes_2d[:, 0])

        # right
        bottom_corners[:, 2, 0] = torch.min(bottom_corners[:, 2, 0],
                                            label_boxes_2d[:, 2])
        bottom_corners[:, 2, 0] = torch.max(bottom_corners[:, 2, 0],
                                            label_boxes_2d[:, 0])

        # top
        top_corners[:, 0, 0] = torch.min(top_corners[:, 0, 0],
                                         label_boxes_2d[:, 2])
        top_corners[:, 0, 0] = torch.max(top_corners[:, 0, 0],
                                         label_boxes_2d[:, 0])

        top_corners[:, 2, 0] = torch.min(top_corners[:, 2, 0],
                                         label_boxes_2d[:, 2])
        top_corners[:, 2, 0] = torch.max(top_corners[:, 2, 0],
                                         label_boxes_2d[:, 0])

        in_box_cond = (bottom_corners[:, 1, 0] < label_boxes_2d[:, 2]) & (
            bottom_corners[:, 1, 0] > label_boxes_2d[:, 0])

        # bottom_corners[:, [0, 2], 0] = bottom_corners[:, [0, 2], 0]
        # top_corners[:, :, 0] = top_corners[:, :, 0].clamp(
        # min=0, max=image_info[1])

        visibility = visible_cond.float() * in_box_cond.float()
        # import ipdb
        # ipdb.set_trace()
        index = torch.nonzero(visibility <= 0).view(-1)
        tmp = bottom_corners[index]
        merge_left_cond = merge_left_cond[index]
        merge_right_cond = ~merge_left_cond
        tmp_left = torch.stack([tmp[:, 0], tmp[:, 0], tmp[:, 2]], dim=1)
        tmp_right = torch.stack([tmp[:, 0], tmp[:, 2], tmp[:, 2]], dim=1)
        # tmp = torch.cat(
        # [tmp_left[merge_left_cond], tmp_right[~merge_left_cond]], dim=0)
        tmp[merge_left_cond] = tmp_left[merge_left_cond]
        tmp[merge_right_cond] = tmp_right[merge_right_cond]
        bottom_corners[index] = tmp

        tmp = top_corners[index]
        # tmp = torch.stack([tmp[:, 0], tmp[:, 0], tmp[:, 2]], dim=1)
        tmp_left = torch.stack([tmp[:, 0], tmp[:, 0], tmp[:, 2]], dim=1)
        tmp_right = torch.stack([tmp[:, 0], tmp[:, 2], tmp[:, 2]], dim=1)
        tmp[merge_left_cond] = tmp[merge_left_cond]
        tmp[merge_right_cond] = tmp[merge_right_cond]
        # tmp = torch.cat(
        # [tmp_left[merge_left_cond], tmp_right[~merge_left_cond]], dim=0)
        top_corners[index] = tmp

        # encode
        encoded_bottom_corners = (bottom_corners - xy) / wh
        encoded_heights = (bottom_corners[..., 1] -
                           top_corners[..., 1]) / wh[..., 1]

        # import ipdb
        # ipdb.set_trace()
        mid_x = bottom_corners[:, 1, 0]
        ratio = (mid_x - proposals[:, 0]) / wh[:, 0, 0]
        ratio = ratio.clamp(min=0, max=1)

        # import ipdb
        # ipdb.set_trace()
        # encoded_bottom_corners = tensor_utils.multidim_index(
        # encoded_bottom_corners, visible_index)
        # encoded_heights = tensor_utils.multidim_index(
        # encoded_heights.unsqueeze(-1), visible_index)
        # tensor_utils.
        # visibility = tensor_utils.multidim_index(
        # visibility.unsqueeze(-1), visible_index).squeeze(-1)

        return torch.cat([
            encoded_bottom_corners.contiguous().view(num_instances, -1),
            encoded_heights.contiguous().view(num_instances, -1),
            ratio.view(num_instances, -1)
        ],
                         dim=-1)