Exemplo n.º 1
0
    def ioun_tensorboard(self, visual_dict, batch, it):


        rcnn_cls_label = batch['cls']
        valid_mask = (rcnn_cls_label).view(-1) > 0
        pred_iou=visual_dict['rcnn_iou'][valid_mask]
        rcnn_ref = visual_dict['rcnn_ref'][valid_mask]
        gt_boxes = batch['gt_boxes'][valid_mask]
        pred_boxes3d = visual_dict['pred_boxes3d'][valid_mask]

        iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d.view(-1, 7), gt_boxes.view(-1, 7))
        eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
        iou2d = torch.gather(iou2d, 1, eye)
        iou3d = torch.gather(iou3d, 1, eye)

        self.tb_log.add_scalar('recalled_0.5', torch.sum((iou3d > 0.5).float() / (valid_mask).float().sum()), it)
        self.tb_log.add_scalar('recalled_0.7', torch.sum((iou3d > 0.7).float() / (valid_mask).float().sum()), it)
        self.tb_log.add_histogram('pred_iou', pred_iou.view(-1), it)
        self.tb_log.add_histogram('offset_iou', iou3d.view(-1) - pred_iou.view(-1), it)
        self.tb_log.add_histogram('iou', iou3d.view(-1), it)

        # basic ref
        refine_boxes3d = refine_box(pred_boxes3d, rcnn_ref)

        iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(refine_boxes3d.view(-1, 7), gt_boxes.view(-1, 7))
        eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
        iou2d = torch.gather(iou2d, 1, eye)
        iou3d = torch.gather(iou3d, 1, eye)

        self.tb_log.add_scalar('ref_recalled_0.5', torch.sum((iou3d > 0.5).float() / (valid_mask).float().sum()), it)
        self.tb_log.add_scalar('ref_recalled_0.7', torch.sum((iou3d > 0.7).float() / (valid_mask).float().sum()), it)
        self.tb_log.add_histogram('ref_pred_iou', pred_iou.view(-1), it)
        self.tb_log.add_histogram('ref_offset_iou', iou3d.view(-1) - pred_iou.view(-1), it)
        self.tb_log.add_histogram('ref_iou', iou3d.view(-1), it)
Exemplo n.º 2
0
    def aug_roi_by_noise_torch(self,
                               roi_boxes3d,
                               gt_boxes3d,
                               iou3d_src,
                               aug_times=10):
        iou_of_rois = torch.zeros(roi_boxes3d.shape[0]).type_as(gt_boxes3d)
        pos_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH)

        for k in range(roi_boxes3d.shape[0]):
            temp_iou = cnt = 0
            roi_box3d = roi_boxes3d[k]

            gt_box3d = gt_boxes3d[k].view(1, 7)
            aug_box3d = roi_box3d
            keep = True
            while temp_iou < pos_thresh and cnt < aug_times:
                if np.random.rand() < 0.2:
                    aug_box3d = roi_box3d  # p=0.2 to keep the original roi box
                    keep = True
                else:
                    aug_box3d = self.random_aug_box3d(roi_box3d)
                    keep = False
                aug_box3d = aug_box3d.view((1, 7))
                iou3d = iou3d_utils.boxes_iou3d_gpu(aug_box3d, gt_box3d)
                temp_iou = iou3d[0][0]
                cnt += 1
            roi_boxes3d[k] = aug_box3d.view(-1)
            if cnt == 0 or keep:
                iou_of_rois[k] = iou3d_src[k]
            else:
                iou_of_rois[k] = temp_iou
        return roi_boxes3d, iou_of_rois
Exemplo n.º 3
0
    def rcnn_tensorboard(self, visual_dict, batch, it):

        rcnn_cls = visual_dict['rcnn_cls']
        rcnn_reg = visual_dict['rcnn_reg']
        pred_boxes3d = visual_dict['pred_boxes3d']
        rcnn_cls_label = batch['cls']

        valid_mask = (rcnn_cls_label).view(-1)>0
        rcnn_reg = rcnn_reg[valid_mask]
        gt_boxes=batch['gt_boxes'][valid_mask]
        pred_boxes3d = pred_boxes3d[valid_mask]

        self.tb_log.add_histogram('x_label', gt_boxes[:, 0, 0], it)
        self.tb_log.add_histogram('z_label', gt_boxes[:, 0, 2], it)
        self.tb_log.add_histogram('y_label', gt_boxes[:, 0, 1], it)
        self.tb_log.add_histogram('h_label', gt_boxes[:, 0, 3], it)
        self.tb_log.add_histogram('w_label', gt_boxes[:, 0, 4], it)
        self.tb_log.add_histogram('l_label', gt_boxes[:, 0, 5], it)
        self.tb_log.add_histogram('ry_label', gt_boxes[:, 0, 6] % np.pi, it)

        pred_boxes3d = pred_boxes3d.view(-1, 1, 7)

        iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d.squeeze(1), gt_boxes.squeeze(1))
        eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
        iou2d = torch.gather(iou2d, 1, eye)
        iou3d = torch.gather(iou3d, 1, eye)

        self.tb_log.add_histogram('trans_x', pred_boxes3d[:, :, 0], it)
        self.tb_log.add_histogram('trans_z', pred_boxes3d[:, :, 2], it)
        self.tb_log.add_histogram('trans_y', pred_boxes3d[:, :, 1], it)
        self.tb_log.add_histogram('trans_h', pred_boxes3d[:, :, 3], it)
        self.tb_log.add_histogram('trans_w', pred_boxes3d[:, :, 4], it)
        self.tb_log.add_histogram('trans_l', pred_boxes3d[:, :, 5], it)
        self.tb_log.add_histogram('trans_ry', pred_boxes3d[:, :, 6] % (np.pi), it)

        self.tb_log.add_histogram('x_offset', pred_boxes3d[:, 0, 0] - gt_boxes[:, 0, 0], it)
        self.tb_log.add_histogram('z_offset', pred_boxes3d[:, 0, 2] - gt_boxes[:, 0, 2], it)
        self.tb_log.add_histogram('y_offset', pred_boxes3d[:, 0, 1] - gt_boxes[:, 0, 1], it)
        self.tb_log.add_histogram('h_offset', pred_boxes3d[:, 0, 3] - gt_boxes[:, 0, 3], it)
        self.tb_log.add_histogram('w_offset', pred_boxes3d[:, 0, 4] - gt_boxes[:, 0, 4], it)
        self.tb_log.add_histogram('l_offset', pred_boxes3d[:, 0, 5] - gt_boxes[:, 0, 5], it)
        self.tb_log.add_histogram('ry_offset', ((pred_boxes3d[:, 0, 6] % (np.pi * 2) - gt_boxes[:, 0, 6] % (np.pi * 2))) , it)


        self.tb_log.add_histogram('iou2d', iou2d, it)
        self.tb_log.add_histogram('iou3d', iou3d, it)

        self.tb_log.add_scalar('recalled_0.5', torch.sum((iou3d > 0.5).float()/(rcnn_cls_label).sum()), it)
        self.tb_log.add_scalar('recalled_0.7', torch.sum((iou3d > 0.7).float()/(rcnn_cls_label).sum()), it)
Exemplo n.º 4
0
    def aug_roi_by_noise_torch(self,
                               roi_boxes3d,
                               gt_boxes3d,
                               iou3d_src,
                               aug_times=10):
        # print("1", fg_rois_src.size()) # size([32,7])
        # print("2", gt_of_fg_rois.size()) # size([32,8])
        # print("3", iou3d_src.size()) # size([32])
        iou_of_rois = torch.zeros(roi_boxes3d.shape[0]).type_as(gt_boxes3d)
        pos_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH)

        for k in range(roi_boxes3d.shape[0]):  # 0~31 = 32 times repeat
            temp_iou = cnt = 0
            roi_box3d = roi_boxes3d[k]
            gt_box3d = gt_boxes3d[k].view(1, 8)
            ##### gt_box3d = gt_boxes3d[k].view(1, 7)
            aug_box3d = roi_box3d
            keep = True
            while temp_iou < pos_thresh and cnt < aug_times:
                if np.random.rand() < 0.2:
                    aug_box3d = roi_box3d  # p=0.2 to keep the original roi box
                    keep = True
                else:
                    aug_box3d = self.random_aug_box3d(roi_box3d)
                    keep = False
                aug_box3d = aug_box3d.view((1, 7))
                iou3d = iou3d_utils.boxes_iou3d_gpu(aug_box3d, gt_box3d[:,
                                                                        0:7])
                #### print("iou3d size", iou3d) # size([1,1])
                temp_iou = iou3d[0][0]
                cnt += 1
            roi_boxes3d[k] = aug_box3d.view(-1)
            if cnt == 0 or keep:
                iou_of_rois[k] = iou3d_src[k]
            else:
                iou_of_rois[k] = temp_iou
        return roi_boxes3d, iou_of_rois
Exemplo n.º 5
0
    def sample_rois_for_rcnn(self, roi_boxes3d, gt_boxes3d):
        """
        :param roi_boxes3d: (B, M, 7)
        :param gt_boxes3d: (B, N, 8) [x, y, z, h, w, l, ry, cls]
        :return
            batch_rois: (B, N, 7)
            batch_gt_of_rois: (B, N, 8)
            batch_roi_iou: (B, N)
        """
        batch_size = roi_boxes3d.size(0)

        fg_rois_per_image = int(
            np.round(cfg.RCNN.FG_RATIO * cfg.RCNN.ROI_PER_IMAGE))

        batch_rois = gt_boxes3d.new(batch_size, cfg.RCNN.ROI_PER_IMAGE,
                                    7).zero_()
        batch_gt_of_rois = gt_boxes3d.new(batch_size, cfg.RCNN.ROI_PER_IMAGE,
                                          7).zero_()
        batch_roi_iou = gt_boxes3d.new(batch_size,
                                       cfg.RCNN.ROI_PER_IMAGE).zero_()

        for idx in range(batch_size):
            cur_roi, cur_gt = roi_boxes3d[idx], gt_boxes3d[idx]

            k = cur_gt.__len__() - 1
            while cur_gt[k].sum() == 0:
                k -= 1
            cur_gt = cur_gt[:k + 1]

            # include gt boxes in the candidate rois
            iou3d = iou3d_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:,
                                                                0:7])  # (M, N)

            max_overlaps, gt_assignment = torch.max(iou3d, dim=1)

            # sample fg, easy_bg, hard_bg
            fg_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH)
            fg_inds = torch.nonzero((max_overlaps >= fg_thresh)).view(-1)

            # TODO: this will mix the fg and bg when CLS_BG_THRESH_LO < iou < CLS_BG_THRESH
            # fg_inds = torch.cat((fg_inds, roi_assignment), dim=0)  # consider the roi which has max_iou with gt as fg

            easy_bg_inds = torch.nonzero(
                (max_overlaps < cfg.RCNN.CLS_BG_THRESH_LO)).view(-1)
            hard_bg_inds = torch.nonzero(
                (max_overlaps < cfg.RCNN.CLS_BG_THRESH)
                & (max_overlaps >= cfg.RCNN.CLS_BG_THRESH_LO)).view(-1)

            fg_num_rois = fg_inds.numel()
            bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()

            if fg_num_rois > 0 and bg_num_rois > 0:
                # sampling fg
                fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)

                rand_num = torch.from_numpy(np.random.permutation(
                    fg_num_rois)).type_as(gt_boxes3d).long()
                fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]

                # sampling bg
                bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE - fg_rois_per_this_image
                bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds,
                                              bg_rois_per_this_image)

            elif fg_num_rois > 0 and bg_num_rois == 0:
                # sampling fg
                rand_num = np.floor(
                    np.random.rand(cfg.RCNN.ROI_PER_IMAGE) * fg_num_rois)
                rand_num = torch.from_numpy(rand_num).type_as(
                    gt_boxes3d).long()
                fg_inds = fg_inds[rand_num]
                fg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE
                bg_rois_per_this_image = 0
            elif bg_num_rois > 0 and fg_num_rois == 0:
                # sampling bg
                bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE
                bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds,
                                              bg_rois_per_this_image)

                fg_rois_per_this_image = 0
            else:
                import pdb
                pdb.set_trace()
                raise NotImplementedError

            # augment the rois by noise
            roi_list, roi_iou_list, roi_gt_list = [], [], []
            if fg_rois_per_this_image > 0:
                fg_rois_src = cur_roi[fg_inds]
                gt_of_fg_rois = cur_gt[gt_assignment[fg_inds]]
                iou3d_src = max_overlaps[fg_inds]
                fg_rois, fg_iou3d = self.aug_roi_by_noise_torch(
                    fg_rois_src,
                    gt_of_fg_rois,
                    iou3d_src,
                    aug_times=cfg.RCNN.ROI_FG_AUG_TIMES)
                roi_list.append(fg_rois)
                roi_iou_list.append(fg_iou3d)
                roi_gt_list.append(gt_of_fg_rois)

            if bg_rois_per_this_image > 0:
                bg_rois_src = cur_roi[bg_inds]
                gt_of_bg_rois = cur_gt[gt_assignment[bg_inds]]
                iou3d_src = max_overlaps[bg_inds]
                aug_times = 1 if cfg.RCNN.ROI_FG_AUG_TIMES > 0 else 0
                bg_rois, bg_iou3d = self.aug_roi_by_noise_torch(
                    bg_rois_src, gt_of_bg_rois, iou3d_src, aug_times=aug_times)
                roi_list.append(bg_rois)
                roi_iou_list.append(bg_iou3d)
                roi_gt_list.append(gt_of_bg_rois)

            rois = torch.cat(roi_list, dim=0)
            iou_of_rois = torch.cat(roi_iou_list, dim=0)
            gt_of_rois = torch.cat(roi_gt_list, dim=0)

            batch_rois[idx] = rois
            batch_gt_of_rois[idx] = gt_of_rois
            batch_roi_iou[idx] = iou_of_rois

        return batch_rois, batch_gt_of_rois, batch_roi_iou
    def aug_one_scene(self, sample_id, pts_rect, pts_intensity, all_gt_boxes3d):
        """
        :param pts_rect: (N, 3)
        :param gt_boxes3d: (M1, 7)
        :param all_gt_boxex3d: (M2, 7)
        :return:
        """
        assert self.gt_database is not None
        extra_gt_num = np.random.randint(10, 15)
        try_times = 50
        cnt = 0
        cur_gt_boxes3d = all_gt_boxes3d.copy()
        cur_gt_boxes3d[:, 4] += 0.5
        cur_gt_boxes3d[:, 5] += 0.5  # enlarge new added box to avoid too nearby boxes

        extra_gt_obj_list = []
        extra_gt_boxes3d_list = []
        new_pts_list, new_pts_intensity_list = [], []
        src_pts_flag = np.ones(pts_rect.shape[0], dtype=np.int32)

        road_plane = self.get_road_plane(sample_id)
        a, b, c, d = road_plane

        while try_times > 0:
            try_times -= 1

            rand_idx = np.random.randint(0, self.gt_database.__len__() - 1)

            new_gt_dict = self.gt_database[rand_idx]
            new_gt_box3d = new_gt_dict['gt_box3d'].copy()
            new_gt_points = new_gt_dict['points'].copy()
            new_gt_intensity = new_gt_dict['intensity'].copy()
            new_gt_obj = new_gt_dict['obj']
            center = new_gt_box3d[0:3]
            if PC_REDUCE_BY_RANGE and (self.check_pc_range(center) is False):
                continue
            if cnt > extra_gt_num:
                break
            if new_gt_points.__len__() < 5:  # too few points
                continue

            # put it on the road plane
            cur_height = (-d - a * center[0] - c * center[2]) / b
            move_height = new_gt_box3d[1] - cur_height
            new_gt_box3d[1] -= move_height
            new_gt_points[:, 1] -= move_height

            cnt += 1

            iou3d = iou3d_utils.boxes_iou3d_gpu(torch.from_numpy(new_gt_box3d.reshape(1, 7)).cuda(),
                                                torch.from_numpy(cur_gt_boxes3d).cuda()).cpu().numpy()

            valid_flag = iou3d.max() < 1e-8
            if not valid_flag:
                continue

            enlarged_box3d = new_gt_box3d.copy()
            enlarged_box3d[3] += 2  # remove the points above and below the object
            boxes_pts_mask_list = roipool3d_utils.pts_in_boxes3d_cpu(torch.from_numpy(pts_rect),
                                                                     torch.from_numpy(enlarged_box3d.reshape(1, 7)))
            pt_mask_flag = (boxes_pts_mask_list[0].numpy() == 1)
            src_pts_flag[pt_mask_flag] = 0  # remove the original points which are inside the new box

            new_pts_list.append(new_gt_points)
            new_pts_intensity_list.append(new_gt_intensity)
            enlarged_box3d = new_gt_box3d.copy()
            enlarged_box3d[4] += 0.5
            enlarged_box3d[5] += 0.5  # enlarge new added box to avoid too nearby boxes
            cur_gt_boxes3d = np.concatenate((cur_gt_boxes3d, enlarged_box3d.reshape(1, 7)), axis=0)
            extra_gt_boxes3d_list.append(new_gt_box3d.reshape(1, 7))
            extra_gt_obj_list.append(new_gt_obj)

        if new_pts_list.__len__() == 0:
            return False, pts_rect, pts_intensity, None, None

        extra_gt_boxes3d = np.concatenate(extra_gt_boxes3d_list, axis=0)
        # remove original points and add new points
        pts_rect = pts_rect[src_pts_flag == 1]
        pts_intensity = pts_intensity[src_pts_flag == 1]
        new_pts_rect = np.concatenate(new_pts_list, axis=0)
        new_pts_intensity = np.concatenate(new_pts_intensity_list, axis=0)
        pts_rect = np.concatenate((pts_rect, new_pts_rect), axis=0)
        pts_intensity = np.concatenate((pts_intensity, new_pts_intensity), axis=0)

        return True, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list
Exemplo n.º 7
0
    def get_ioun_loss(model, ret_dict, tb_dict, visual_dict, input_data):

        rcnn_iou = ret_dict['rcnn_iou'].clone()
        rcnn_ref = ret_dict['rcnn_ref'].clone()
        gt_boxes3d = ret_dict['gt_boxes'].clone().view(-1, 7)
        pred_boxes3d = ret_dict['pred_boxes3d'].clone().view(-1, 7)
        refined_boxes3d = ret_dict['refined_box'].clone().view(-1, 7)
        reg_valid_mask = (ret_dict['cls'].float()).view(-1)
        iou_loss_dict = {}

        # iou mask
        batch_size = rcnn_iou.shape[0]
        fg_mask = reg_valid_mask > 0
        rcnn_iou = rcnn_iou
        rcnn_ref = rcnn_ref[fg_mask]
        gt_boxes3d = gt_boxes3d[fg_mask]
        pred_boxes3d = pred_boxes3d[fg_mask]
        refined_boxes3d = refined_boxes3d[fg_mask]
        fg_sum = torch.sum(fg_mask)
        if fg_sum != 0:

            # iou loss calculate

            # input box iou
            # iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d, gt_boxes3d)
            # ref box iou
            iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(refined_boxes3d,
                                                       gt_boxes3d)

            eye = torch.from_numpy(
                np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
            iou3d = torch.gather(iou3d, 1, eye).detach()
            iou3d_label = iou3d
            iou3d_label = iou3d_label.pow(2)

            # origin
            # loss_iou = F.mse_loss(rcnn_iou[fg_mask].view(-1), iou3d_label.view(-1))*100

            # basic box refine
            loc_pred = pred_boxes3d[:, :3]
            siz_pred = pred_boxes3d[:, 3:6]
            ang_pred = pred_boxes3d[:, 6]
            loc_label = gt_boxes3d[:, :3]
            siz_label = gt_boxes3d[:, 3:6]
            ang_label = gt_boxes3d[:, 6]

            # loc
            loss_loc = F.smooth_l1_loss(
                rcnn_ref[:, :3], (loc_label - loc_pred) / siz_pred) * 300

            # size
            size_res_norm_label = (siz_label - siz_pred) / siz_pred
            size_res_norm = rcnn_ref[:, 3:6]
            loss_siz = F.smooth_l1_loss(size_res_norm,
                                        size_res_norm_label) * 300

            # ang
            angle_residual = ((ang_label) % np.pi - (ang_pred) % np.pi)
            loss_ang = F.smooth_l1_loss(rcnn_ref[:, 6], angle_residual) * 20

        else:
            loss_iou = torch.zeros(()).cuda()
            loss_loc = torch.zeros(()).cuda()
            loss_siz = torch.zeros(()).cuda()
            loss_ang = torch.zeros(()).cuda()

        loss_reg = loss_loc + loss_siz + loss_ang

        # BCE iou loss,BCE global iouloss,MSE iouloss,MSE global iouloss
        if True:
            gt_boxes3d = ret_dict['gt_boxes'].clone().view(-1, 7)
            invalid_mask = gt_boxes3d.sum(-1) != 0
            pred_boxes3d = ret_dict['pred_boxes3d'].clone().view(-1, 7)
            refined_boxes3d = ret_dict['refined_box'].clone().view(-1, 7)
            iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(refined_boxes3d,
                                                       gt_boxes3d)
            eye = torch.from_numpy(
                np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
            iou3d = torch.gather(iou3d, 1, eye).detach()
            iou3d_label = iou3d
            iou3d_label = iou3d_label.pow(2)
            # BCE
            # rcnn_cls_flat = rcnn_cls.view(-1)
            # batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            # cls_valid_mask = (cls_label_flat >= 0).float()
            # rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)

            # #BCEIOU
            # rcnn_cls_flat = reg_valid_mask.view(-1).clone()
            # if fg_sum != 0:
            #     rcnn_cls_flat[fg_mask]=iou3d_label.view(-1)
            # batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_iou), rcnn_cls_flat, reduction='none')
            # cls_valid_mask = (rcnn_cls_flat >= 0).float()
            # loss_iou = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)

            # GlobalMSE
            # rcnn_cls_flat=iou3d_label.view(-1)
            # cls_valid_mask = ((fg_mask >= 0) | (~invalid_mask))
            # loss_iou = F.mse_loss(rcnn_iou[cls_valid_mask].view(-1), rcnn_cls_flat[cls_valid_mask].view(-1))*100

            # range MSE
            rcnn_cls_flat = iou3d_label.view(-1)
            cls_valid_mask = invalid_mask
            loss_iou = F.mse_loss(rcnn_iou[cls_valid_mask].view(-1),
                                  rcnn_cls_flat[cls_valid_mask].view(-1)) * 100

            # # #MSE
            # if fg_sum != 0:
            #     loss_iou = F.mse_loss(rcnn_iou[fg_mask].view(-1), iou3d_label.view(-1))*100
            # else:
            #     loss_iou = torch.zeros(()).cuda()
        else:
            raise NotImplementedError

        iou_loss_dict['ioun_loss_loc'] = loss_loc.item()
        iou_loss_dict['ioun_loss_siz'] = loss_siz.item()
        iou_loss_dict['ioun_loss_ang'] = loss_ang.item()
        iou_loss_dict['loss_iou'] = loss_iou.item()
        iou_loss_dict['loss_reg'] = loss_reg.item()

        tb_dict.update(iou_loss_dict)

        rcnn_loss_iou = loss_iou + loss_reg

        tb_dict['rcnn_loss_iou'] = rcnn_loss_iou.item()
        visual_dict['rcnn_iou'] = ret_dict['rcnn_iou'].clone()
        visual_dict['rcnn_ref'] = ret_dict['rcnn_ref'].clone()
        visual_dict['pred_boxes3d'] = ret_dict['pred_boxes3d'].clone().view(
            -1, 7)
        visual_dict['refined_box'] = ret_dict['refined_box'].clone().view(
            -1, 7)
        return rcnn_loss_iou
Exemplo n.º 8
0
    def get_rcnn_loss(model, ret_dict, tb_dict, visual_dict):
        rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg']
        batch_size = rcnn_reg.shape[0]

        gt_boxes3d = ret_dict['gt_boxes'].clone().view(batch_size, 7)
        cls_label = ret_dict['cls'].float().view(-1)
        reg_valid_mask = (ret_dict['cls'].float()).view(-1)
        pred_boxes3d = ret_dict['pred_boxes3d'].clone().view(-1, 7)

        # rcnn regression loss
        fg_mask = reg_valid_mask > 0
        fg_sum = torch.sum(fg_mask)
        if fg_sum != 0:
            # rcnn regression
            anchor_size = MEAN_SIZE

            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_rcnn_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask],
                                             gt_boxes3d.view(batch_size, 7)[fg_mask],
                                             loc_scope=cfg.RCNN.LOC_SCOPE,
                                             loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                             num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                             anchor_size=anchor_size,
                                             get_xz_fine=cfg.RCNN.LOC_XZ_FINE, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                             loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                             get_ry_fine=False)

            # extra box loss
            iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(
                pred_boxes3d[fg_mask],
                gt_boxes3d.view(batch_size, 7)[fg_mask])
            eye = torch.from_numpy(
                np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
            iou3d = torch.gather(iou3d, 1, eye).detach()
            iou_mask = iou3d.view(-1) > 0.5
            iou_sum = torch.sum(iou_mask)
            if iou_sum != 0:
                # corner loss
                gt_boxes3d_fcorner = gt_boxes3d.clone().view(
                    batch_size, 7)[fg_mask][iou_mask]
                pred_corner = boxes3d_to_corners3d_torch(
                    pred_boxes3d[fg_mask][iou_mask])
                gt_corner = boxes3d_to_corners3d_torch(gt_boxes3d_fcorner)
                gt_boxes3d_fcorner[:, 6] += np.pi
                gt_flip_corner = boxes3d_to_corners3d_torch(gt_boxes3d_fcorner)
                corner_dist = torch.min(
                    torch.norm(pred_corner - gt_corner, dim=-1),
                    torch.norm(pred_corner - gt_flip_corner, dim=-1))
                corner_loss = F.smooth_l1_loss(corner_dist,
                                               torch.zeros_like(corner_dist))
                # corner_loss = torch.zeros(()).cuda()

                # giou loss
                gt_boxes3d_fgiou = gt_boxes3d.clone().view(batch_size,
                                                           7)[fg_mask]
                # gious_loss = object_gious_3d_loss(gt_boxes3d_fgiou[iou_mask], pred_boxes3d[iou_mask])
                gious_loss = object_ious_3d_loss(
                    gt_boxes3d_fgiou[iou_mask],
                    pred_boxes3d[fg_mask][iou_mask])
                # gious_loss = torch.zeros(()).cuda()

            else:
                corner_loss = torch.zeros(()).cuda()
                gious_loss = torch.zeros(()).cuda()

            # iou loss calculate
            iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d[fg_mask],
                                                       gt_boxes3d[fg_mask])
            eye = torch.from_numpy(
                np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
            iou3d = torch.gather(iou3d, 1, eye).detach()
            iou3d_label = iou3d
            iou3d_label = iou3d_label.pow(2)

            loss_loc = loss_loc * 20
            loss_angle = loss_angle
            loss_size = loss_size * 300
            corner_loss = corner_loss * 10
            rcnn_loss_reg = loss_loc + loss_angle + loss_size  # + reg_error_T1 + reg_error_T2
            reg_loss_dict['loss_loc'] = loss_loc
            reg_loss_dict['loss_angle'] = loss_angle
            reg_loss_dict['loss_size'] = loss_size
            reg_loss_dict['loss_corner'] = corner_loss
            reg_loss_dict['loss_giou'] = gious_loss

        else:
            loss_loc = torch.zeros(()).cuda()
            loss_angle = torch.zeros(()).cuda()
            loss_size = torch.zeros(()).cuda()
            rcnn_loss_reg = torch.zeros(()).cuda()
            corner_loss = torch.zeros(()).cuda()
            gious_loss = torch.zeros(()).cuda()

        # rcnn classification loss
        if isinstance(model, nn.DataParallel):
            cls_loss_func = model.module.rcnn_net.cls_loss_func
        else:
            cls_loss_func = model.rcnn_net.cls_loss_func

        cls_label_flat = cls_label.view(-1)

        if cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy':
            invalid_mask = gt_boxes3d.sum(-1) != 0
            # BCE
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)

            # BalanceBCE
            # rcnn_cls_flat = rcnn_cls.view(-1)
            # batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            # cls_valid_mask = ((cls_label_flat >= 0)&(~invalid_mask)).float()
            # rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)\
            #                 +(batch_loss_cls * (1-cls_valid_mask)).sum() / torch.clamp((1-cls_valid_mask).sum(), min=1.0)

            # BCEIOU
            # rcnn_cls_flat = rcnn_cls.view(-1)
            # if fg_sum != 0:
            #     cls_label[fg_mask]=iou3d_label.view(-1)
            # batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            # cls_valid_mask = (cls_label_flat >= 0).float()
            # rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)

            # MSE
            # if fg_sum != 0:
            #     rcnn_loss_cls = F.mse_loss(rcnn_cls.view(-1)[fg_mask], iou3d_label.view(-1)) * 100
            # else:
            #     rcnn_loss_cls = torch.zeros(()).cuda()

            # GMSE
            # cls_label = cls_label*0
            # if fg_sum != 0:
            #     cls_label[fg_mask]=iou3d_label.view(-1)
            # rcnn_loss_cls = F.mse_loss(rcnn_cls.view(-1), cls_label.view(-1)) * 100

            # BalancedGMSE
            # cls_label = cls_label*0
            # if fg_sum != 0:
            #     cls_label[fg_mask]=iou3d_label.view(-1)
            # batch_loss_cls = F.mse_loss(rcnn_cls.view(-1), cls_label.view(-1), reduction='none')
            # rcnn_loss_cls = (batch_loss_cls * fg_mask.float()).sum() / torch.clamp(fg_mask.float().sum(), min=1.0) \
            #                 + (batch_loss_cls * (~fg_mask).float()).sum() / torch.clamp((~fg_mask).float().sum(), min=1.0)
            #
            # rcnn_loss_cls = rcnn_loss_cls * 100
        else:
            raise NotImplementedError

        # rcnn training
        if cfg.RCNN.ENABLED:
            rcnn_loss = rcnn_loss_cls + rcnn_loss_reg + corner_loss

        tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item()
        tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item()
        tb_dict['rcnn_loss'] = rcnn_loss.item()

        # tb_dict['reg_error_T1'] = reg_error_T1.item()
        # tb_dict['reg_error_T2'] = reg_error_T2.item()

        tb_dict['rcnn_loss_loc'] = loss_loc.item()
        tb_dict['rcnn_loss_angle'] = loss_angle.item()
        tb_dict['rcnn_loss_size'] = loss_size.item()
        tb_dict['rcnn_loss_corner'] = corner_loss.item()
        tb_dict['rcnn_loss_giou'] = gious_loss.item()

        tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item()
        tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item()

        visual_dict['rcnn_cls'] = rcnn_cls
        visual_dict['rcnn_reg'] = rcnn_reg
        visual_dict['pred_boxes3d'] = ret_dict['pred_boxes3d'].clone().view(
            -1, 7)

        return rcnn_loss
Exemplo n.º 9
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        input_data2 = input_data.copy()
        pred_boxes3d_1st = input_data2['pred_boxes3d_1st']
        ret_dict = {}
        batch_size = input_data['roi_boxes3d'].size(0)
        if self.training:

            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            with torch.no_grad():
                target_dict_2nd = self.proposal_target_layer(input_data2,
                                                             stage=2)
            pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'],
                                     target_dict_2nd['pts_feature']),
                                    dim=2)
            target_dict_2nd['pts_input'] = pts_input_2
            roi = target_dict_2nd['roi_boxes3d']
            #roi = pred_boxes3d_1st

        else:
            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            roi = pred_boxes3d_1st
            #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            pts_input_2 = self.roipooling(input_data2)

        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)

        batch_size_2 = pts_input_2.shape[0]
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou2 = self.iou_layer(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        #loss

        if self.training:
            cls_label = target_dict_2nd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_2nd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_2nd['gt_of_rois']
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            two = {
                'rcnn_loss_cls_2nd': rcnn_loss_cls,
                'rcnn_loss_reg_2nd': rcnn_loss_reg
            }

        else:
            two = {}

        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape)

        pred_boxes3d_2nd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        input_data3 = input_data.copy()
        if self.training:

            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            # print(input_data3['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_3rd = self.proposal_target_layer(input_data3,
                                                             stage=3)

            pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'],
                                     target_dict_3rd['pts_feature']),
                                    dim=2)
            target_dict_3rd['pts_input'] = pts_input_3
            roi = target_dict_3rd['roi_boxes3d']
            #roi = pred_boxes3d_2nd
        else:
            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            # input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            roi = pred_boxes3d_2nd
            # roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            pts_input_3 = self.roipooling(input_data3)
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input_3[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer_3(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer_3(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules_3)):
            li_xyz_3, li_features_3 = self.SA_modules_3[i](l_xyz_3[i],
                                                           l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou3 = self.iou_layer(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        # loss
        if self.training:
            cls_label = target_dict_3rd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_3rd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_3rd['gt_of_rois']
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)

            if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            # three = {'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg}

        else:
            three = {}
        pred_boxes3d_3rd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        if self.training:
            gt = target_dict_3rd['real_gt']
            iou_label = []
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d_3rd.view(-1, 7)[i].view(1, 7),
                        gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_label = (iou_label - 0.5) * 2
            iou_loss = F.mse_loss((pre_iou3[fg_mask]), iou_label[fg_mask])
            #print(iou_loss.item())
            three = {
                'rcnn_loss_cls_3rd': rcnn_loss_cls,
                'rcnn_loss_reg_3rd': rcnn_loss_reg,
                'rcnn_iou_loss': iou_loss
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        pre_iou3 = pre_iou3 / 2 + 0.5
        pre_iou2 = pre_iou2 / 2 + 0.5
        ret_dict = {
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd,
            'pre_iou3': pre_iou3,
            'pre_iou2': pre_iou2
        }
        ret_dict.update(sec)
        ret_dict.update(two)
        ret_dict.update(three)

        return ret_dict
Exemplo n.º 10
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """

        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data,
                                                             stage=1)

                pts_input = torch.cat(
                    (target_dict['sampled_pts'], target_dict['pts_feature']),
                    dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data[
                    'rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [
                        input_data['rpn_intensity'].unsqueeze(dim=2),
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]
                else:
                    pts_extra_input_list = [
                        input_data['seg_mask'].unsqueeze(dim=2)
                    ]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :,
                                    0:3] = kitti_utils.rotate_pc_along_y_torch(
                                        pooled_features[k, :, :, 0:3],
                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2],
                                                 pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                #input_data['ori_roi'] = torch.cat((input_data['ori_roi'], input_data['roi_boxes3d']), 1)
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data[
                    'reg_valid_mask'].view(-1)
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']
        #print(pts_input.shape)
        pts_input = pts_input.view(-1, 512, 128 + self.rcnn_input_channel)
        xyz, features = self._break_up_pc(pts_input)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(
                1, 2).unsqueeze(dim=3)
            #xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2)

            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(
                1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):

            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        batch_size = input_data['roi_boxes3d'].size(0)
        batch_size_2 = pts_input.shape[0]  # for loss fun
        #print(input_data['roi_boxes3d'].shape,pts_input.shape)
        rcnn_cls = self.cls_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou1 = self.iou_layer(l_features[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        if self.training:
            roi_boxes3d = target_dict['roi_boxes3d'].view(-1, 7)
            cls_label = target_dict['cls_label'].float()
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict['gt_of_rois']
            reg_valid_mask = target_dict['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg.view(batch_size_2, -1)[fg_mask].shape)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

        else:
            roi_boxes3d = input_data['roi_boxes3d'].view(-1, 7)
            one = {}
        #print(rcnn_reg.size(),roi_boxes3d.size())
        #print(roi_boxes3d.shape, rcnn_reg.shape)
        pred_boxes3d_1st = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        if self.training == False and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED:
            pred_boxes3d_1st = pred_boxes3d_1st.view(-1, 7)
        if self.training:

            gt = target_dict['real_gt']
            iou_label = []
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d_1st.view(-1, 7)[i].view(1, 7),
                        gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_loss = F.smooth_l1_loss(abs(pre_iou1), iou_label)

            one = {
                'rcnn_loss_cls': rcnn_loss_cls,
                'rcnn_loss_reg': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask

        #print(pre_iou1,iou_label)
        #print(gt[0:10],pred_boxes3d_1st[0:10],roi_boxes3d[0:10])
        input_data2 = input_data.copy()

        #print(input_data['roi_boxes3d'].size())
        if self.training:
            #input_data2['roi_boxes3d'] = torch.cat((pred_boxes3d_1st, input_data['ori_roi']), 1)
            input_data2['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_1st, input_data['roi_boxes3d']), 1)

            with torch.no_grad():
                target_dict_2nd = self.proposal_target_layer(input_data2,
                                                             stage=2)
            pts_input_2 = torch.cat((target_dict_2nd['sampled_pts'],
                                     target_dict_2nd['pts_feature']),
                                    dim=2)
            target_dict_2nd['pts_input'] = pts_input_2
            roi = target_dict_2nd['roi_boxes3d']

        else:
            input_data2['roi_boxes3d'] = pred_boxes3d_1st
            #input_data2['roi_boxes3d']=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            roi = pred_boxes3d_1st
            #roi=torch.cat((pred_boxes3d_1st, input_data['roi_boxes3d']), 1)
            pts_input_2 = self.roipooling(input_data2)
        #print(pts_input_2.shape)
        xyz_2, features_2 = self._break_up_pc(pts_input_2)
        #print(xyz_2.size(),xyz.size(),features_2.size(),features.size())
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_2 = pts_input_2[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_2 = self.xyz_up_layer(xyz_input_2)

            rpn_feature_2 = pts_input_2[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_2 = torch.cat((xyz_feature_2, rpn_feature_2), dim=1)
            merged_feature_2 = self.merge_down_layer(merged_feature_2)
            l_xyz_2, l_features_2 = [xyz_2], [merged_feature_2.squeeze(dim=3)]
        else:
            l_xyz__2, l_features_2 = [xyz_2], [features_2]
        #print(l_xyz_2[0].size(), l_xyz[0].size(), l_features_2[0].size(), l_features[0].size())
        for i in range(len(self.SA_modules)):
            li_xyz_2, li_features_2 = self.SA_modules[i](l_xyz_2[i],
                                                         l_features_2[i])
            l_xyz_2.append(li_xyz_2)
            l_features_2.append(li_features_2)
        del xyz, features, l_features

        rcnn_cls_2nd = self.cls_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_2nd = self.reg_layer_2nd(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou2 = self.iou_layer(l_features_2[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        #loss
        if self.training:
            cls_label = target_dict_2nd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_2nd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat),
                cls_label.view(-1),
                reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_2nd['gt_of_rois']
            reg_valid_mask = target_dict_2nd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #print(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_2nd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            #two = {'rcnn_loss_cls_2nd': rcnn_loss_cls, 'rcnn_loss_reg_2nd': rcnn_loss_reg}

        else:
            two = {}

        sec = {'rcnn_cls_2nd': rcnn_cls_2nd, 'rcnn_reg_2nd': rcnn_reg_2nd}
        #print(input_data['roi_boxes3d'].shape,input_data2['roi_boxes3d'].shape)

        pred_boxes3d_2nd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_2nd.view(-1, rcnn_reg_2nd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        if self.training:
            '''
            gt=target_dict_2nd['real_gt']
            iou_label=[]
            for i in range(batch_size_2):
                iou_label.append(iou3d_utils.boxes_iou3d_gpu(pred_boxes3d_2nd.view(-1,7)[i].view(1,7), gt[i].view(1,7)))
            iou_label=torch.cat(iou_label)
            iou_loss=F.smooth_l1_loss(abs(pre_iou2),iou_label)
            '''
            two = {
                'rcnn_loss_cls_2nd': rcnn_loss_cls,
                'rcnn_loss_reg_2nd': rcnn_loss_reg
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        input_data3 = input_data2.copy()
        #del input_data2

        if self.training:
            input_data3['roi_boxes3d'] = torch.cat(
                (pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            #input_data3['roi_boxes3d'] = input_data2['gt_boxes3d']
            #input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #print(input_data3['roi_boxes3d'].shape)
            with torch.no_grad():
                target_dict_3rd = self.proposal_target_layer(input_data3,
                                                             stage=3)

            pts_input_3 = torch.cat((target_dict_3rd['sampled_pts'],
                                     target_dict_3rd['pts_feature']),
                                    dim=2)
            target_dict_3rd['pts_input'] = pts_input_3
            roi = target_dict_3rd['roi_boxes3d']
        else:
            input_data3['roi_boxes3d'] = pred_boxes3d_2nd
            #input_data3['roi_boxes3d']=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            roi = pred_boxes3d_2nd
            #roi=torch.cat((pred_boxes3d_2nd, input_data2['roi_boxes3d']), 1)
            pts_input_3 = self.roipooling(input_data3)
        xyz_3, features_3 = self._break_up_pc(pts_input_3)

        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input_3 = pts_input_3[...,
                                      0:self.rcnn_input_channel].transpose(
                                          1, 2).unsqueeze(dim=3)
            xyz_feature_3 = self.xyz_up_layer(xyz_input_3)

            rpn_feature_3 = pts_input_3[...,
                                        self.rcnn_input_channel:].transpose(
                                            1, 2).unsqueeze(dim=3)

            merged_feature_3 = torch.cat((xyz_feature_3, rpn_feature_3), dim=1)
            merged_feature_3 = self.merge_down_layer(merged_feature_3)
            l_xyz_3, l_features_3 = [xyz_3], [merged_feature_3.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz_3], [features_3]

        for i in range(len(self.SA_modules)):
            li_xyz_3, li_features_3 = self.SA_modules[i](l_xyz_3[i],
                                                         l_features_3[i])
            l_xyz_3.append(li_xyz_3)
            l_features_3.append(li_features_3)
        del xyz_2, features_2, l_features_2
        rcnn_cls_3rd = self.cls_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, 1 or 2)
        rcnn_reg_3rd = self.reg_layer_3rd(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)  # (B*64, C)
        pre_iou3 = self.iou_layer(l_features_3[-1]).transpose(
            1, 2).contiguous().squeeze(dim=1)
        #loss
        if self.training:
            cls_label = target_dict_3rd['cls_label'].float()
            rcnn_cls_flat = rcnn_cls_3rd.view(-1)
            batch_loss_cls = F.binary_cross_entropy(
                torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
            cls_label_flat = cls_label.view(-1)
            cls_valid_mask = (cls_label_flat >= 0).float()
            rcnn_loss_cls = (batch_loss_cls *
                             cls_valid_mask).sum() / torch.clamp(
                                 cls_valid_mask.sum(), min=1.0)
            gt_boxes3d_ct = target_dict_3rd['gt_of_rois']
            reg_valid_mask = target_dict_3rd['reg_valid_mask']
            fg_mask = (reg_valid_mask > 0)
            #cls_mask=(target_dict_3rd['cls_label']>0)
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[cls_mask].size(0))
            #print(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0))
            if rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask].size(0) == 0:
                fg_mask = (reg_valid_mask <= 0)
            loss_loc, loss_angle, loss_size, reg_loss_dict = \
                loss_utils.get_reg_loss(rcnn_reg_3rd.view(batch_size_2, -1)[fg_mask],
                                        gt_boxes3d_ct.view(batch_size_2, 7)[fg_mask],
                                        loc_scope=cfg.RCNN.LOC_SCOPE,
                                        loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                        num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                        anchor_size=anchor_size,
                                        get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                        loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                        get_ry_fine=True)
            rcnn_loss_reg = loss_loc + loss_angle + 3 * loss_size

            #three = {'rcnn_loss_cls_3rd': rcnn_loss_cls, 'rcnn_loss_reg_3rd': rcnn_loss_reg}

        else:
            three = {}
        pred_boxes3d_3rd = decode_bbox_target(
            roi.view(-1, 7),
            rcnn_reg_3rd.view(-1, rcnn_reg_3rd.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)
        if self.training:
            gt = target_dict_3rd['real_gt']
            iou_label = []
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d_3rd.view(-1, 7)[i].view(1, 7),
                        gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_loss = F.smooth_l1_loss((pre_iou3), iou_label)
            three = {
                'rcnn_loss_cls_3rd': rcnn_loss_cls,
                'rcnn_loss_reg_3rd': rcnn_loss_reg,
                'rcnn_iou_loss': iou_loss
            }
            del cls_label, rcnn_cls_flat, batch_loss_cls, cls_label_flat, cls_valid_mask, rcnn_loss_cls, gt_boxes3d_ct, reg_valid_mask, fg_mask
        ret_dict = {
            'rcnn_cls': rcnn_cls,
            'rcnn_reg': rcnn_reg,
            'rcnn_cls_3rd': rcnn_cls_3rd,
            'rcnn_reg_3rd': rcnn_reg_3rd,
            'pred_boxes3d_1st': pred_boxes3d_1st,
            'pred_boxes3d_2nd': pred_boxes3d_2nd,
            'pred_boxes3d_3rd': pred_boxes3d_3rd,
            'pre_iou3': pre_iou3,
            'pre_iou2': pre_iou2,
            'pre_iou1': pre_iou1
        }
        ret_dict.update(sec)
        ret_dict.update(one)
        ret_dict.update(two)
        ret_dict.update(three)
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
Exemplo n.º 11
0
def eval_one_epoch_rcnn(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(1024)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')
    os.makedirs(final_output_dir, exist_ok=True)

    if args.save_result:
        roi_output_dir = os.path.join(result_dir, 'roi_result', 'data')
        refine_output_dir = os.path.join(result_dir, 'refine_result', 'data')
        os.makedirs(roi_output_dir, exist_ok=True)
        os.makedirs(refine_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s RCNN EVALUATION ----' % epoch_id)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')
    for data in dataloader:
        sample_id = data['sample_id']
        cnt += 1
        assert args.batch_size == 1, 'Only support bs=1 here'
        input_data = {}
        for key, val in data.items():
            if key != 'sample_id':
                input_data[key] = torch.from_numpy(val).contiguous().cuda(
                    non_blocking=True).float()

        roi_boxes3d = input_data['roi_boxes3d']
        roi_scores = input_data['roi_scores']
        if cfg.RCNN.ROI_SAMPLE_JIT:
            for key, val in input_data.items():
                if key in ['gt_iou', 'gt_boxes3d']:
                    continue
                input_data[key] = input_data[key].unsqueeze(dim=0)
        else:
            pts_input = torch.cat(
                (input_data['pts_input'], input_data['pts_features']), dim=-1)
            input_data['pts_input'] = pts_input

        ret_dict = model(input_data)
        rcnn_cls = ret_dict['rcnn_cls']
        rcnn_reg = ret_dict['rcnn_reg']

        # bounding box regression
        anchor_size = MEAN_SIZE
        if cfg.RCNN.SIZE_RES_ON_ROI:
            roi_size = input_data['roi_size']
            anchor_size = roi_size

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d,
            rcnn_reg,
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True)

        # scoring
        if rcnn_cls.shape[1] == 1:
            raw_scores = rcnn_cls.view(-1)
            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores > cfg.RCNN.SCORE_THRESH).long()
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        # evaluation
        disp_dict = {'mode': mode}
        if not args.test:
            gt_boxes3d = input_data['gt_boxes3d']
            gt_iou = input_data['gt_iou']

            # calculate recall
            gt_num = gt_boxes3d.shape[0]
            if gt_num > 0:
                iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d, gt_boxes3d)
                gt_max_iou, _ = iou3d.max(dim=0)
                refined_iou, _ = iou3d.max(dim=1)

                for idx, thresh in enumerate(thresh_list):
                    total_recalled_bbox_list[idx] += (gt_max_iou >
                                                      thresh).sum().item()
                recalled_num = (gt_max_iou > 0.7).sum().item()
                total_gt_bbox += gt_num

                iou3d_in = iou3d_utils.boxes_iou3d_gpu(roi_boxes3d, gt_boxes3d)
                gt_max_iou_in, _ = iou3d_in.max(dim=0)

                for idx, thresh in enumerate(thresh_list):
                    total_roi_recalled_bbox_list[idx] += (gt_max_iou_in >
                                                          thresh).sum().item()

            # classification accuracy
            cls_label = (gt_iou > cfg.RCNN.CLS_FG_THRESH).float()
            cls_valid_mask = ((gt_iou >= cfg.RCNN.CLS_FG_THRESH) |
                              (gt_iou <= cfg.RCNN.CLS_BG_THRESH)).float()
            cls_acc = ((pred_classes == cls_label.long()).float() *
                       cls_valid_mask).sum() / max(cls_valid_mask.sum(), 1.0)

            iou_thresh = 0.7 if cfg.CLASSES == 'Car' else 0.5
            cls_label_refined = (gt_iou >= iou_thresh).float()
            cls_acc_refined = (
                pred_classes == cls_label_refined.long()).float().sum() / max(
                    cls_label_refined.shape[0], 1.0)

            total_cls_acc += cls_acc.item()
            total_cls_acc_refined += cls_acc_refined.item()

            disp_dict['recall'] = '%d/%d' % (total_recalled_bbox_list[3],
                                             total_gt_bbox)
            disp_dict['cls_acc_refined'] = '%.2f' % cls_acc_refined.item()

        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        image_shape = dataset.get_image_shape(sample_id)
        if args.save_result:
            # save roi and refine results
            roi_boxes3d_np = roi_boxes3d.cpu().numpy()
            pred_boxes3d_np = pred_boxes3d.cpu().numpy()
            calib = dataset.get_calib(sample_id)

            save_kitti_format(sample_id, calib, roi_boxes3d_np, roi_output_dir,
                              roi_scores, image_shape)
            save_kitti_format(sample_id, calib, pred_boxes3d_np,
                              refine_output_dir,
                              raw_scores.cpu().numpy(), image_shape)

        # NMS and scoring
        # scores thresh
        inds = norm_scores > cfg.RCNN.SCORE_THRESH
        if inds.sum() == 0:
            continue

        pred_boxes3d_selected = pred_boxes3d[inds]
        raw_scores_selected = raw_scores[inds]

        # NMS thresh
        boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(
            pred_boxes3d_selected)
        keep_idx = iou3d_utils.nms_gpu(boxes_bev_selected, raw_scores_selected,
                                       cfg.RCNN.NMS_THRESH)
        pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]

        scores_selected = raw_scores_selected[keep_idx]
        pred_boxes3d_selected, scores_selected = pred_boxes3d_selected.cpu(
        ).numpy(), scores_selected.cpu().numpy()

        calib = dataset.get_calib(sample_id)
        final_total += pred_boxes3d_selected.shape[0]
        save_kitti_format(sample_id, calib, pred_boxes3d_selected,
                          final_output_dir, scores_selected, image_shape)

    progress_bar.close()

    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info(str(datetime.now()))

    avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
    avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
    avg_det_num = (final_total / max(cnt, 1.0))
    logger.info('final average detections: %.3f' % avg_det_num)
    logger.info('final average cls acc: %.3f' % avg_cls_acc)
    logger.info('final average cls acc refined: %.3f' % avg_cls_acc_refined)
    ret_dict['rcnn_cls_acc'] = avg_cls_acc
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        logger.info(ap_result_str)
        ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)

    return ret_dict
Exemplo n.º 12
0
    def _eval_data(self, masked_pts=None):
        """eval data with sampled pts
        """
        with torch.no_grad():
            MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
            batch_size = self.config['batch_size']

            # get valid point (projected points should be in image)
            sample_id, pts_rect, pts_intensity, gt_boxes3d, npoints, labels = \
            self.data['sample_id'], self.data['pts_rect'], self.data['pts_intensity'], self.data['gt_boxes3d'], self.data['npoints'], self.data['label']

            cls_types = [[
                labels[k][i].cls_type for i in range(len(labels[k]))
            ] for k in range(batch_size)]

            calib = [
                self.test_loader.dataset.get_calib(idx) for idx in sample_id
            ]
            if self.use_masked:
                # use masked/sampled pts if True
                pts_rect = np.array([
                    c.lidar_to_rect(masked_pts[k][:, 0:3])
                    for k, c in enumerate(calib)
                ])
                pts_intensity = [
                    masked_pts[k][:, 3] for k in range(batch_size)
                ]
                npoints = masked_pts.shape[0]

            inputs = torch.from_numpy(pts_rect).cuda(
                non_blocking=True).float().view(self.config['batch_size'], -1,
                                                3)
            gt_boxes3d = torch.from_numpy(gt_boxes3d).cuda(non_blocking=True)
            input_data = {'pts_input': inputs}

            # model inference
            ret_dict = self.model(input_data)

            roi_scores_raw = ret_dict['roi_scores_raw']  # (B, M)
            roi_boxes3d = ret_dict['rois']  # (B, M, 7)
            # seg_result = ret_dict['seg_result'].long()  # (B, N)

            rcnn_cls = ret_dict['rcnn_cls'].view(batch_size, -1,
                                                 ret_dict['rcnn_cls'].shape[1])
            rcnn_reg = ret_dict['rcnn_reg'].view(
                batch_size, -1, ret_dict['rcnn_reg'].shape[1])  # (B, M, C)

            norm_scores = torch.sigmoid(rcnn_cls)

            # remove low confidence scores
            thresh_mask = norm_scores > cfg.RCNN.SCORE_THRESH

            # bounding box regression
            anchor_size = MEAN_SIZE

            pred_boxes3d = decode_bbox_target(
                roi_boxes3d.view(-1, 7),
                rcnn_reg.view(-1, rcnn_reg.shape[-1]),
                anchor_size=anchor_size,
                loc_scope=cfg.RCNN.LOC_SCOPE,
                loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                get_xz_fine=True,
                get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
                loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                get_ry_fine=True).view(batch_size, -1, 7)

            # select boxes (list of tensors)
            pred_boxes3d_selected = [
                pred_boxes3d[k][thresh_mask[k].view(-1)]
                for k in range(batch_size)
            ]
            raw_scores_selected = [
                roi_scores_raw[k][thresh_mask[k].view(-1)]
                for k in range(batch_size)
            ]
            norm_scores_selected = [
                norm_scores[k][thresh_mask[k].view(-1)]
                for k in range(batch_size)
            ]

            # rotated NMS
            boxes_bev_selected = [
                kitti_utils.boxes3d_to_bev_torch(bboxes)
                for bboxes in pred_boxes3d_selected
            ]
            keep_idx = [
                iou3d_utils.nms_gpu(boxes_bev_selected[k],
                                    raw_scores_selected[k],
                                    cfg.RCNN.NMS_THRESH).view(-1)
                for k in range(batch_size)
            ]
            pred_boxes3d_selected = [
                pred_boxes3d_selected[k][keep_idx[k]]
                for k in range(batch_size)
            ]
            scores_selected = [
                raw_scores_selected[k][keep_idx[k]] for k in range(batch_size)
            ]
            norm_scores_selected = [
                norm_scores_selected[k][keep_idx[k]] for k in range(batch_size)
            ]

            # want car gt_boxes
            keep_idx = [[
                i for i in range(len(cls_types[k])) if cls_types[k][i] == 'Car'
            ] for k in range(batch_size)]
            gt_boxes3d_selected = [
                gt_boxes3d[k][keep_idx[k]] for k in range(batch_size)
            ]

            # what if no boxes with cars?
            has_info = [k for k in range(batch_size) if len(keep_idx[k]) > 0]
            gt_boxes3d_selected = [gt_boxes3d_selected[x] for x in has_info]
            pred_boxes3d_selected = [
                pred_boxes3d_selected[x] for x in has_info
            ]
            batch_size = len(has_info)
            if batch_size == 0:
                return None

            # Intersect over union
            iou3d = [
                iou3d_utils.boxes_iou3d_gpu(gt_boxes3d_selected[k],
                                            pred_boxes3d_selected[k])
                for k in range(batch_size)
            ]

            # get the max iou for each ground truth bounding box
            gt_max_iou = [
                torch.max(iou3d[k], dim=0)[0] for k in range(batch_size)
            ]

            # get precision at each index (to get auc)
            precision_vals = []
            for k in range(batch_size):
                batch_iou = gt_max_iou[k]
                batch_precision = []
                num_correct = 0
                for i in range(len(batch_iou)):
                    if batch_iou[i] > 0.7:
                        num_correct += 1
                    batch_precision.append(num_correct / (i + 1))

                precision_vals.append(batch_precision)

            aps = []
            for k in range(batch_size):
                batch_prec = precision_vals[k]
                ap = 0
                for i in range(len(batch_prec)):
                    ap += max(batch_prec[i:])

                aps.append(ap)

            num_gt_boxes = sum([len(gt_max_iou[k]) for k in range(batch_size)])

            return sum(aps) / num_gt_boxes
Exemplo n.º 13
0
    def forward(self, input_data):
        """
        :param input_data: input dict
        :return:
        """
        if cfg.RCNN.ROI_SAMPLE_JIT:
            if self.training:
                with torch.no_grad():
                    target_dict = self.proposal_target_layer(input_data)

                pts_input = torch.cat((target_dict['sampled_pts'], target_dict['pts_feature']), dim=2)
                target_dict['pts_input'] = pts_input
            else:
                rpn_xyz, rpn_features = input_data['rpn_xyz'], input_data['rpn_features']
                batch_rois = input_data['roi_boxes3d']
                if cfg.RCNN.USE_INTENSITY:
                    pts_extra_input_list = [input_data['rpn_intensity'].unsqueeze(dim=2),
                                            input_data['seg_mask'].unsqueeze(dim=2)]
                else:
                    pts_extra_input_list = [input_data['seg_mask'].unsqueeze(dim=2)]

                if cfg.RCNN.USE_DEPTH:
                    pts_depth = input_data['pts_depth'] / 70.0 - 0.5
                    pts_extra_input_list.append(pts_depth.unsqueeze(dim=2))
                pts_extra_input = torch.cat(pts_extra_input_list, dim=2)

                pts_feature = torch.cat((pts_extra_input, rpn_features), dim=2)
                pooled_features, pooled_empty_flag = \
                        roipool3d_utils.roipool3d_gpu(rpn_xyz, pts_feature, batch_rois, cfg.RCNN.POOL_EXTRA_WIDTH,
                                                      sampled_pt_num=cfg.RCNN.NUM_POINTS)

                # canonical transformation
                batch_size = batch_rois.shape[0]
                roi_center = batch_rois[:, :, 0:3]
                pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
                for k in range(batch_size):
                    pooled_features[k, :, :, 0:3] = kitti_utils.rotate_pc_along_y_torch(pooled_features[k, :, :, 0:3],
                                                                                        batch_rois[k, :, 6])

                pts_input = pooled_features.view(-1, pooled_features.shape[2], pooled_features.shape[3])
        else:
            pts_input = input_data['pts_input']
            target_dict = {}
            target_dict['pts_input'] = input_data['pts_input']
            target_dict['roi_boxes3d'] = input_data['roi_boxes3d']
            if self.training:
                target_dict['cls_label'] = input_data['cls_label']
                target_dict['reg_valid_mask'] = input_data['reg_valid_mask']
                target_dict['gt_of_rois'] = input_data['gt_boxes3d_ct']

        xyz, features = self._break_up_pc(pts_input)
        batch_size = input_data['roi_boxes3d'].size(0)
        if cfg.RCNN.USE_RPN_FEATURES:
            xyz_input = pts_input[..., 0:self.rcnn_input_channel].transpose(1, 2).unsqueeze(dim=3)
            xyz_feature = self.xyz_up_layer(xyz_input)

            rpn_feature = pts_input[..., self.rcnn_input_channel:].transpose(1, 2).unsqueeze(dim=3)

            merged_feature = torch.cat((xyz_feature, rpn_feature), dim=1)
            merged_feature = self.merge_down_layer(merged_feature)
            l_xyz, l_features = [xyz], [merged_feature.squeeze(dim=3)]
        else:
            l_xyz, l_features = [xyz], [features]

        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)

        rcnn_cls = self.cls_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layer(l_features[-1]).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, C)
        anchor_size = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
        if self.training:
            roi_boxes3d=target_dict['roi_boxes3d'].view(-1,7)
            #roi_boxes3d = input_data['roi_boxes3d']
        else:
            roi_boxes3d=input_data['roi_boxes3d']
        pred_boxes3d_1st = decode_bbox_target(roi_boxes3d.view(-1, 7), rcnn_reg.view(-1, rcnn_reg.shape[-1]),
                                              anchor_size=anchor_size,
                                              loc_scope=cfg.RCNN.LOC_SCOPE,
                                              loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
                                              num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
                                              get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
                                              loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
                                              get_ry_fine=True).view(batch_size, -1, 7)
        ret_dict = {'rcnn_cls': rcnn_cls, 'rcnn_reg': rcnn_reg,'pred_boxes3d_1st':pred_boxes3d_1st}
        ret_dict['pooled_feature'] = l_features[-1]
        if cfg.TRAIN.IOU_LAYER == 'split' and self.training:

            gt = target_dict['real_gt']
            iou_label = []
            batch_size_2 = pts_input.shape[0]
            for i in range(batch_size_2):
                iou_label.append(
                    iou3d_utils.boxes_iou3d_gpu(pred_boxes3d_1st.view(-1, 7)[i].view(1, 7), gt[i].view(1, 7)))
            iou_label = torch.cat(iou_label)
            iou_label = (iou_label - 0.5) * 2
            ret_dict['iou_label']=iou_label
        if self.training:
            ret_dict.update(target_dict)
        return ret_dict
Exemplo n.º 14
0
    def eval_epoch_ioun(self, d_loader):
        self.model.eval()

        eval_dict = {'recalled_0.5': 0, 'recalled_0.7': 0,
                     'ref_recalled_0.5': 0, 'ref_recalled_0.7': 0}
        total_loss = count = 0.0

        # eval one epoch
        cls = []
        iou = []
        score = []
        p_iou = []
        iou_offset = []
        offset = []
        ref_offset = []
        recall_list = []
        all_gt_list = []
        ref_iou = []
        ref_recall_list = []
        ref_box_list =[]
        TP = 0
        FP = 0
        FN = 0
        for i, data in tqdm.tqdm(enumerate(d_loader, 0), total=len(d_loader), leave=False, desc='val'):
            self.optimizer.zero_grad()

            sample_id = int(data['sample_id'].tolist()[0])
            box_id = int(data['box_id'].tolist()[0])
            if box_id != -1:
                all_gt_list.append([sample_id, box_id])

            if data['cur_box_point'].shape[1]==0: continue
            gt_boxes = torch.from_numpy(data['gt_boxes'][...,:7]).cuda().float()
            loss, tb_dict, disp_dict, visual_dict = self.model_fn_eval(self.model, data)
            total_loss += loss.item()
            #count += 1

            for k, v in tb_dict.items():
                eval_dict[k] = eval_dict.get(k, 0) + v

            cls_label = data['cls'].float().view(-1)
            reg_valid_mask = (cls_label.float()).view(-1)

            fg_mask = reg_valid_mask > 0
            if fg_mask != 0:

                pred_boxes3d = visual_dict['pred_boxes3d'].view(-1, 7)
                refined_boxes3d = visual_dict['refined_box'].view(-1, 7)

                gt_boxes = gt_boxes[:, :, :7].view(-1, 7)

                _, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d, gt_boxes)
                eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
                iou3d = torch.gather(iou3d, 1, eye)

                eval_dict['recalled_0.5'] += torch.sum((iou3d > 0.5).long()).item()
                eval_dict['recalled_0.7'] += torch.sum((iou3d > 0.7).long()).item()

                pred_iou = visual_dict['rcnn_iou'] # torch.sigmoid(rcnn_reg[:, -1])
                p_iou.append(pred_iou.view(-1, 1))
                iou_offset.append((iou3d.view(-1,1)-pred_iou.view(-1,1)))
                iou.append(iou3d.view(-1,1))
                offset.append(pred_boxes3d[:,:] - gt_boxes[:,:7])
                count += 1

                recall_list.append([sample_id, box_id, pred_iou.item(), iou3d.item()])



                _, iou3d = iou3d_utils.boxes_iou3d_gpu(refined_boxes3d.view(-1, 7), gt_boxes.view(-1, 7))
                eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
                iou3d = torch.gather(iou3d, 1, eye)

                eval_dict['ref_recalled_0.5'] += torch.sum((iou3d > 0.5).long()).item()
                eval_dict['ref_recalled_0.7'] += torch.sum((iou3d > 0.7).long()).item()

                ref_iou.append(iou3d.view(-1, 1))
                ref_offset.append(refined_boxes3d[:,:] - gt_boxes[:,:7])
                ref_box_list.append(refined_boxes3d[:, :])
                ref_recall_list.append([sample_id, box_id, pred_iou.item(), iou3d.item()])

        # statistics this epoch
        for k, v in eval_dict.items():
            eval_dict[k] = eval_dict[k] / max(count, 1)

        offset = torch.cat(offset,dim=0).reshape(-1,7)
        ref_offset = torch.cat(ref_offset, dim=0).reshape(-1, 7)
        ref_box_list = torch.cat(ref_box_list, dim=0).reshape(-1, 7)

        iou_offset = torch.cat(iou_offset,dim=0).reshape(-1,1)
        p_iou = torch.cat(p_iou, dim=0).reshape(-1, 1)
        iou = torch.cat(iou, dim=0).reshape(-1, 1)
        ref_iou = torch.cat(ref_iou, dim=0).reshape(-1, 1)

        single_gt_list = []
        for sample in all_gt_list:
            if not sample in single_gt_list: single_gt_list.append(sample)

        similar_nms_list = []
        recall_list.sort(key=lambda x: x[2], reverse=True)

        for sample in recall_list:
            exist_flag = False
            for target in similar_nms_list:
                if sample[0:2] == target[0:2]:
                    exist_flag=True
                    if sample[2]>target[2]:
                        target[3] = sample[3]
            if exist_flag == False:
                similar_nms_list.append(sample)

        recall05_list = [sample for sample in similar_nms_list if sample[3]>0.5]
        recall07_list = [sample for sample in similar_nms_list if sample[3]>0.7]

        similar_nms_list = []
        ref_recall_list.sort(key=lambda x: x[2], reverse=True)

        for sample in ref_recall_list:
            exist_flag = False
            for target in similar_nms_list:
                if sample[0:2] == target[0:2]:
                    exist_flag=True
                    if sample[2]>target[2]:
                        target[3] = sample[3]
            if exist_flag == False:
                similar_nms_list.append(sample)

        ref_recall05_list = [sample for sample in similar_nms_list if sample[3]>0.5]
        ref_recall07_list = [sample for sample in similar_nms_list if sample[3]>0.7]

        self.tb_log.add_histogram('val_iou', iou.view(-1), self.it)
        self.tb_log.add_histogram('val_ref_iou', ref_iou.view(-1), self.it)
        self.tb_log.add_histogram('val_x_offset', offset[:, 0], self.it)
        self.tb_log.add_histogram('val_y_offset', offset[:, 1], self.it)
        self.tb_log.add_histogram('val_z_offset', offset[:, 2], self.it)
        self.tb_log.add_histogram('val_h_offset', offset[:, 3], self.it)
        self.tb_log.add_histogram('val_w_offset', offset[:, 4], self.it)
        self.tb_log.add_histogram('val_l_offset', offset[:, 5], self.it)
        self.tb_log.add_histogram('val_ry_offset', offset[:, 6], self.it)

        self.tb_log.add_histogram('val_x_roffset', ref_offset[:, 0], self.it)
        self.tb_log.add_histogram('val_y_roffset', ref_offset[:, 1], self.it)
        self.tb_log.add_histogram('val_z_roffset', ref_offset[:, 2], self.it)
        self.tb_log.add_histogram('val_h_roffset', ref_offset[:, 3], self.it)
        self.tb_log.add_histogram('val_w_roffset', ref_offset[:, 4], self.it)
        self.tb_log.add_histogram('val_l_roffset', ref_offset[:, 5], self.it)
        self.tb_log.add_histogram('val_ry_roffset', ref_offset[:, 6], self.it)

        self.tb_log.add_histogram('val_ref_x', ref_box_list[:, 0], self.it)
        self.tb_log.add_histogram('val_ref_z', ref_box_list[:, 2], self.it)
        self.tb_log.add_histogram('val_ref_y', ref_box_list[:, 1], self.it)
        self.tb_log.add_histogram('val_ref_h', ref_box_list[:, 3], self.it)
        self.tb_log.add_histogram('val_ref_w', ref_box_list[:, 4], self.it)
        self.tb_log.add_histogram('val_ref_l', ref_box_list[:, 5], self.it)
        self.tb_log.add_histogram('val_ref_ry', ref_box_list[:, 6] % (np.pi), self.it)

        self.tb_log.add_histogram('val_pred_iou', p_iou.view(-1), self.it)
        self.tb_log.add_histogram('val_iou_offset', iou_offset, self.it)

        eval_dict['single_recalled_0.5'] = (len(recall05_list) / float(len(single_gt_list)))
        eval_dict['single_recalled_0.7'] = (len(recall07_list) / float(len(single_gt_list)))

        eval_dict['single_ref_recalled_0.5'] = (len(ref_recall05_list) / float(len(single_gt_list)))
        eval_dict['single_ref_recalled_0.7'] = (len(ref_recall07_list) / float(len(single_gt_list)))

        print('Recall_0.5 %.4f.' % eval_dict['recalled_0.5'])
        print('Recall_0.7 %.4f.' % eval_dict['recalled_0.7'])
        print('Recall_ref0.5 %.4f.' % eval_dict['ref_recalled_0.5'])
        print('Recall_ref0.7 %.4f.' % eval_dict['ref_recalled_0.7'])
        return total_loss / count, eval_dict
Exemplo n.º 15
0
    def eval_epoch_rcnn(self, d_loader):
        self.model.eval()

        eval_dict = {'recalled_0.5': 0, 'recalled_0.7': 0}
        total_loss = count = 0.0

        # eval one epoch
        cls=[]
        iou =[]
        p_iou=[]
        iou_offset=[]
        recall05_list = []
        recall07_list = []
        all_gt_list = []
        offset =[]
        for i, data in tqdm.tqdm(enumerate(d_loader, 0), total=len(d_loader), leave=False, desc='val'):
            self.optimizer.zero_grad()

            sample_id = int(data['sample_id'].tolist()[0])
            box_id = int(data['box_id'].tolist()[0])
            if box_id != -1:
                all_gt_list.append([sample_id, box_id])
            if data['cur_box_point'].shape[1]==0: continue
            gt_boxes = torch.from_numpy(data['gt_boxes']).cuda().float()
            loss, tb_dict, disp_dict, visual_dict = self.model_fn_eval(self.model, data)
            total_loss += loss.item()
            #count += 1

            for k, v in tb_dict.items():
                eval_dict[k] = eval_dict.get(k, 0) + v

            rcnn_cls = visual_dict['rcnn_cls']
            cls.append(rcnn_cls)
            cls_label = data['cls'].float().view(-1)
            reg_valid_mask = (cls_label.float()).view(-1)
            fg_mask = reg_valid_mask > 0
            if fg_mask != 0:
                rcnn_reg = visual_dict['rcnn_reg']
                pred_boxes3d = visual_dict['pred_boxes3d']

                _, iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d.squeeze(1), gt_boxes.squeeze(1))
                eye = torch.from_numpy(np.arange(0, iou3d.shape[0]).reshape(-1, 1)).long().cuda()
                iou3d = torch.gather(iou3d, 1, eye)

                eval_dict['recalled_0.5'] += torch.sum((iou3d > 0.5).long()).item()
                eval_dict['recalled_0.7'] += torch.sum((iou3d > 0.7).long()).item()
                if iou3d > 0.5:
                    recall05_list.append([sample_id,box_id])
                if iou3d > 0.7:
                    recall07_list.append([sample_id,box_id])

                iou.append(iou3d.view(-1,1))
                pred_boxes3d = pred_boxes3d.view(-1,7)
                gt_boxes = gt_boxes[:, :, :7].view(-1,7)
                offset.append(pred_boxes3d[:,:] - gt_boxes[:,:7])
                count += 1


        # statistics this epoch
        for k, v in eval_dict.items():
            eval_dict[k] = eval_dict[k] / max(count, 1)

        cls = torch.cat(cls,dim=0)
        offset = torch.cat(offset,dim=0).reshape(-1,7)
        iou = torch.cat(iou, dim=0).reshape(-1, 1)

        single_gt_list = []
        for sample in all_gt_list:
            if not sample in single_gt_list: single_gt_list.append(sample)
        recall05_list = [sample for sample in single_gt_list if sample in recall05_list]
        recall07_list = [sample for sample in single_gt_list if sample in recall07_list]

        eval_dict['single_recalled_0.5'] = (len(recall05_list) / float(len(single_gt_list)))
        eval_dict['single_recalled_0.7'] = (len(recall07_list) / float(len(single_gt_list)))


        self.tb_log.add_histogram('val_cls', cls, self.it)
        self.tb_log.add_histogram('val_iou', iou.view(-1), self.it)
        self.tb_log.add_histogram('val_x_offset', offset[:, 0], self.it)
        self.tb_log.add_histogram('val_y_offset', offset[:, 1], self.it)
        self.tb_log.add_histogram('val_z_offset', offset[:, 2], self.it)
        self.tb_log.add_histogram('val_h_offset', offset[:, 3], self.it)
        self.tb_log.add_histogram('val_w_offset', offset[:, 4], self.it)
        self.tb_log.add_histogram('val_l_offset', offset[:, 5], self.it)
        self.tb_log.add_histogram('val_ry_offset', offset[:, 6], self.it)

        # if cfg.IOUN.ENABLED:
        #     self.tb_log.add_histogram('val_pred_iou', p_iou.view(-1), self.it)
        #     self.tb_log.add_histogram('val_iou_offset', iou_offset, self.it)

        print('Recall_0.5 %.4f.' % eval_dict['recalled_0.5'])
        print('Recall_0.7 %.4f.' % eval_dict['recalled_0.7'])
        print('Single_Recall_0.5 %.4f.' % eval_dict['single_recalled_0.5'])
        print('Single_Recall_0.7 %.4f.' % eval_dict['single_recalled_0.7'])

        return total_loss / count, eval_dict
Exemplo n.º 16
0
def eval_one_epoch_joint(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(666)

    # Loads the mean size of the CLASS from CFG YAML file
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()

    # Assign the MODE as TEST unless EVAL specified
    mode = 'TEST' if args.test else 'EVAL'

    # Make output directory result_dir/final_result/data
    final_output_dir = os.path.join(result_dir, 'final_result', 'data')
    os.makedirs(final_output_dir, exist_ok=True)

    # Save data if args.save_result is True or not(default now True)
    if args.save_result:
        roi_output_dir = os.path.join(result_dir, 'roi_result', 'data')
        refine_output_dir = os.path.join(result_dir, 'refine_result', 'data')
        rpn_output_dir = os.path.join(result_dir, 'rpn_result', 'data')
        os.makedirs(rpn_output_dir, exist_ok=True)
        os.makedirs(roi_output_dir, exist_ok=True)
        os.makedirs(refine_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s JOINT EVALUATION ----' % epoch_id)
    logger.info('==> Output file: %s' % result_dir)
    model.eval()

    # Threshold for IOU
    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    lidar_idx_table = dataset.lidar_idx_table
    cnt = final_total = total_cls_acc = total_cls_acc_refined = total_rpn_iou = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')

    # Iterate through data in dataloader
    for data in dataloader:
        cnt += 1
        sample_id, pts_rect, pts_features, pts_input = data['sample_id'], data[
            'pts_rect'], data['pts_features'], data['pts_input']
        batch_size = len(sample_id)
        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        input_data = {'pts_input': inputs}

        # model inference
        ret_dict = model(input_data)

        roi_scores_raw = ret_dict['roi_scores_raw']  # (B, M)
        roi_boxes3d = ret_dict['rois']  # (B, M, 7)
        seg_result = ret_dict['seg_result'].long()  # (B, N)

        rcnn_cls = ret_dict['rcnn_cls'].view(batch_size, -1,
                                             ret_dict['rcnn_cls'].shape[1])
        rcnn_reg = ret_dict['rcnn_reg'].view(
            batch_size, -1, ret_dict['rcnn_reg'].shape[1])  # (B, M, C)

        # bounding box regression
        anchor_size = MEAN_SIZE
        if cfg.RCNN.SIZE_RES_ON_ROI:
            assert False

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        # scoring
        if rcnn_cls.shape[2] == 1:
            raw_scores = rcnn_cls  # (B, M, 1)
            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores > cfg.RCNN.SCORE_THRESH).long()
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        # evaluation
        recalled_num = gt_num = rpn_iou = 0
        if not args.test:
            if not cfg.RPN.FIXED:
                rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data[
                    'rpn_reg_label']
                rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(
                    non_blocking=True).long()

            gt_boxes3d = data['gt_boxes3d']
            gt_boxes3d = filtrate_gtboxes(gt_boxes3d)

            for k in range(batch_size):
                # calculate recall
                cur_gt_boxes3d = gt_boxes3d[k]
                tmp_idx = cur_gt_boxes3d.__len__() - 1

                while tmp_idx >= 0 and cur_gt_boxes3d[tmp_idx].sum() == 0:
                    tmp_idx -= 1

                if tmp_idx >= 0:
                    cur_gt_boxes3d = cur_gt_boxes3d[:tmp_idx + 1]

                    cur_gt_boxes3d = torch.from_numpy(cur_gt_boxes3d).cuda(
                        non_blocking=True).float()
                    iou3d = iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou, _ = iou3d.max(dim=0)
                    refined_iou, _ = iou3d.max(dim=1)

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num += (gt_max_iou > 0.7).sum().item()
                    gt_num += cur_gt_boxes3d.shape[0]
                    total_gt_bbox += cur_gt_boxes3d.shape[0]

                    # original recall
                    iou3d_in = iou3d_utils.boxes_iou3d_gpu(
                        roi_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou_in, _ = iou3d_in.max(dim=0)

                    for idx, thresh in enumerate(thresh_list):
                        total_roi_recalled_bbox_list[idx] += (
                            gt_max_iou_in > thresh).sum().item()

                if not cfg.RPN.FIXED:
                    fg_mask = rpn_cls_label > 0
                    correct = ((seg_result == rpn_cls_label)
                               & fg_mask).sum().float()
                    union = fg_mask.sum().float() + (seg_result >
                                                     0).sum().float() - correct
                    rpn_iou = correct / torch.clamp(union, min=1.0)
                    total_rpn_iou += rpn_iou.item()

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        if args.save_result:
            # save roi and refine results
            roi_boxes3d_np = roi_boxes3d.cpu().numpy()
            pred_boxes3d_np = pred_boxes3d.cpu().numpy()
            roi_scores_raw_np = roi_scores_raw.cpu().numpy()
            raw_scores_np = raw_scores.cpu().numpy()

            rpn_cls_np = ret_dict['rpn_cls'].cpu().numpy()
            rpn_xyz_np = ret_dict['backbone_xyz'].cpu().numpy()
            print(ret_dict['backbone_xyz'].cpu().numpy()[0].shape)
            rpn_xyz_np = np.dot(
                np.linalg.inv(argo_to_kitti),
                ret_dict['backbone_xyz'].cpu().numpy()[0].T).T.reshape(
                    1, -1, 3)
            seg_result_np = seg_result.cpu().numpy()

            output_data = np.concatenate(
                (rpn_xyz_np, rpn_cls_np.reshape(batch_size, -1, 1),
                 seg_result_np.reshape(batch_size, -1, 1)),
                axis=2)

            for k in range(batch_size):
                cur_sample_id = sample_id[k]
                #calib = dataset.get_calib(cur_sample_id)
                #image_shape = dataset.get_image_shape(cur_sample_id)
                save_argo_format(cur_sample_id, roi_boxes3d_np[k],
                                 roi_output_dir, roi_scores_raw_np[k],
                                 lidar_idx_table)
                save_argo_format(cur_sample_id, pred_boxes3d_np[k],
                                 refine_output_dir, raw_scores_np[k],
                                 lidar_idx_table)
                output_file = os.path.join(
                    rpn_output_dir,
                    lidar_idx_table['%06d' % cur_sample_id] + '.npy')
                np.save(output_file, output_data.astype(np.float32))

        # scores thresh
        inds = norm_scores > cfg.RCNN.SCORE_THRESH

        for k in range(batch_size):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            raw_scores_selected = raw_scores[k, cur_inds]
            norm_scores_selected = norm_scores[k, cur_inds]

            # NMS thresh
            # rotated nms
            boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(
                pred_boxes3d_selected)
            keep_idx = iou3d_utils.nms_gpu(boxes_bev_selected,
                                           raw_scores_selected,
                                           cfg.RCNN.NMS_THRESH).view(-1)
            pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            scores_selected = raw_scores_selected[keep_idx]
            pred_boxes3d_selected, scores_selected = pred_boxes3d_selected.cpu(
            ).numpy(), scores_selected.cpu().numpy()

            cur_sample_id = sample_id[k]
            #calib = dataset.get_calib(cur_sample_id)
            final_total += pred_boxes3d_selected.shape[0]
            #image_shape = dataset.get_image_shape(cur_sample_id)
            save_argo_format(cur_sample_id, pred_boxes3d_selected,
                             final_output_dir, scores_selected,
                             lidar_idx_table)

    progress_bar.close()

    ret_dict = {}
    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info(str(datetime.now()))

    avg_rpn_iou = (total_rpn_iou / max(cnt, 1.0))
    avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
    avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
    avg_det_num = (final_total / max(len(dataset), 1.0))
    logger.info('final average detections: %.3f' % avg_det_num)
    logger.info('final average rpn_iou refined: %.3f' % avg_rpn_iou)
    logger.info('final average cls acc: %.3f' % avg_cls_acc)
    logger.info('final average cls acc refined: %.3f' % avg_cls_acc_refined)
    ret_dict['rpn_iou'] = avg_rpn_iou
    ret_dict['rcnn_cls_acc'] = avg_cls_acc
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    logger.info('result is saved to: %s' % result_dir)
    return ret_dict
Exemplo n.º 17
0
def eval_one_epoch_rpn(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(1024)
    mode = 'TEST' if args.test else 'EVAL'

    if args.save_rpn_feature:
        kitti_features_dir = os.path.join(result_dir, 'features')
        os.makedirs(kitti_features_dir, exist_ok=True)

    if args.save_result or args.save_rpn_feature:
        kitti_output_dir = os.path.join(result_dir, 'detections', 'data')
        seg_output_dir = os.path.join(result_dir, 'seg_result')
        os.makedirs(kitti_output_dir, exist_ok=True)
        os.makedirs(seg_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s RPN EVALUATION ----' % epoch_id)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    dataset = dataloader.dataset
    cnt = max_num = rpn_iou_avg = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')

    for data in dataloader:
        sample_id_list, pts_rect, pts_features, pts_input = \
            data['sample_id'], data['pts_rect'], data['pts_features'], data['pts_input']
        sample_id = sample_id_list[0]
        cnt += len(sample_id_list)

        if not args.test:
            rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data[
                'rpn_reg_label']
            gt_boxes3d = data['gt_boxes3d']

            rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(
                non_blocking=True).long()
            if gt_boxes3d.shape[1] == 0:  # (B, M, 7)
                pass
                # logger.info('%06d: No gt box' % sample_id)
            else:
                gt_boxes3d = torch.from_numpy(gt_boxes3d).cuda(
                    non_blocking=True).float()

        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        input_data = {'pts_input': inputs}

        # model inference
        ret_dict = model(input_data)
        rpn_cls, rpn_reg = ret_dict['rpn_cls'], ret_dict['rpn_reg']
        backbone_xyz, backbone_features = ret_dict['backbone_xyz'], ret_dict[
            'backbone_features']

        rpn_scores_raw = rpn_cls[:, :, 0]
        rpn_scores = torch.sigmoid(rpn_scores_raw)
        seg_result = (rpn_scores > cfg.RPN.SCORE_THRESH).long()

        # proposal layer
        rois, roi_scores_raw = model.rpn.proposal_layer(
            rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)
        batch_size = rois.shape[0]

        # calculate recall and save results to file
        for bs_idx in range(batch_size):
            cur_sample_id = sample_id_list[bs_idx]
            cur_scores_raw = roi_scores_raw[bs_idx]  # (N)
            cur_boxes3d = rois[bs_idx]  # (N, 7)
            cur_seg_result = seg_result[bs_idx]
            cur_pts_rect = pts_rect[bs_idx]

            # calculate recall
            if not args.test:
                cur_rpn_cls_label = rpn_cls_label[bs_idx]
                cur_gt_boxes3d = gt_boxes3d[bs_idx]

                k = cur_gt_boxes3d.__len__() - 1
                while k > 0 and cur_gt_boxes3d[k].sum() == 0:
                    k -= 1
                cur_gt_boxes3d = cur_gt_boxes3d[:k + 1]

                recalled_num = 0
                if cur_gt_boxes3d.shape[0] > 0:
                    iou3d = iou3d_utils.boxes_iou3d_gpu(
                        cur_boxes3d, cur_gt_boxes3d[:, 0:7])
                    gt_max_iou, _ = iou3d.max(dim=0)

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num = (gt_max_iou > 0.7).sum().item()
                    total_gt_bbox += cur_gt_boxes3d.__len__()

                fg_mask = cur_rpn_cls_label > 0
                correct = ((cur_seg_result == cur_rpn_cls_label)
                           & fg_mask).sum().float()
                union = fg_mask.sum().float() + (cur_seg_result >
                                                 0).sum().float() - correct
                rpn_iou = correct / torch.clamp(union, min=1.0)
                rpn_iou_avg += rpn_iou.item()

            # save result
            if args.save_rpn_feature:
                # save features to file
                save_rpn_features(
                    seg_result[bs_idx].float().cpu().numpy(),
                    rpn_scores_raw[bs_idx].float().cpu().numpy(),
                    pts_features[bs_idx], backbone_xyz[bs_idx].cpu().numpy(),
                    backbone_features[bs_idx].cpu().numpy().transpose(1, 0),
                    kitti_features_dir, cur_sample_id)

            if args.save_result or args.save_rpn_feature:
                cur_pred_cls = cur_seg_result.cpu().numpy()
                output_file = os.path.join(seg_output_dir,
                                           '%06d.npy' % cur_sample_id)
                if not args.test:
                    cur_gt_cls = cur_rpn_cls_label.cpu().numpy()
                    output_data = np.concatenate(
                        (cur_pts_rect.reshape(-1, 3), cur_gt_cls.reshape(
                            -1, 1), cur_pred_cls.reshape(-1, 1)),
                        axis=1)
                else:
                    output_data = np.concatenate((cur_pts_rect.reshape(
                        -1, 3), cur_pred_cls.reshape(-1, 1)),
                                                 axis=1)

                np.save(output_file, output_data.astype(np.float16))

                # save as kitti format
                calib = dataset.get_calib(cur_sample_id)
                cur_boxes3d = cur_boxes3d.cpu().numpy()
                image_shape = dataset.get_image_shape(cur_sample_id)
                save_kitti_format(cur_sample_id, calib, cur_boxes3d,
                                  kitti_output_dir, cur_scores_raw,
                                  image_shape)

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox),
            'rpn_iou': rpn_iou_avg / max(cnt, 1.0)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

    progress_bar.close()

    logger.info(str(datetime.now()))
    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info('max number of objects: %d' % max_num)
    logger.info('rpn iou avg: %f' % (rpn_iou_avg / max(cnt, 1.0)))

    ret_dict = {'max_obj_num': max_num, 'rpn_iou': rpn_iou_avg / cnt}

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_recall
    logger.info('result is saved to: %s' % result_dir)

    return ret_dict
Exemplo n.º 18
0
    def sample_rois_for_rcnn(self, roi_boxes3d, gt_boxes3d):
        """
        :param roi_boxes3d: (B, M, 7)
        :param gt_boxes3d: (B, N, 8) [x, y, z, h, w, l, ry, cls]
        :return
            batch_rois: (B, N, 7)
            batch_gt_of_rois: (B, N, 8)
            batch_roi_iou: (B, N)
        """
        # print(gt_boxes3d[0].size()) # size([15(~20), 7])
        batch_size = roi_boxes3d.size(0)

        fg_rois_per_image = int(
            np.round(cfg.RCNN.FG_RATIO * cfg.RCNN.ROI_PER_IMAGE))

        batch_rois = gt_boxes3d.new(batch_size, cfg.RCNN.ROI_PER_IMAGE,
                                    7).zero_()
        batch_gt_of_rois = gt_boxes3d.new(batch_size, cfg.RCNN.ROI_PER_IMAGE,
                                          8).zero_()  #####
        # print(batch_gt_of_rois.size()) # size([4, 64, 7])
        batch_roi_iou = gt_boxes3d.new(batch_size,
                                       cfg.RCNN.ROI_PER_IMAGE).zero_()

        cls_list = []  #### for cls labeling

        for idx in range(batch_size):
            cur_roi, cur_gt = roi_boxes3d[idx], gt_boxes3d[idx]

            k = cur_gt.__len__() - 1
            while cur_gt[k].sum() == 0:
                k -= 1
            cur_gt = cur_gt[:k + 1]

            # include gt boxes in the candidate rois
            # print("1", cur_gt.size()) # size([15..., 8])
            # print("2", cur_roi.size()) # size([512, 7])
            iou3d = iou3d_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:,
                                                                0:7])  # (M, N)
            # print("iou3d",iou3d.size()) # size([512, (13,17,10,...)])

            max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
            # print("iou3d", iou3d) # [512, (13,17,...)]
            # print("max_overlaps", max_overlaps) # 512 => max iou3d
            # print("gt_assignment", gt_assignment) # 512 => index of max iou3d

            # sample fg, easy_bg, hard_bg
            fg_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH)
            # print("fg_thresh", fg_thresh) # 0.55 4번 출력
            fg_inds = torch.nonzero((max_overlaps >= fg_thresh)).view(-1)
            # 512개의 gt_assignment중에 0이 아닌 것의 index를 출력 결과는 계속 달라진다
            # print(fg_inds)
            # print(fg_inds.size()) # 23 or 97, ....

            # TODO: this will mix the fg and bg when CLS_BG_THRESH_LO < iou < CLS_BG_THRESH
            # fg_inds = torch.cat((fg_inds, roi_assignment), dim=0)  # consider the roi which has max_iou with gt as fg

            easy_bg_inds = torch.nonzero(
                (max_overlaps < cfg.RCNN.CLS_BG_THRESH_LO)).view(-1)
            hard_bg_inds = torch.nonzero(
                (max_overlaps < cfg.RCNN.CLS_BG_THRESH)
                & (max_overlaps >= cfg.RCNN.CLS_BG_THRESH_LO)).view(-1)

            fg_num_rois = fg_inds.numel()
            bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()
            # print(fg_num_rois) # 23 97 79 64 이런 숫자가 4번 반복 = 위에 있는 fg_inds.size()와 동일
            # print(bg_num_rois) # 353 360 365 427 이런 숫자가 4번 반복

            if fg_num_rois > 0 and bg_num_rois > 0:  # Use this!
                # sampling fg
                fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)
                # print("fg_rois_per_this_image", fg_rois_per_this_image) # 32개의 sample
                rand_num = torch.from_numpy(np.random.permutation(
                    fg_num_rois)).type_as(gt_boxes3d).long()
                # print("rand_num", rand_num)
                fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]
                # print("fg_inds", fg_inds)
                # print(fg_inds.size()) # [32] batch size인 4번 반복

                # sampling bg
                bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE - fg_rois_per_this_image
                # print("bg_rois_per_this_image", bg_rois_per_this_image) # 32 samples
                bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds,
                                              bg_rois_per_this_image)

            elif fg_num_rois > 0 and bg_num_rois == 0:
                # sampling fg
                rand_num = np.floor(
                    np.random.rand(cfg.RCNN.ROI_PER_IMAGE) * fg_num_rois)
                rand_num = torch.from_numpy(rand_num).type_as(
                    gt_boxes3d).long()
                fg_inds = fg_inds[rand_num]
                fg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE
                bg_rois_per_this_image = 0
            elif bg_num_rois > 0 and fg_num_rois == 0:
                # sampling bg
                bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE
                bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds,
                                              bg_rois_per_this_image)

                fg_rois_per_this_image = 0
            else:
                import pdb
                pdb.set_trace()
                raise NotImplementedError

            # augment the rois by noise
            roi_list, roi_iou_list, roi_gt_list, cls_list_pre = [], [], [], []
            if fg_rois_per_this_image > 0:  # Use this!
                fg_rois_src = cur_roi[fg_inds]
                gt_of_fg_rois = cur_gt[gt_assignment[fg_inds]]
                iou3d_src = max_overlaps[fg_inds]
                # print("1", fg_rois_src.size()) # size([32,7])
                # print("2 fg_cls", gt_of_fg_rois[:,7]) # size([32,8]) composed of the value 1,2,3
                # print("3", iou3d_src.size()) # size([32])
                fg_rois, fg_iou3d = self.aug_roi_by_noise_torch(
                    fg_rois_src,
                    gt_of_fg_rois,
                    iou3d_src,
                    aug_times=cfg.RCNN.ROI_FG_AUG_TIMES)
                roi_list.append(fg_rois)
                roi_iou_list.append(fg_iou3d)
                roi_gt_list.append(gt_of_fg_rois)

            if bg_rois_per_this_image > 0:  # Use this!
                bg_rois_src = cur_roi[bg_inds]
                gt_of_bg_rois = cur_gt[gt_assignment[bg_inds]]
                iou3d_src = max_overlaps[bg_inds]
                # print("bg_cls", gt_of_bg_rois[:,7])
                aug_times = 1 if cfg.RCNN.ROI_FG_AUG_TIMES > 0 else 0
                bg_rois, bg_iou3d = self.aug_roi_by_noise_torch(
                    bg_rois_src, gt_of_bg_rois, iou3d_src, aug_times=aug_times)
                roi_list.append(bg_rois)
                roi_iou_list.append(bg_iou3d)
                roi_gt_list.append(gt_of_bg_rois)

            ##### for cls list
            cls_list_pre = torch.cat(
                (gt_of_fg_rois[:, 7], gt_of_bg_rois[:, 7]), 0).unsqueeze(dim=0)
            cls_list.append(cls_list_pre)

            rois = torch.cat(roi_list, dim=0)
            iou_of_rois = torch.cat(roi_iou_list, dim=0)
            gt_of_rois = torch.cat(roi_gt_list, dim=0)

            batch_rois[idx] = rois
            batch_gt_of_rois[idx] = gt_of_rois
            batch_roi_iou[idx] = iou_of_rois
        # print("batch_roi_iou", batch_roi_iou.size()) # size([4, 64])
        # total 64 = fg 32 + bg 32
        cls_list = torch.cat(
            (cls_list[0], cls_list[1], cls_list[2], cls_list[3])).long()
        # print("cls_list", cls_list) # size([4, 64])

        return batch_rois, batch_gt_of_rois, batch_roi_iou, cls_list
Exemplo n.º 19
0
def eval_one_epoch_joint(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(666)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')
    os.makedirs(final_output_dir, exist_ok=True)

    if args.save_result:
        roi_output_dir = os.path.join(result_dir, 'roi_result', 'data')
        refine_output_dir = os.path.join(result_dir, 'refine_result', 'data')
        rpn_output_dir = os.path.join(result_dir, 'rpn_result', 'data')
        os.makedirs(rpn_output_dir, exist_ok=True)
        os.makedirs(roi_output_dir, exist_ok=True)
        os.makedirs(refine_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s JOINT EVALUATION ----' % epoch_id)
    logger.info('==> Output file: %s' % result_dir)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = total_rpn_iou = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')
    for data in dataloader:
        cnt += 1
        sample_id, pts_rect, pts_features, pts_input = \
            data['sample_id'], data['pts_rect'], data['pts_features'], data['pts_input']
        batch_size = len(sample_id)
        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        input_data = {'pts_input': inputs}

        # model inference
        ret_dict = model(input_data)

        roi_scores_raw = ret_dict['roi_scores_raw']  # (B, M)
        roi_boxes3d = ret_dict['rois']  # (B, M, 7)
        seg_result = ret_dict['seg_result'].long()  # (B, N)

        rcnn_cls = ret_dict['rcnn_cls'].view(batch_size, -1,
                                             ret_dict['rcnn_cls'].shape[1])
        rcnn_reg = ret_dict['rcnn_reg'].view(
            batch_size, -1, ret_dict['rcnn_reg'].shape[1])  # (B, M, C)

        # bounding box regression
        anchor_size = MEAN_SIZE
        if cfg.RCNN.SIZE_RES_ON_ROI:
            assert False

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        # scoring
        if rcnn_cls.shape[2] == 1:
            raw_scores = rcnn_cls  # (B, M, 1)

            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores > cfg.RCNN.SCORE_THRESH).long()
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        # evaluation
        recalled_num = gt_num = rpn_iou = 0
        if not args.test:
            if not cfg.RPN.FIXED:
                rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data[
                    'rpn_reg_label']
                rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(
                    non_blocking=True).long()

            gt_boxes3d = data['gt_boxes3d']

            for k in range(batch_size):
                # calculate recall
                cur_gt_boxes3d = gt_boxes3d[k]
                tmp_idx = cur_gt_boxes3d.__len__() - 1

                while tmp_idx >= 0 and cur_gt_boxes3d[tmp_idx].sum() == 0:
                    tmp_idx -= 1

                if tmp_idx >= 0:
                    cur_gt_boxes3d = cur_gt_boxes3d[:tmp_idx + 1]

                    cur_gt_boxes3d = torch.from_numpy(cur_gt_boxes3d).cuda(
                        non_blocking=True).float()
                    iou3d = iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou, _ = iou3d.max(dim=0)
                    refined_iou, _ = iou3d.max(dim=1)

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num += (gt_max_iou > 0.7).sum().item()
                    gt_num += cur_gt_boxes3d.shape[0]
                    total_gt_bbox += cur_gt_boxes3d.shape[0]

                    # original recall
                    iou3d_in = iou3d_utils.boxes_iou3d_gpu(
                        roi_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou_in, _ = iou3d_in.max(dim=0)

                    for idx, thresh in enumerate(thresh_list):
                        total_roi_recalled_bbox_list[idx] += (
                            gt_max_iou_in > thresh).sum().item()

                if not cfg.RPN.FIXED:
                    fg_mask = rpn_cls_label > 0
                    correct = ((seg_result == rpn_cls_label)
                               & fg_mask).sum().float()
                    union = fg_mask.sum().float() + (seg_result >
                                                     0).sum().float() - correct
                    rpn_iou = correct / torch.clamp(union, min=1.0)
                    total_rpn_iou += rpn_iou.item()

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        if args.save_result:
            # save roi and refine results
            roi_boxes3d_np = roi_boxes3d.cpu().numpy()
            pred_boxes3d_np = pred_boxes3d.cpu().numpy()
            roi_scores_raw_np = roi_scores_raw.cpu().numpy()
            raw_scores_np = raw_scores.cpu().numpy()

            rpn_cls_np = ret_dict['rpn_cls'].cpu().numpy()
            rpn_xyz_np = ret_dict['backbone_xyz'].cpu().numpy()
            seg_result_np = seg_result.cpu().numpy()
            output_data = np.concatenate(
                (rpn_xyz_np, rpn_cls_np.reshape(batch_size, -1, 1),
                 seg_result_np.reshape(batch_size, -1, 1)),
                axis=2)

            for k in range(batch_size):
                cur_sample_id = sample_id[k]
                calib = dataset.get_calib(cur_sample_id)
                image_shape = dataset.get_image_shape(cur_sample_id)
                save_kitti_format(cur_sample_id, calib, roi_boxes3d_np[k],
                                  roi_output_dir, roi_scores_raw_np[k],
                                  image_shape)
                save_kitti_format(cur_sample_id, calib, pred_boxes3d_np[k],
                                  refine_output_dir, raw_scores_np[k],
                                  image_shape)

                output_file = os.path.join(rpn_output_dir,
                                           '%06d.npy' % cur_sample_id)
                np.save(output_file, output_data.astype(np.float32))

        # scores thresh
        inds = norm_scores > cfg.RCNN.SCORE_THRESH

        for k in range(batch_size):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            raw_scores_selected = raw_scores[k, cur_inds]
            norm_scores_selected = norm_scores[k, cur_inds]

            # NMS thresh
            # rotated nms
            boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(
                pred_boxes3d_selected)
            keep_idx = iou3d_utils.nms_gpu(boxes_bev_selected,
                                           raw_scores_selected,
                                           cfg.RCNN.NMS_THRESH).view(-1)
            pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            scores_selected = raw_scores_selected[keep_idx]
            pred_boxes3d_selected, scores_selected = pred_boxes3d_selected.cpu(
            ).numpy(), scores_selected.cpu().numpy()

            cur_sample_id = sample_id[k]
            calib = dataset.get_calib(cur_sample_id)
            final_total += pred_boxes3d_selected.shape[0]
            image_shape = dataset.get_image_shape(cur_sample_id)
            save_kitti_format(cur_sample_id, calib, pred_boxes3d_selected,
                              final_output_dir, scores_selected, image_shape)

    progress_bar.close()
    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info(str(datetime.now()))

    avg_rpn_iou = (total_rpn_iou / max(cnt, 1.0))
    avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
    avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
    avg_det_num = (final_total / max(len(dataset), 1.0))
    logger.info('final average detections: %.3f' % avg_det_num)
    logger.info('final average rpn_iou refined: %.3f' % avg_rpn_iou)
    logger.info('final average cls acc: %.3f' % avg_cls_acc)
    logger.info('final average cls acc refined: %.3f' % avg_cls_acc_refined)
    ret_dict['rpn_iou'] = avg_rpn_iou
    ret_dict['rcnn_cls_acc'] = avg_cls_acc
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        logger.info(ap_result_str)
        ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)
    return ret_dict
Exemplo n.º 20
0
def eval_one_epoch_joint(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(666)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')

    if os.path.exists(final_output_dir): shutil.rmtree(final_output_dir)
    os.makedirs(final_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s JOINT EVALUATION ----' % epoch_id)
    logger.info('==> Output file: %s' % result_dir)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = total_rpn_iou = 0
    obj_num = 0
    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')

    iou_list = []
    iou_p_score_list = []
    rcnn_p_score_list = []
    prop_count = 0
    for data in dataloader:

        # Loading sample
        sample_id_list, pts_input = data['sample_id'], data['pts_input']
        sample_id = sample_id_list[0]
        cnt += len(sample_id_list)
        #if cnt < 118: continue
        #load label
        if not args.test:
            gt_boxes3d = data['gt_boxes3d']
            obj_num += gt_boxes3d.shape[1]
            # print(obj_num)
            if gt_boxes3d.shape[1] == 0:  # (B, M, 7)
                pass
            else:
                gt_boxes3d = gt_boxes3d

        # rpn model inference
        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        #inputs = inputs[:,torch.argsort(-inputs[0,:,2])]
        input_data = {'pts_input': inputs}
        ret_dict = model.rpn_forward(input_data)
        rpn_cls, rpn_reg = ret_dict['rpn_cls'], ret_dict['rpn_reg']
        rpn_backbone_xyz, rpn_backbone_features = ret_dict[
            'backbone_xyz'], ret_dict['backbone_features']

        # stage score parsing
        rpn_scores_raw = rpn_cls[:, :, 0]
        rpn_scores_norm = torch.sigmoid(rpn_cls[:, :, 0])
        rcnn_input_scores = rpn_scores_norm.view(-1).clone()
        inputs = inputs.view(-1, inputs.shape[-1])
        rpn_backbone_features = rpn_backbone_features.view(
            -1, rpn_backbone_features.shape[-2])
        rpn_backbone_xyz = rpn_backbone_xyz.view(-1,
                                                 rpn_backbone_xyz.shape[-1])

        # if VISUAL:
        #     order = torch.argsort(-rpn_scores_norm).view(-1)
        #     inputs = inputs.view(-1,inputs.shape[-1])[order]
        #     rpn_scores_norm = rpn_scores_norm.view(-1)[order]
        #     rpn_backbone_features = rpn_backbone_features.view(-1,rpn_backbone_features.shape[-1])[order]
        #
        #     norm_feature = F.normalize(rpn_backbone_features)
        #     similarity = norm_feature.mm(norm_feature.t())
        #
        #     inputs_plt = inputs.detach().cpu().numpy()
        #     scores_plt = rpn_scores_norm.detach().cpu().numpy()
        #     similarity_plt = similarity.detach().cpu().numpy()
        #
        #
        #     fig = plt.figure(figsize=(10, 10))
        #     plt.axes(facecolor='silver')
        #     plt.axis([-30,30,0,70])
        #     plt.title('point_regressed_center %06d'%sample_id)
        #     plt.scatter(inputs_plt[:, 0], inputs_plt[:, 2], s=15, c=scores_plt[:], edgecolor='none',
        #                 cmap=plt.get_cmap('rainbow'), alpha=1, marker='.', vmin=0, vmax=1)
        #     if args.test==False:
        #         gt_boxes3d = gt_boxes3d.reshape(-1,7)
        #         plt.scatter(gt_boxes3d[:, 0], gt_boxes3d[:, 2], s=200, c='blue',
        #                     alpha=0.5, marker='+', vmin=-1, vmax=1)
        #     plt.show()
        #
        #     for i in range(similarity_plt.shape[0]):
        #         fig = plt.figure(figsize=(10, 10))
        #         plt.axes(facecolor='silver')
        #         plt.axis([-30, 30, 0, 70])
        #         sm_plt = similarity_plt[i]
        #         plt.scatter(inputs_plt[i, 0].reshape(-1), inputs_plt[i, 2].reshape(-1), s=400, c='blue',
        #                     alpha=0.5, marker='+', vmin=-1, vmax=1)
        #         plt.scatter(inputs_plt[:, 0], inputs_plt[:, 2], s=15, c=(sm_plt[:]+scores_plt[:])/2, edgecolor='none',
        #                     cmap=plt.get_cmap('rainbow'), alpha=1, marker='.', vmin=0, vmax=1)
        #         plt.show()

        # thresh select and jump out
        # rpn_mask = rpn_scores_norm.view(-1) > cfg.RPN.SCORE_THRESH
        # if rpn_mask.float().sum() == 0: continue
        # rpn_scores_raw = rpn_scores_raw.view(-1)[rpn_mask]
        # rpn_scores_norm = rpn_scores_norm.view(-1)[rpn_mask]
        # rpn_reg = rpn_reg.view(-1, rpn_reg.shape[-1])[rpn_mask]
        # rpn_backbone_xyz = rpn_backbone_xyz.view(-1, rpn_backbone_xyz.shape[-1])[rpn_mask]

        # generate rois

        rpn_rois = decode_center_target(
            rpn_backbone_xyz,
            rpn_reg.view(-1, rpn_reg.shape[-1]),
            loc_scope=cfg.RPN.LOC_SCOPE,
            loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
        ).view(-1, 3)
        rpn_reg_dist = (rpn_rois - rpn_backbone_xyz).clone()
        #similarity = torch.cosine_similarity(rpn_backbone_xyz[:, [0, 2]], rpn_reg_dist[:, [0, 2]], dim=1)

        # # thresh select and jump out
        rpn_mask = (rpn_scores_norm.view(-1) > cfg.RPN.SCORE_THRESH) & (
            rpn_reg_dist[:, [0, 2]].pow(2).sum(-1).sqrt() > 0.2)  #\
        #& (similarity > -0.7)
        if rpn_mask.float().sum() == 0: continue
        rpn_scores_raw = rpn_scores_raw.view(-1)[rpn_mask]
        rpn_scores_norm = rpn_scores_norm.view(-1)[rpn_mask]
        rpn_rois = rpn_rois[rpn_mask]
        rpn_backbone_xyz = rpn_backbone_xyz.view(
            -1, rpn_backbone_xyz.shape[-1])[rpn_mask]

        # radius NMS
        # sort by center score
        sort_points = torch.argsort(-rpn_scores_raw)
        rpn_rois = rpn_rois[sort_points]
        rpn_scores_norm = rpn_scores_norm[sort_points]
        rpn_scores_raw = rpn_scores_raw[sort_points]

        if rpn_rois.shape[0] > 1:
            keep_id = [0]
            prop_prop_distance = distance_2(rpn_rois[:, [0, 2]],
                                            rpn_rois[:, [0, 2]])
            for i in range(1, rpn_rois.shape[0]):
                #if torch.min(prop_prop_distance[:i, i], dim=-1)[0] > 0.3:
                if torch.min(prop_prop_distance[keep_id, i], dim=-1)[0] > 0.3:
                    keep_id.append(i)
            rpn_center = rpn_rois[keep_id][:, [0, 2]]
            rpn_scores_norm = rpn_scores_norm[keep_id]
            rpn_scores_raw = rpn_scores_raw[keep_id]

        else:
            rpn_center = rpn_rois[:, [0, 2]]
            rpn_scores_norm = rpn_scores_norm
            rpn_scores_raw = rpn_scores_raw

        # #rcnn input select:
        point_center_distance = distance_2(rpn_center, inputs[:, [0, 2]])
        cur_proposal_points_index = (torch.min(point_center_distance,
                                               dim=-1)[0] < 4.0)

        point_center_distance = point_center_distance[
            cur_proposal_points_index]
        inputs = inputs[cur_proposal_points_index]
        rcnn_input_scores = rcnn_input_scores.view(
            -1)[cur_proposal_points_index]

        if VISUAL:
            inputs_plt = inputs.detach().cpu().numpy()
            scores_plt = rcnn_input_scores.detach().cpu().numpy()
            # point_center= rpn_center[rpn_scores_norm > 0.5]
            # point_center_score = rpn_scores_norm[rpn_scores_norm > 0.5]
            point_center = rpn_center
            point_center_score = rpn_scores_norm
            fig = plt.figure(figsize=(10, 10))
            plt.axes(facecolor='silver')
            plt.axis([-30, 30, 0, 70])
            point_center_plt = point_center.cpu().numpy()
            plt.title('point_regressed_center %06d' % sample_id)
            plt.scatter(inputs_plt[:, 0],
                        inputs_plt[:, 2],
                        s=15,
                        c=scores_plt[:],
                        edgecolor='none',
                        cmap=plt.get_cmap('rainbow'),
                        alpha=1,
                        marker='.',
                        vmin=0,
                        vmax=1)
            if point_center.shape[0] > 0:
                plt.scatter(point_center_plt[:, 0],
                            point_center_plt[:, 1],
                            s=200,
                            c='white',
                            alpha=0.5,
                            marker='x',
                            vmin=-1,
                            vmax=1)
            if args.test == False:
                gt_boxes3d = gt_boxes3d.reshape(-1, 7)
                plt.scatter(gt_boxes3d[:, 0],
                            gt_boxes3d[:, 2],
                            s=200,
                            c='blue',
                            alpha=0.5,
                            marker='+',
                            vmin=-1,
                            vmax=1)
            plt.savefig('../visual/rpn.jpg')

        # RCNN stage
        box_list = []
        raw_score_list = []
        iou_score_list = []
        inputs[:, 1] -= 1.65
        point_center_distance = distance_2(rpn_center[:, :], inputs[:, [0, 2]])
        #for c in range(min(rpn_center.shape[0],100)):
        prop_count += rpn_center.shape[0]
        print('num %d' % (prop_count / float(cnt)))
        for c in range(rpn_center.shape[0]):
            # rcnn input generate
            cur_input = inputs.clone()
            cur_input_score = rcnn_input_scores.clone()

            # if COSINE_DISTANCE:
            #     cur_center_points_index = ((point_center_distance[:, c] < 4.0) & \
            #                                (point_prop_cos_matrix[:, c] > COS_THRESH) | \
            #                                (point_center_distance[:, c].view(-1) < 0.7)).view(-1)
            # else:
            cur_center_points_index = (point_center_distance[:, c] <
                                       4.0).view(-1)
            if cur_center_points_index.long().sum() == 0: continue

            cur_center_points_xyz = cur_input[cur_center_points_index, :3]
            cur_center_points_xyz[:, 0] -= rpn_center[c, 0]
            cur_center_points_xyz[:, 2] -= rpn_center[c, 1]
            cur_center_points_r = cur_input[cur_center_points_index,
                                            3].view(-1, 1)
            cur_center_points_mask = (cur_input_score[cur_center_points_index]
                                      > 0.5).view(-1, 1).float()

            # # easy sample sampling
            # if pts_input.shape[0]>512:
            #     cur_input = torch.cat((cur_center_points_xyz, cur_center_points_r,
            #                            (cur_input_score[cur_center_points_index] > 0.5).view(-1, 1).float()), dim=-1)
            #     pts_input = cur_input
            #     pts_input = pts_input[:min(pts_input.shape[0], 2000), :]
            #     pts_input = pts_input[:, :]
            #     sample_index = fps(pts_input[:, 0:3].contiguous(), ratio=min(512 / pts_input.shape[0], 0.99),
            #                        random_start=False)
            #     perm = sample_index
            #     while sample_index.shape[0] < 512:
            #         sample_index = torch.cat(
            #             (sample_index, perm[:min(perm.shape[0], 512 - sample_index.shape[0])]), dim=0)
            #
            #     cur_center_points_xyz = pts_input[sample_index, 0:3]
            #     cur_center_points_r = pts_input[sample_index, 3].reshape(-1, 1)
            #     cur_center_points_mask = pts_input[sample_index, 4].reshape(-1, 1)

            cur_center_points_xyz = cur_center_points_xyz.unsqueeze(0).float()
            cur_center_points_r = cur_center_points_r.unsqueeze(0).float()
            cur_center_points_mask = cur_center_points_mask.unsqueeze(
                0).float() - 0.5

            input_data = {
                'cur_box_point': cur_center_points_xyz,
                'cur_box_reflect': cur_center_points_r,
                'train_mask': cur_center_points_mask,
            }

            # # globaly random sampling
            # pts_input = pts_input[:min(pts_input.shape[0], self.npoints), :]
            # sample_index = np.arange(0, pts_input.shape[0], 1).astype(np.int)
            # perm = np.copy(sample_index)
            # while sample_index.shape[0] < self.npoints:
            #     sample_index = np.concatenate(
            #         (sample_index, perm[:min(perm.shape[0], self.npoints - sample_index.shape[0])]))
            #
            # cur_box_point = pts_input[sample_index, 0:3]
            # cur_box_reflect = pts_input[sample_index, 3].reshape(-1, 1)
            # cur_prob_mask = pts_input[sample_index, 4].reshape(-1, 1)
            # gt_mask = pts_input[sample_index, 5].reshape(-1, 1)

            # rcnn model inference
            ret_dict = model.rcnn_forward(input_data)
            rcnn_cls = ret_dict['rcnn_cls']
            ioun_cls = ret_dict['ioun_cls']
            rcnn_reg = ret_dict['rcnn_reg']
            rcnn_iou = ret_dict['rcnn_iou']
            rcnn_ref = ret_dict['rcnn_ref'].view(1, 1, -1)
            rcnn_box3d = ret_dict['pred_boxes3d']
            refined_box = ret_dict['refined_box']

            rcnn_box3d = refined_box
            rcnn_box3d[:, :, 6] = rcnn_box3d[:, :, 6] % (np.pi * 2)
            if rcnn_box3d[:, :, 6] > np.pi: rcnn_box3d[:, :, 6] -= np.pi * 2

            rcnn_box3d[:, :, 0] += rpn_center[c][0]
            rcnn_box3d[:, :, 2] += rpn_center[c][1]
            rcnn_box3d[:, :, 1] += 1.65

            box_list.append(rcnn_box3d)

            raw_score_list.append(rcnn_cls.view(1, 1))
            #raw_score_list.append(ioun_cls.view(1,1))

            iou_score_list.append(rcnn_iou.view(1, 1))

        rcnn_box3d = torch.cat((box_list), dim=1)
        raw_rcnn_score = torch.cat((raw_score_list),
                                   dim=0).unsqueeze(0).float()
        norm_ioun_score = torch.cat((iou_score_list),
                                    dim=0).unsqueeze(0).float()

        # scoring
        pred_boxes3d = rcnn_box3d
        norm_ioun_score = norm_ioun_score
        raw_rcnn_score = raw_rcnn_score
        norm_rcnn_score = torch.sigmoid(raw_rcnn_score)

        # scores thresh
        pred_h = pred_boxes3d[:, :, 3].view(-1)
        pred_w = pred_boxes3d[:, :, 4].view(-1)
        pred_l = pred_boxes3d[:, :, 5].view(-1)
        inds = (norm_rcnn_score > cfg.RCNN.SCORE_THRESH) & (
            norm_ioun_score > cfg.IOUN.SCORE_THRESH)
        inds = inds.view(-1)
        #size filiter
        # inds = inds & \
        #         (pred_h > 1.2) & (pred_h < 2.2) & \
        #         (pred_w > 1.3) & (pred_w < 2.0) & \
        #         (pred_l > 2.2) & (pred_l < 5.0)
        inds = inds & \
                (pred_h > 1.1) & (pred_h < 2.3) & \
                (pred_w > 1.2) & (pred_w < 2.1) & \
                (pred_l > 2.1) & (pred_l < 5.1)

        pred_boxes3d = pred_boxes3d[:, inds]
        norm_rcnn_score = norm_rcnn_score[:, inds]
        norm_ioun_score = norm_ioun_score[:, inds]
        raw_rcnn_score = raw_rcnn_score[:, inds]

        if pred_boxes3d.shape[1] == 0: continue
        # evaluation
        recalled_num = gt_num = 0

        if not args.test:
            gt_boxes3d = data['gt_boxes3d']

            for k in range(1):
                # calculate recall
                cur_gt_boxes3d = gt_boxes3d[k]
                tmp_idx = cur_gt_boxes3d.__len__() - 1

                while tmp_idx >= 0 and cur_gt_boxes3d[tmp_idx].sum() == 0:
                    tmp_idx -= 1

                if tmp_idx >= 0:
                    cur_gt_boxes3d = cur_gt_boxes3d[:tmp_idx + 1]

                    cur_gt_boxes3d = torch.from_numpy(cur_gt_boxes3d).cuda(
                        non_blocking=True).float()
                    _, iou3d = iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou, _ = iou3d.max(dim=0)
                    refined_iou, _ = iou3d.max(dim=1)

                    iou_list.append(refined_iou.view(-1, 1))
                    iou_p_score_list.append(norm_ioun_score.view(-1, 1))
                    rcnn_p_score_list.append(norm_rcnn_score.view(-1, 1))

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num += (gt_max_iou > 0.7).sum().item()
                    gt_num += cur_gt_boxes3d.shape[0]
                    total_gt_bbox += cur_gt_boxes3d.shape[0]

        if cnt == 1000:
            iou_clloe = torch.cat(iou_list, dim=0).detach().cpu().numpy()
            iou_score_clloe = torch.cat(iou_p_score_list,
                                        dim=0).detach().cpu().numpy()
            plt.axis([-.1, 1.1, -.1, 1.1])
            plt.scatter(iou_clloe,
                        iou_score_clloe,
                        s=20,
                        c='blue',
                        edgecolor='none',
                        cmap=plt.get_cmap('YlOrRd'),
                        alpha=1,
                        marker='.')
            plt.savefig(os.path.join(result_dir, 'distributercnn.png'))

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        if VISUAL:
            fig, ax = plt.subplots(figsize=(10, 10))
            inputs_plt = inputs.detach().cpu().numpy()
            #plt.axes(facecolor='silver')
            plt.axis([-35, 35, 0, 70])
            plt.scatter(inputs_plt[:, 0],
                        inputs_plt[:, 2],
                        s=15,
                        c=inputs_plt[:, 1],
                        edgecolor='none',
                        cmap=plt.get_cmap('Blues'),
                        alpha=1,
                        marker='.',
                        vmin=-1,
                        vmax=2)
            pred_boxes3d_numpy = pred_boxes3d[0].detach().cpu().numpy()
            pred_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                pred_boxes3d_numpy, rotate=True)
            for o in range(pred_boxes3d_corner.shape[0]):
                print_box_corner = pred_boxes3d_corner[o]

                x1, x2, x3, x4 = print_box_corner[0:4, 0]
                z1, z2, z3, z4 = print_box_corner[0:4, 2]

                polygon = np.zeros([5, 2], dtype=np.float32)
                polygon[0, 0] = x1
                polygon[1, 0] = x2
                polygon[2, 0] = x3
                polygon[3, 0] = x4
                polygon[4, 0] = x1

                polygon[0, 1] = z1
                polygon[1, 1] = z2
                polygon[2, 1] = z3
                polygon[3, 1] = z4
                polygon[4, 1] = z1

                line1 = [(x1, z1), (x2, z2)]
                line2 = [(x2, z2), (x3, z3)]
                line3 = [(x3, z3), (x4, z4)]
                line4 = [(x4, z4), (x1, z1)]
                (line1_xs, line1_ys) = zip(*line1)
                (line2_xs, line2_ys) = zip(*line2)
                (line3_xs, line3_ys) = zip(*line3)
                (line4_xs, line4_ys) = zip(*line4)
                ax.add_line(
                    Line2D(line1_xs, line1_ys, linewidth=1, color='green'))
                ax.add_line(
                    Line2D(line2_xs, line2_ys, linewidth=1, color='red'))
                ax.add_line(
                    Line2D(line3_xs, line3_ys, linewidth=1, color='red'))
                ax.add_line(
                    Line2D(line4_xs, line4_ys, linewidth=1, color='red'))

                # gt visualize

            if args.test == False and data['gt_boxes3d'].shape[1] > 0:
                gt_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                    data['gt_boxes3d'].reshape(-1, 7), rotate=True)

                for o in range(gt_boxes3d_corner.shape[0]):
                    print_box_corner = gt_boxes3d_corner[o]

                    x1, x2, x3, x4 = print_box_corner[0:4, 0]
                    z1, z2, z3, z4 = print_box_corner[0:4, 2]

                    polygon = np.zeros([5, 2], dtype=np.float32)
                    polygon[0, 0] = x1
                    polygon[1, 0] = x2
                    polygon[2, 0] = x3
                    polygon[3, 0] = x4
                    polygon[4, 0] = x1

                    polygon[0, 1] = z1
                    polygon[1, 1] = z2
                    polygon[2, 1] = z3
                    polygon[3, 1] = z4
                    polygon[4, 1] = z1

                    line1 = [(x1, z1), (x2, z2)]
                    line2 = [(x2, z2), (x3, z3)]
                    line3 = [(x3, z3), (x4, z4)]
                    line4 = [(x4, z4), (x1, z1)]
                    (line1_xs, line1_ys) = zip(*line1)
                    (line2_xs, line2_ys) = zip(*line2)
                    (line3_xs, line3_ys) = zip(*line3)
                    (line4_xs, line4_ys) = zip(*line4)
                    ax.add_line(
                        Line2D(line1_xs, line1_ys, linewidth=1,
                               color='yellow'))
                    ax.add_line(
                        Line2D(line2_xs, line2_ys, linewidth=1,
                               color='purple'))
                    ax.add_line(
                        Line2D(line3_xs, line3_ys, linewidth=1,
                               color='purple'))
                    ax.add_line(
                        Line2D(line4_xs, line4_ys, linewidth=1,
                               color='purple'))
            plt.savefig('../visual/rcnn.jpg')

        # scores thresh
        inds = (norm_rcnn_score > cfg.RCNN.SCORE_THRESH) & (
            norm_ioun_score > cfg.IOUN.SCORE_THRESH)
        #inds = (norm_ioun_score > cfg.IOUN.SCORE_THRESH)

        for k in range(1):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            norm_iou_scores_selected = norm_ioun_score[k, cur_inds]
            raw_rcnn_score_selected = raw_rcnn_score[k, cur_inds]

            #traditional nms
            # NMS thresh rotated nms
            # boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(pred_boxes3d_selected)
            # #score NMS
            # # boxes_bev_selected[:,-1] += np.pi/2
            # keep_idx = iou3d_utils.nms_normal_gpu(boxes_bev_selected, norm_iou_scores_selected, cfg.RCNN.NMS_THRESH).view(-1)
            # pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            # norm_iou_scores_selected = norm_iou_scores_selected[keep_idx]
            # raw_rcnn_score_selected = raw_rcnn_score_selected[keep_idx]

            #self NMS
            sort_boxes = torch.argsort(-norm_iou_scores_selected.view(-1))
            pred_boxes3d_selected = pred_boxes3d_selected[sort_boxes]
            norm_iou_scores_selected = norm_iou_scores_selected[sort_boxes]

            if pred_boxes3d_selected.shape[0] > 1:
                keep_id = [0]
                iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(
                    pred_boxes3d_selected, pred_boxes3d_selected)
                for i in range(1, pred_boxes3d_selected.shape[0]):
                    # if torch.min(prop_prop_distance[:i, i], dim=-1)[0] > 0.3:
                    if torch.max(iou2d[keep_id, i], dim=-1)[0] < 0.01:
                        keep_id.append(i)
                pred_boxes3d_selected = pred_boxes3d_selected[keep_id]
                norm_iou_scores_selected = norm_iou_scores_selected[keep_id]
            else:
                pred_boxes3d_selected = pred_boxes3d_selected
                norm_iou_scores_selected = norm_iou_scores_selected

            pred_boxes3d_selected, norm_iou_scores_selected = pred_boxes3d_selected.cpu(
            ).numpy(), norm_iou_scores_selected.cpu().numpy()

            cur_sample_id = sample_id
            calib = dataset.get_calib(cur_sample_id)
            final_total += pred_boxes3d_selected.shape[0]
            image_shape = dataset.get_image_shape(cur_sample_id)
            save_kitti_format(cur_sample_id, calib, pred_boxes3d_selected,
                              final_output_dir, norm_iou_scores_selected,
                              image_shape)

            if VISUAL:
                fig, ax = plt.subplots(figsize=(10, 10))
                inputs_plt = inputs.detach().cpu().numpy()
                # plt.axes(facecolor='silver')
                plt.axis([-35, 35, 0, 70])
                plt.scatter(inputs_plt[:, 0],
                            inputs_plt[:, 2],
                            s=15,
                            c=inputs_plt[:, 1],
                            edgecolor='none',
                            cmap=plt.get_cmap('Blues'),
                            alpha=1,
                            marker='.',
                            vmin=-1,
                            vmax=2)
                pred_boxes3d_numpy = pred_boxes3d_selected
                pred_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                    pred_boxes3d_numpy, rotate=True)
                for o in range(pred_boxes3d_corner.shape[0]):
                    print_box_corner = pred_boxes3d_corner[o]

                    x1, x2, x3, x4 = print_box_corner[0:4, 0]
                    z1, z2, z3, z4 = print_box_corner[0:4, 2]

                    polygon = np.zeros([5, 2], dtype=np.float32)
                    polygon[0, 0] = x1
                    polygon[1, 0] = x2
                    polygon[2, 0] = x3
                    polygon[3, 0] = x4
                    polygon[4, 0] = x1

                    polygon[0, 1] = z1
                    polygon[1, 1] = z2
                    polygon[2, 1] = z3
                    polygon[3, 1] = z4
                    polygon[4, 1] = z1

                    line1 = [(x1, z1), (x2, z2)]
                    line2 = [(x2, z2), (x3, z3)]
                    line3 = [(x3, z3), (x4, z4)]
                    line4 = [(x4, z4), (x1, z1)]
                    (line1_xs, line1_ys) = zip(*line1)
                    (line2_xs, line2_ys) = zip(*line2)
                    (line3_xs, line3_ys) = zip(*line3)
                    (line4_xs, line4_ys) = zip(*line4)
                    ax.add_line(
                        Line2D(line1_xs, line1_ys, linewidth=1, color='green'))
                    ax.add_line(
                        Line2D(line2_xs, line2_ys, linewidth=1, color='red'))
                    ax.add_line(
                        Line2D(line3_xs, line3_ys, linewidth=1, color='red'))
                    ax.add_line(
                        Line2D(line4_xs, line4_ys, linewidth=1, color='red'))

                    # gt visualize

                if args.test == False and data['gt_boxes3d'].shape[1] > 0:
                    gt_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                        data['gt_boxes3d'].reshape(-1, 7), rotate=True)

                    for o in range(gt_boxes3d_corner.shape[0]):
                        print_box_corner = gt_boxes3d_corner[o]

                        x1, x2, x3, x4 = print_box_corner[0:4, 0]
                        z1, z2, z3, z4 = print_box_corner[0:4, 2]

                        polygon = np.zeros([5, 2], dtype=np.float32)
                        polygon[0, 0] = x1
                        polygon[1, 0] = x2
                        polygon[2, 0] = x3
                        polygon[3, 0] = x4
                        polygon[4, 0] = x1

                        polygon[0, 1] = z1
                        polygon[1, 1] = z2
                        polygon[2, 1] = z3
                        polygon[3, 1] = z4
                        polygon[4, 1] = z1

                        line1 = [(x1, z1), (x2, z2)]
                        line2 = [(x2, z2), (x3, z3)]
                        line3 = [(x3, z3), (x4, z4)]
                        line4 = [(x4, z4), (x1, z1)]
                        (line1_xs, line1_ys) = zip(*line1)
                        (line2_xs, line2_ys) = zip(*line2)
                        (line3_xs, line3_ys) = zip(*line3)
                        (line4_xs, line4_ys) = zip(*line4)
                        ax.add_line(
                            Line2D(line1_xs,
                                   line1_ys,
                                   linewidth=1,
                                   color='yellow'))
                        ax.add_line(
                            Line2D(line2_xs,
                                   line2_ys,
                                   linewidth=1,
                                   color='purple'))
                        ax.add_line(
                            Line2D(line3_xs,
                                   line3_ys,
                                   linewidth=1,
                                   color='purple'))
                        ax.add_line(
                            Line2D(line4_xs,
                                   line4_ys,
                                   linewidth=1,
                                   color='purple'))
                plt.savefig('../visual/ioun.jpg')

    progress_bar.close()
    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    if not args.eval_all:
        logger.info(
            '-------------------performance of epoch %s---------------------' %
            epoch_id)
        logger.info(str(datetime.now()))

        avg_rpn_iou = (total_rpn_iou / max(cnt, 1.0))
        avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
        avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
        avg_det_num = (final_total / max(len(dataset), 1.0))
        logger.info('final average detections: %.3f' % avg_det_num)
        logger.info('final average rpn_iou refined: %.3f' % avg_rpn_iou)
        logger.info('final average cls acc: %.3f' % avg_cls_acc)
        logger.info('final average cls acc refined: %.3f' %
                    avg_cls_acc_refined)
        ret_dict['rpn_iou'] = avg_rpn_iou
        ret_dict['rcnn_cls_acc'] = avg_cls_acc
        ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
        ret_dict['rcnn_avg_num'] = avg_det_num

        for idx, thresh in enumerate(thresh_list):
            cur_recall = total_recalled_bbox_list[idx] / max(
                total_gt_bbox, 1.0)
            logger.info('total bbox recall(thresh=%.3f): %d / %d = %f' %
                        (thresh, total_recalled_bbox_list[idx], total_gt_bbox,
                         cur_recall))
            ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall
            if thresh == 0.7:
                recall = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        if not args.eval_all:
            logger.info(ap_result_str)
            ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)
    precision = ap_dict['Car_3d_easy'] + ap_dict['Car_3d_moderate'] + ap_dict[
        'Car_3d_hard']
    recall = total_recalled_bbox_list[3] / max(total_gt_bbox, 1.0)
    F2_score = 0
    return precision, recall, F2_score