Example #1
0
def kitti_eval():
    if float(sys.version[:3]) < 3.6:
        print("KITTI mAP evaluation can only run with python3.6+")
        sys.exit(1)

    args = parse_args()

    label_dir = os.path.join(args.data_dir, 'KITTI/object/training', 'label_2')
    split_file = os.path.join(args.data_dir, 'KITTI/ImageSets',
                              '{}.txt'.format(args.split))
    final_output_dir = os.path.join(args.result_dir, 'final_result', 'data')
    name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}

    from tools.kitti_object_eval_python.evaluate import evaluate as kitti_evaluate 
    ap_result_str, ap_dict = kitti_evaluate(
        label_dir, final_output_dir, label_split_file=split_file,
         current_class=name_to_class[args.class_name])

    print("KITTI evaluate: ", ap_result_str, ap_dict)
Example #2
0
def eval_one_epoch_joint(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(666)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')
    os.makedirs(final_output_dir, exist_ok=True)

    if args.save_result:
        roi_output_dir = os.path.join(result_dir, 'roi_result', 'data')
        refine_output_dir = os.path.join(result_dir, 'refine_result', 'data')
        rpn_output_dir = os.path.join(result_dir, 'rpn_result', 'data')
        os.makedirs(rpn_output_dir, exist_ok=True)
        os.makedirs(roi_output_dir, exist_ok=True)
        os.makedirs(refine_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s JOINT EVALUATION ----' % epoch_id)
    logger.info('==> Output file: %s' % result_dir)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = total_rpn_iou = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')
    for data in dataloader:
        cnt += 1
        sample_id, pts_rect, pts_features, pts_input = \
            data['sample_id'], data['pts_rect'], data['pts_features'], data['pts_input']
        batch_size = len(sample_id)
        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        input_data = {'pts_input': inputs}

        # model inference
        ret_dict = model(input_data)

        roi_scores_raw = ret_dict['roi_scores_raw']  # (B, M)
        roi_boxes3d = ret_dict['rois']  # (B, M, 7)
        seg_result = ret_dict['seg_result'].long()  # (B, N)

        rcnn_cls = ret_dict['rcnn_cls'].view(batch_size, -1,
                                             ret_dict['rcnn_cls'].shape[1])
        rcnn_reg = ret_dict['rcnn_reg'].view(
            batch_size, -1, ret_dict['rcnn_reg'].shape[1])  # (B, M, C)

        # bounding box regression
        anchor_size = MEAN_SIZE
        if cfg.RCNN.SIZE_RES_ON_ROI:
            assert False

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        # scoring
        if rcnn_cls.shape[2] == 1:
            raw_scores = rcnn_cls  # (B, M, 1)

            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores > cfg.RCNN.SCORE_THRESH).long()
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        # evaluation
        recalled_num = gt_num = rpn_iou = 0
        if not args.test:
            if not cfg.RPN.FIXED:
                rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data[
                    'rpn_reg_label']
                rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(
                    non_blocking=True).long()

            gt_boxes3d = data['gt_boxes3d']

            for k in range(batch_size):
                # calculate recall
                cur_gt_boxes3d = gt_boxes3d[k]
                tmp_idx = cur_gt_boxes3d.__len__() - 1

                while tmp_idx >= 0 and cur_gt_boxes3d[tmp_idx].sum() == 0:
                    tmp_idx -= 1

                if tmp_idx >= 0:
                    cur_gt_boxes3d = cur_gt_boxes3d[:tmp_idx + 1]

                    cur_gt_boxes3d = torch.from_numpy(cur_gt_boxes3d).cuda(
                        non_blocking=True).float()
                    iou3d = iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou, _ = iou3d.max(dim=0)
                    refined_iou, _ = iou3d.max(dim=1)

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num += (gt_max_iou > 0.7).sum().item()
                    gt_num += cur_gt_boxes3d.shape[0]
                    total_gt_bbox += cur_gt_boxes3d.shape[0]

                    # original recall
                    iou3d_in = iou3d_utils.boxes_iou3d_gpu(
                        roi_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou_in, _ = iou3d_in.max(dim=0)

                    for idx, thresh in enumerate(thresh_list):
                        total_roi_recalled_bbox_list[idx] += (
                            gt_max_iou_in > thresh).sum().item()

                if not cfg.RPN.FIXED:
                    fg_mask = rpn_cls_label > 0
                    correct = ((seg_result == rpn_cls_label)
                               & fg_mask).sum().float()
                    union = fg_mask.sum().float() + (seg_result >
                                                     0).sum().float() - correct
                    rpn_iou = correct / torch.clamp(union, min=1.0)
                    total_rpn_iou += rpn_iou.item()

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        if args.save_result:
            # save roi and refine results
            roi_boxes3d_np = roi_boxes3d.cpu().numpy()
            pred_boxes3d_np = pred_boxes3d.cpu().numpy()
            roi_scores_raw_np = roi_scores_raw.cpu().numpy()
            raw_scores_np = raw_scores.cpu().numpy()

            rpn_cls_np = ret_dict['rpn_cls'].cpu().numpy()
            rpn_xyz_np = ret_dict['backbone_xyz'].cpu().numpy()
            seg_result_np = seg_result.cpu().numpy()
            output_data = np.concatenate(
                (rpn_xyz_np, rpn_cls_np.reshape(batch_size, -1, 1),
                 seg_result_np.reshape(batch_size, -1, 1)),
                axis=2)

            for k in range(batch_size):
                cur_sample_id = sample_id[k]
                calib = dataset.get_calib(cur_sample_id)
                image_shape = dataset.get_image_shape(cur_sample_id)
                save_kitti_format(cur_sample_id, calib, roi_boxes3d_np[k],
                                  roi_output_dir, roi_scores_raw_np[k],
                                  image_shape)
                save_kitti_format(cur_sample_id, calib, pred_boxes3d_np[k],
                                  refine_output_dir, raw_scores_np[k],
                                  image_shape)

                output_file = os.path.join(rpn_output_dir,
                                           '%06d.npy' % cur_sample_id)
                np.save(output_file, output_data.astype(np.float32))

        # scores thresh
        inds = norm_scores > cfg.RCNN.SCORE_THRESH

        for k in range(batch_size):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            raw_scores_selected = raw_scores[k, cur_inds]
            norm_scores_selected = norm_scores[k, cur_inds]

            # NMS thresh
            # rotated nms
            boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(
                pred_boxes3d_selected)
            keep_idx = iou3d_utils.nms_gpu(boxes_bev_selected,
                                           raw_scores_selected,
                                           cfg.RCNN.NMS_THRESH).view(-1)
            pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            scores_selected = raw_scores_selected[keep_idx]
            pred_boxes3d_selected, scores_selected = pred_boxes3d_selected.cpu(
            ).numpy(), scores_selected.cpu().numpy()

            cur_sample_id = sample_id[k]
            calib = dataset.get_calib(cur_sample_id)
            final_total += pred_boxes3d_selected.shape[0]
            image_shape = dataset.get_image_shape(cur_sample_id)
            save_kitti_format(cur_sample_id, calib, pred_boxes3d_selected,
                              final_output_dir, scores_selected, image_shape)

    progress_bar.close()
    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info(str(datetime.now()))

    avg_rpn_iou = (total_rpn_iou / max(cnt, 1.0))
    avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
    avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
    avg_det_num = (final_total / max(len(dataset), 1.0))
    logger.info('final average detections: %.3f' % avg_det_num)
    logger.info('final average rpn_iou refined: %.3f' % avg_rpn_iou)
    logger.info('final average cls acc: %.3f' % avg_cls_acc)
    logger.info('final average cls acc refined: %.3f' % avg_cls_acc_refined)
    ret_dict['rpn_iou'] = avg_rpn_iou
    ret_dict['rcnn_cls_acc'] = avg_cls_acc
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        logger.info(ap_result_str)
        ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)
    return ret_dict
Example #3
0
def eval_one_epoch_rcnn(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(1024)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')
    os.makedirs(final_output_dir, exist_ok=True)

    if args.save_result:
        roi_output_dir = os.path.join(result_dir, 'roi_result', 'data')
        refine_output_dir = os.path.join(result_dir, 'refine_result', 'data')
        os.makedirs(roi_output_dir, exist_ok=True)
        os.makedirs(refine_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s RCNN EVALUATION ----' % epoch_id)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = 0

    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')
    for data in dataloader:
        sample_id = data['sample_id']
        cnt += 1
        assert args.batch_size == 1, 'Only support bs=1 here'
        input_data = {}
        for key, val in data.items():
            if key != 'sample_id':
                input_data[key] = torch.from_numpy(val).contiguous().cuda(
                    non_blocking=True).float()

        roi_boxes3d = input_data['roi_boxes3d']
        roi_scores = input_data['roi_scores']
        if cfg.RCNN.ROI_SAMPLE_JIT:
            for key, val in input_data.items():
                if key in ['gt_iou', 'gt_boxes3d']:
                    continue
                input_data[key] = input_data[key].unsqueeze(dim=0)
        else:
            pts_input = torch.cat(
                (input_data['pts_input'], input_data['pts_features']), dim=-1)
            input_data['pts_input'] = pts_input

        ret_dict = model(input_data)
        rcnn_cls = ret_dict['rcnn_cls']
        rcnn_reg = ret_dict['rcnn_reg']

        # bounding box regression
        anchor_size = MEAN_SIZE
        if cfg.RCNN.SIZE_RES_ON_ROI:
            roi_size = input_data['roi_size']
            anchor_size = roi_size

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d,
            rcnn_reg,
            anchor_size=anchor_size,
            loc_scope=cfg.RCNN.LOC_SCOPE,
            loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True)

        # scoring
        if rcnn_cls.shape[1] == 1:
            raw_scores = rcnn_cls.view(-1)
            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores > cfg.RCNN.SCORE_THRESH).long()
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        # evaluation
        disp_dict = {'mode': mode}
        if not args.test:
            gt_boxes3d = input_data['gt_boxes3d']
            gt_iou = input_data['gt_iou']

            # calculate recall
            gt_num = gt_boxes3d.shape[0]
            if gt_num > 0:
                iou3d = iou3d_utils.boxes_iou3d_gpu(pred_boxes3d, gt_boxes3d)
                gt_max_iou, _ = iou3d.max(dim=0)
                refined_iou, _ = iou3d.max(dim=1)

                for idx, thresh in enumerate(thresh_list):
                    total_recalled_bbox_list[idx] += (gt_max_iou >
                                                      thresh).sum().item()
                recalled_num = (gt_max_iou > 0.7).sum().item()
                total_gt_bbox += gt_num

                iou3d_in = iou3d_utils.boxes_iou3d_gpu(roi_boxes3d, gt_boxes3d)
                gt_max_iou_in, _ = iou3d_in.max(dim=0)

                for idx, thresh in enumerate(thresh_list):
                    total_roi_recalled_bbox_list[idx] += (gt_max_iou_in >
                                                          thresh).sum().item()

            # classification accuracy
            cls_label = (gt_iou > cfg.RCNN.CLS_FG_THRESH).float()
            cls_valid_mask = ((gt_iou >= cfg.RCNN.CLS_FG_THRESH) |
                              (gt_iou <= cfg.RCNN.CLS_BG_THRESH)).float()
            cls_acc = ((pred_classes == cls_label.long()).float() *
                       cls_valid_mask).sum() / max(cls_valid_mask.sum(), 1.0)

            iou_thresh = 0.7 if cfg.CLASSES == 'Car' else 0.5
            cls_label_refined = (gt_iou >= iou_thresh).float()
            cls_acc_refined = (
                pred_classes == cls_label_refined.long()).float().sum() / max(
                    cls_label_refined.shape[0], 1.0)

            total_cls_acc += cls_acc.item()
            total_cls_acc_refined += cls_acc_refined.item()

            disp_dict['recall'] = '%d/%d' % (total_recalled_bbox_list[3],
                                             total_gt_bbox)
            disp_dict['cls_acc_refined'] = '%.2f' % cls_acc_refined.item()

        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        image_shape = dataset.get_image_shape(sample_id)
        if args.save_result:
            # save roi and refine results
            roi_boxes3d_np = roi_boxes3d.cpu().numpy()
            pred_boxes3d_np = pred_boxes3d.cpu().numpy()
            calib = dataset.get_calib(sample_id)

            save_kitti_format(sample_id, calib, roi_boxes3d_np, roi_output_dir,
                              roi_scores, image_shape)
            save_kitti_format(sample_id, calib, pred_boxes3d_np,
                              refine_output_dir,
                              raw_scores.cpu().numpy(), image_shape)

        # NMS and scoring
        # scores thresh
        inds = norm_scores > cfg.RCNN.SCORE_THRESH
        if inds.sum() == 0:
            continue

        pred_boxes3d_selected = pred_boxes3d[inds]
        raw_scores_selected = raw_scores[inds]

        # NMS thresh
        boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(
            pred_boxes3d_selected)
        keep_idx = iou3d_utils.nms_gpu(boxes_bev_selected, raw_scores_selected,
                                       cfg.RCNN.NMS_THRESH)
        pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]

        scores_selected = raw_scores_selected[keep_idx]
        pred_boxes3d_selected, scores_selected = pred_boxes3d_selected.cpu(
        ).numpy(), scores_selected.cpu().numpy()

        calib = dataset.get_calib(sample_id)
        final_total += pred_boxes3d_selected.shape[0]
        save_kitti_format(sample_id, calib, pred_boxes3d_selected,
                          final_output_dir, scores_selected, image_shape)

    progress_bar.close()

    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    logger.info(
        '-------------------performance of epoch %s---------------------' %
        epoch_id)
    logger.info(str(datetime.now()))

    avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
    avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
    avg_det_num = (final_total / max(cnt, 1.0))
    logger.info('final average detections: %.3f' % avg_det_num)
    logger.info('final average cls acc: %.3f' % avg_cls_acc)
    logger.info('final average cls acc refined: %.3f' % avg_cls_acc_refined)
    ret_dict['rcnn_cls_acc'] = avg_cls_acc
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        logger.info(ap_result_str)
        ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)

    return ret_dict
Example #4
0
    ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
    ret_dict['rcnn_avg_num'] = avg_det_num

    for idx, thresh in enumerate(thresh_list):
        cur_roi_recall = total_roi_recalled_bbox_list[idx] / max(
            total_gt_bbox, 1.0)
        logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh, total_roi_recalled_bbox_list[idx], total_gt_bbox,
                     cur_roi_recall))
        ret_dict['rpn_recall(thresh=%.2f)' % thresh] = cur_roi_recall

    for idx, thresh in enumerate(thresh_list):
        cur_recall = total_recalled_bbox_list[idx] / max(total_gt_bbox, 1.0)
        logger.info(
            'total bbox recall(thresh=%.3f): %d / %d = %f' %
            (thresh, total_recalled_bbox_list[idx], total_gt_bbox, cur_recall))
        ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        logger.info(ap_result_str)
        ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)
Example #5
0
def eval():
    args = parse_args()
    print_arguments(args)
    # check whether the installed paddle is compiled with GPU
    # PointRCNN model can only run on GPU
    check_gpu(True)

    load_config(args.cfg)
    if args.set_cfgs is not None:
        set_config_from_list(args.set_cfgs)

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    if args.eval_mode == 'rpn':
        cfg.RPN.ENABLED = True
        cfg.RCNN.ENABLED = False
    elif args.eval_mode == 'rcnn':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = cfg.RPN.FIXED = True
        assert args.batch_size, "batch size must be 1 in rcnn evaluation"
    elif args.eval_mode == 'rcnn_offline':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = False
        assert args.batch_size, "batch size must be 1 in rcnn_offline evaluation"
    else:
        raise NotImplementedError("unkown eval mode: {}".format(
            args.eval_mode))

    place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)

    # build model
    startup = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup):
        with fluid.unique_name.guard():
            eval_model = PointRCNN(cfg, args.batch_size, True, 'TEST')
            eval_model.build()
            eval_loader = eval_model.get_loader()
            eval_feeds = eval_model.get_feeds()
            eval_outputs = eval_model.get_outputs()
    eval_prog = eval_prog.clone(True)

    extra_keys = []
    if args.eval_mode == 'rpn':
        extra_keys.extend(['sample_id', 'rpn_cls_label', 'gt_boxes3d'])
        if args.save_rpn_feature:
            extra_keys.extend([
                'pts_rect',
                'pts_features',
                'pts_input',
            ])
    eval_keys, eval_values = parse_outputs(eval_outputs,
                                           prog=eval_prog,
                                           extra_keys=extra_keys)

    eval_compile_prog = fluid.compiler.CompiledProgram(
        eval_prog).with_data_parallel()

    exe.run(startup)

    # load weights
    if not os.path.isdir(args.weights):
        assert os.path.exists("{}.pdparams".format(args.weights)), \
                "Given resume weight {}.pdparams not exist.".format(args.weights)
    fluid.load(eval_prog, args.weights, exe)

    kitti_feature_dir = os.path.join(args.output_dir, 'features')
    kitti_output_dir = os.path.join(args.output_dir, 'detections', 'data')
    seg_output_dir = os.path.join(args.output_dir, 'seg_result')
    if args.save_rpn_feature:
        if os.path.exists(kitti_feature_dir):
            shutil.rmtree(kitti_feature_dir)
        os.makedirs(kitti_feature_dir)
        if os.path.exists(kitti_output_dir):
            shutil.rmtree(kitti_output_dir)
        os.makedirs(kitti_output_dir)
        if os.path.exists(seg_output_dir):
            shutil.rmtree(seg_output_dir)
        os.makedirs(seg_output_dir)

    # must make sure these dirs existing
    roi_output_dir = os.path.join('./result_dir', 'roi_result', 'data')
    refine_output_dir = os.path.join('./result_dir', 'refine_result', 'data')
    final_output_dir = os.path.join("./result_dir", 'final_result', 'data')
    if not os.path.exists(final_output_dir):
        os.makedirs(final_output_dir)
    if args.save_result:
        if not os.path.exists(roi_output_dir):
            os.makedirs(roi_output_dir)
        if not os.path.exists(refine_output_dir):
            os.makedirs(refine_output_dir)

    # get reader
    kitti_rcnn_reader = KittiRCNNReader(
        data_dir=args.data_dir,
        npoints=cfg.RPN.NUM_POINTS,
        split=cfg.TEST.SPLIT,
        mode='EVAL',
        classes=cfg.CLASSES,
        rcnn_eval_roi_dir=args.rcnn_eval_roi_dir,
        rcnn_eval_feature_dir=args.rcnn_eval_feature_dir)
    eval_reader = kitti_rcnn_reader.get_multiprocess_reader(
        args.batch_size, eval_feeds)
    eval_loader.set_sample_list_generator(eval_reader, place)

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    queue = multiprocessing.Queue(128)
    mgr = multiprocessing.Manager()
    lock = multiprocessing.Lock()
    mdict = mgr.dict()
    if cfg.RPN.ENABLED:
        mdict['exit_proc'] = 0
        mdict['total_gt_bbox'] = 0
        mdict['total_cnt'] = 0
        mdict['total_rpn_iou'] = 0
        for i in range(len(thresh_list)):
            mdict['total_recalled_bbox_list_{}'.format(i)] = 0

        p_list = []
        for i in range(METRIC_PROC_NUM):
            p_list.append(
                multiprocessing.Process(target=rpn_metric,
                                        args=(queue, mdict, lock, thresh_list,
                                              args.save_rpn_feature,
                                              kitti_feature_dir,
                                              seg_output_dir, kitti_output_dir,
                                              kitti_rcnn_reader, cfg.CLASSES)))
            p_list[-1].start()

    if cfg.RCNN.ENABLED:
        for i in range(len(thresh_list)):
            mdict['total_recalled_bbox_list_{}'.format(i)] = 0
            mdict['total_roi_recalled_bbox_list_{}'.format(i)] = 0
        mdict['exit_proc'] = 0
        mdict['total_cls_acc'] = 0
        mdict['total_cls_acc_refined'] = 0
        mdict['total_det_num'] = 0
        mdict['total_gt_bbox'] = 0
        p_list = []
        for i in range(METRIC_PROC_NUM):
            p_list.append(
                multiprocessing.Process(
                    target=rcnn_metric,
                    args=(queue, mdict, lock, thresh_list, kitti_rcnn_reader,
                          roi_output_dir, refine_output_dir, final_output_dir,
                          args.save_result)))
            p_list[-1].start()

    try:
        eval_loader.start()
        eval_iter = 0
        start_time = time.time()

        cur_time = time.time()
        while True:
            eval_outs = exe.run(eval_compile_prog,
                                fetch_list=eval_values,
                                return_numpy=False)
            rets_dict = {
                k: (np.array(v), v.recursive_sequence_lengths())
                for k, v in zip(eval_keys, eval_outs)
            }
            run_time = time.time() - cur_time
            cur_time = time.time()
            queue.put(rets_dict)
            eval_iter += 1

            logger.info("[EVAL] iter {}, time: {:.2f}".format(
                eval_iter, run_time))

    except fluid.core.EOFException:
        # terminate metric process
        for i in range(METRIC_PROC_NUM):
            queue.put(None)
        while mdict['exit_proc'] < METRIC_PROC_NUM:
            time.sleep(1)
        for p in p_list:
            if p.is_alive():
                p.join()

        end_time = time.time()
        logger.info(
            "[EVAL] total {} iter finished, average time: {:.2f}".format(
                eval_iter, (end_time - start_time) / float(eval_iter)))

        if cfg.RPN.ENABLED:
            avg_rpn_iou = mdict['total_rpn_iou'] / max(len(kitti_rcnn_reader),
                                                       1.)
            logger.info("average rpn iou: {:.3f}".format(avg_rpn_iou))
            total_gt_bbox = float(max(mdict['total_gt_bbox'], 1.0))
            for idx, thresh in enumerate(thresh_list):
                recall = mdict['total_recalled_bbox_list_{}'.format(
                    idx)] / total_gt_bbox
                logger.info(
                    "total bbox recall(thresh={:.3f}): {} / {} = {:.3f}".
                    format(thresh,
                           mdict['total_recalled_bbox_list_{}'.format(idx)],
                           mdict['total_gt_bbox'], recall))

        if cfg.RCNN.ENABLED:
            cnt = float(max(eval_iter, 1.0))
            avg_cls_acc = mdict['total_cls_acc'] / cnt
            avg_cls_acc_refined = mdict['total_cls_acc_refined'] / cnt
            avg_det_num = mdict['total_det_num'] / cnt

            logger.info("avg_cls_acc: {}".format(avg_cls_acc))
            logger.info("avg_cls_acc_refined: {}".format(avg_cls_acc_refined))
            logger.info("avg_det_num: {}".format(avg_det_num))

            total_gt_bbox = float(max(mdict['total_gt_bbox'], 1.0))
            for idx, thresh in enumerate(thresh_list):
                cur_roi_recall = mdict['total_roi_recalled_bbox_list_{}'.
                                       format(idx)] / total_gt_bbox
                logger.info(
                    'total roi bbox recall(thresh=%.3f): %d / %d = %f' %
                    (thresh,
                     mdict['total_roi_recalled_bbox_list_{}'.format(idx)],
                     total_gt_bbox, cur_roi_recall))

            for idx, thresh in enumerate(thresh_list):
                cur_recall = mdict['total_recalled_bbox_list_{}'.format(
                    idx)] / total_gt_bbox
                logger.info(
                    'total bbox recall(thresh=%.2f) %d / %.2f = %.4f' %
                    (thresh, mdict['total_recalled_bbox_list_{}'.format(idx)],
                     total_gt_bbox, cur_recall))

            split_file = os.path.join('./data/KITTI', 'ImageSets', 'val.txt')
            image_idx_list = [x.strip() for x in open(split_file).readlines()]
            for k in range(image_idx_list.__len__()):
                cur_file = os.path.join(final_output_dir,
                                        '%s.txt' % image_idx_list[k])
                if not os.path.exists(cur_file):
                    with open(cur_file, 'w') as temp_f:
                        pass

            if float(sys.version[:3]) >= 3.6:
                label_dir = os.path.join('./data/KITTI/object/training',
                                         'label_2')
                split_file = os.path.join('./data/KITTI', 'ImageSets',
                                          'val.txt')
                final_output_dir = os.path.join("./result_dir", 'final_result',
                                                'data')
                name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}

                from tools.kitti_object_eval_python.evaluate import evaluate as kitti_evaluate
                ap_result_str, ap_dict = kitti_evaluate(
                    label_dir,
                    final_output_dir,
                    label_split_file=split_file,
                    current_class=name_to_class["Car"])

                logger.info("KITTI evaluate: {}, {}".format(
                    ap_result_str, ap_dict))

            else:
                logger.info(
                    "KITTI mAP only support python version >= 3.6, users can "
                    "run 'python3 tools/kitti_eval.py' to evaluate KITTI mAP.")

    finally:
        eval_loader.reset()
Example #6
0
def eval_one_epoch_joint(model, dataloader, epoch_id, result_dir, logger):
    np.random.seed(666)
    MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
    mode = 'TEST' if args.test else 'EVAL'

    final_output_dir = os.path.join(result_dir, 'final_result', 'data')

    if os.path.exists(final_output_dir): shutil.rmtree(final_output_dir)
    os.makedirs(final_output_dir, exist_ok=True)

    logger.info('---- EPOCH %s JOINT EVALUATION ----' % epoch_id)
    logger.info('==> Output file: %s' % result_dir)
    model.eval()

    thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
    total_recalled_bbox_list, total_gt_bbox = [0] * 5, 0
    total_roi_recalled_bbox_list = [0] * 5
    dataset = dataloader.dataset
    cnt = final_total = total_cls_acc = total_cls_acc_refined = total_rpn_iou = 0
    obj_num = 0
    progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval')

    iou_list = []
    iou_p_score_list = []
    rcnn_p_score_list = []
    prop_count = 0
    for data in dataloader:

        # Loading sample
        sample_id_list, pts_input = data['sample_id'], data['pts_input']
        sample_id = sample_id_list[0]
        cnt += len(sample_id_list)
        #if cnt < 118: continue
        #load label
        if not args.test:
            gt_boxes3d = data['gt_boxes3d']
            obj_num += gt_boxes3d.shape[1]
            # print(obj_num)
            if gt_boxes3d.shape[1] == 0:  # (B, M, 7)
                pass
            else:
                gt_boxes3d = gt_boxes3d

        # rpn model inference
        inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
        #inputs = inputs[:,torch.argsort(-inputs[0,:,2])]
        input_data = {'pts_input': inputs}
        ret_dict = model.rpn_forward(input_data)
        rpn_cls, rpn_reg = ret_dict['rpn_cls'], ret_dict['rpn_reg']
        rpn_backbone_xyz, rpn_backbone_features = ret_dict[
            'backbone_xyz'], ret_dict['backbone_features']

        # stage score parsing
        rpn_scores_raw = rpn_cls[:, :, 0]
        rpn_scores_norm = torch.sigmoid(rpn_cls[:, :, 0])
        rcnn_input_scores = rpn_scores_norm.view(-1).clone()
        inputs = inputs.view(-1, inputs.shape[-1])
        rpn_backbone_features = rpn_backbone_features.view(
            -1, rpn_backbone_features.shape[-2])
        rpn_backbone_xyz = rpn_backbone_xyz.view(-1,
                                                 rpn_backbone_xyz.shape[-1])

        # if VISUAL:
        #     order = torch.argsort(-rpn_scores_norm).view(-1)
        #     inputs = inputs.view(-1,inputs.shape[-1])[order]
        #     rpn_scores_norm = rpn_scores_norm.view(-1)[order]
        #     rpn_backbone_features = rpn_backbone_features.view(-1,rpn_backbone_features.shape[-1])[order]
        #
        #     norm_feature = F.normalize(rpn_backbone_features)
        #     similarity = norm_feature.mm(norm_feature.t())
        #
        #     inputs_plt = inputs.detach().cpu().numpy()
        #     scores_plt = rpn_scores_norm.detach().cpu().numpy()
        #     similarity_plt = similarity.detach().cpu().numpy()
        #
        #
        #     fig = plt.figure(figsize=(10, 10))
        #     plt.axes(facecolor='silver')
        #     plt.axis([-30,30,0,70])
        #     plt.title('point_regressed_center %06d'%sample_id)
        #     plt.scatter(inputs_plt[:, 0], inputs_plt[:, 2], s=15, c=scores_plt[:], edgecolor='none',
        #                 cmap=plt.get_cmap('rainbow'), alpha=1, marker='.', vmin=0, vmax=1)
        #     if args.test==False:
        #         gt_boxes3d = gt_boxes3d.reshape(-1,7)
        #         plt.scatter(gt_boxes3d[:, 0], gt_boxes3d[:, 2], s=200, c='blue',
        #                     alpha=0.5, marker='+', vmin=-1, vmax=1)
        #     plt.show()
        #
        #     for i in range(similarity_plt.shape[0]):
        #         fig = plt.figure(figsize=(10, 10))
        #         plt.axes(facecolor='silver')
        #         plt.axis([-30, 30, 0, 70])
        #         sm_plt = similarity_plt[i]
        #         plt.scatter(inputs_plt[i, 0].reshape(-1), inputs_plt[i, 2].reshape(-1), s=400, c='blue',
        #                     alpha=0.5, marker='+', vmin=-1, vmax=1)
        #         plt.scatter(inputs_plt[:, 0], inputs_plt[:, 2], s=15, c=(sm_plt[:]+scores_plt[:])/2, edgecolor='none',
        #                     cmap=plt.get_cmap('rainbow'), alpha=1, marker='.', vmin=0, vmax=1)
        #         plt.show()

        # thresh select and jump out
        # rpn_mask = rpn_scores_norm.view(-1) > cfg.RPN.SCORE_THRESH
        # if rpn_mask.float().sum() == 0: continue
        # rpn_scores_raw = rpn_scores_raw.view(-1)[rpn_mask]
        # rpn_scores_norm = rpn_scores_norm.view(-1)[rpn_mask]
        # rpn_reg = rpn_reg.view(-1, rpn_reg.shape[-1])[rpn_mask]
        # rpn_backbone_xyz = rpn_backbone_xyz.view(-1, rpn_backbone_xyz.shape[-1])[rpn_mask]

        # generate rois

        rpn_rois = decode_center_target(
            rpn_backbone_xyz,
            rpn_reg.view(-1, rpn_reg.shape[-1]),
            loc_scope=cfg.RPN.LOC_SCOPE,
            loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
        ).view(-1, 3)
        rpn_reg_dist = (rpn_rois - rpn_backbone_xyz).clone()
        #similarity = torch.cosine_similarity(rpn_backbone_xyz[:, [0, 2]], rpn_reg_dist[:, [0, 2]], dim=1)

        # # thresh select and jump out
        rpn_mask = (rpn_scores_norm.view(-1) > cfg.RPN.SCORE_THRESH) & (
            rpn_reg_dist[:, [0, 2]].pow(2).sum(-1).sqrt() > 0.2)  #\
        #& (similarity > -0.7)
        if rpn_mask.float().sum() == 0: continue
        rpn_scores_raw = rpn_scores_raw.view(-1)[rpn_mask]
        rpn_scores_norm = rpn_scores_norm.view(-1)[rpn_mask]
        rpn_rois = rpn_rois[rpn_mask]
        rpn_backbone_xyz = rpn_backbone_xyz.view(
            -1, rpn_backbone_xyz.shape[-1])[rpn_mask]

        # radius NMS
        # sort by center score
        sort_points = torch.argsort(-rpn_scores_raw)
        rpn_rois = rpn_rois[sort_points]
        rpn_scores_norm = rpn_scores_norm[sort_points]
        rpn_scores_raw = rpn_scores_raw[sort_points]

        if rpn_rois.shape[0] > 1:
            keep_id = [0]
            prop_prop_distance = distance_2(rpn_rois[:, [0, 2]],
                                            rpn_rois[:, [0, 2]])
            for i in range(1, rpn_rois.shape[0]):
                #if torch.min(prop_prop_distance[:i, i], dim=-1)[0] > 0.3:
                if torch.min(prop_prop_distance[keep_id, i], dim=-1)[0] > 0.3:
                    keep_id.append(i)
            rpn_center = rpn_rois[keep_id][:, [0, 2]]
            rpn_scores_norm = rpn_scores_norm[keep_id]
            rpn_scores_raw = rpn_scores_raw[keep_id]

        else:
            rpn_center = rpn_rois[:, [0, 2]]
            rpn_scores_norm = rpn_scores_norm
            rpn_scores_raw = rpn_scores_raw

        # #rcnn input select:
        point_center_distance = distance_2(rpn_center, inputs[:, [0, 2]])
        cur_proposal_points_index = (torch.min(point_center_distance,
                                               dim=-1)[0] < 4.0)

        point_center_distance = point_center_distance[
            cur_proposal_points_index]
        inputs = inputs[cur_proposal_points_index]
        rcnn_input_scores = rcnn_input_scores.view(
            -1)[cur_proposal_points_index]

        if VISUAL:
            inputs_plt = inputs.detach().cpu().numpy()
            scores_plt = rcnn_input_scores.detach().cpu().numpy()
            # point_center= rpn_center[rpn_scores_norm > 0.5]
            # point_center_score = rpn_scores_norm[rpn_scores_norm > 0.5]
            point_center = rpn_center
            point_center_score = rpn_scores_norm
            fig = plt.figure(figsize=(10, 10))
            plt.axes(facecolor='silver')
            plt.axis([-30, 30, 0, 70])
            point_center_plt = point_center.cpu().numpy()
            plt.title('point_regressed_center %06d' % sample_id)
            plt.scatter(inputs_plt[:, 0],
                        inputs_plt[:, 2],
                        s=15,
                        c=scores_plt[:],
                        edgecolor='none',
                        cmap=plt.get_cmap('rainbow'),
                        alpha=1,
                        marker='.',
                        vmin=0,
                        vmax=1)
            if point_center.shape[0] > 0:
                plt.scatter(point_center_plt[:, 0],
                            point_center_plt[:, 1],
                            s=200,
                            c='white',
                            alpha=0.5,
                            marker='x',
                            vmin=-1,
                            vmax=1)
            if args.test == False:
                gt_boxes3d = gt_boxes3d.reshape(-1, 7)
                plt.scatter(gt_boxes3d[:, 0],
                            gt_boxes3d[:, 2],
                            s=200,
                            c='blue',
                            alpha=0.5,
                            marker='+',
                            vmin=-1,
                            vmax=1)
            plt.savefig('../visual/rpn.jpg')

        # RCNN stage
        box_list = []
        raw_score_list = []
        iou_score_list = []
        inputs[:, 1] -= 1.65
        point_center_distance = distance_2(rpn_center[:, :], inputs[:, [0, 2]])
        #for c in range(min(rpn_center.shape[0],100)):
        prop_count += rpn_center.shape[0]
        print('num %d' % (prop_count / float(cnt)))
        for c in range(rpn_center.shape[0]):
            # rcnn input generate
            cur_input = inputs.clone()
            cur_input_score = rcnn_input_scores.clone()

            # if COSINE_DISTANCE:
            #     cur_center_points_index = ((point_center_distance[:, c] < 4.0) & \
            #                                (point_prop_cos_matrix[:, c] > COS_THRESH) | \
            #                                (point_center_distance[:, c].view(-1) < 0.7)).view(-1)
            # else:
            cur_center_points_index = (point_center_distance[:, c] <
                                       4.0).view(-1)
            if cur_center_points_index.long().sum() == 0: continue

            cur_center_points_xyz = cur_input[cur_center_points_index, :3]
            cur_center_points_xyz[:, 0] -= rpn_center[c, 0]
            cur_center_points_xyz[:, 2] -= rpn_center[c, 1]
            cur_center_points_r = cur_input[cur_center_points_index,
                                            3].view(-1, 1)
            cur_center_points_mask = (cur_input_score[cur_center_points_index]
                                      > 0.5).view(-1, 1).float()

            # # easy sample sampling
            # if pts_input.shape[0]>512:
            #     cur_input = torch.cat((cur_center_points_xyz, cur_center_points_r,
            #                            (cur_input_score[cur_center_points_index] > 0.5).view(-1, 1).float()), dim=-1)
            #     pts_input = cur_input
            #     pts_input = pts_input[:min(pts_input.shape[0], 2000), :]
            #     pts_input = pts_input[:, :]
            #     sample_index = fps(pts_input[:, 0:3].contiguous(), ratio=min(512 / pts_input.shape[0], 0.99),
            #                        random_start=False)
            #     perm = sample_index
            #     while sample_index.shape[0] < 512:
            #         sample_index = torch.cat(
            #             (sample_index, perm[:min(perm.shape[0], 512 - sample_index.shape[0])]), dim=0)
            #
            #     cur_center_points_xyz = pts_input[sample_index, 0:3]
            #     cur_center_points_r = pts_input[sample_index, 3].reshape(-1, 1)
            #     cur_center_points_mask = pts_input[sample_index, 4].reshape(-1, 1)

            cur_center_points_xyz = cur_center_points_xyz.unsqueeze(0).float()
            cur_center_points_r = cur_center_points_r.unsqueeze(0).float()
            cur_center_points_mask = cur_center_points_mask.unsqueeze(
                0).float() - 0.5

            input_data = {
                'cur_box_point': cur_center_points_xyz,
                'cur_box_reflect': cur_center_points_r,
                'train_mask': cur_center_points_mask,
            }

            # # globaly random sampling
            # pts_input = pts_input[:min(pts_input.shape[0], self.npoints), :]
            # sample_index = np.arange(0, pts_input.shape[0], 1).astype(np.int)
            # perm = np.copy(sample_index)
            # while sample_index.shape[0] < self.npoints:
            #     sample_index = np.concatenate(
            #         (sample_index, perm[:min(perm.shape[0], self.npoints - sample_index.shape[0])]))
            #
            # cur_box_point = pts_input[sample_index, 0:3]
            # cur_box_reflect = pts_input[sample_index, 3].reshape(-1, 1)
            # cur_prob_mask = pts_input[sample_index, 4].reshape(-1, 1)
            # gt_mask = pts_input[sample_index, 5].reshape(-1, 1)

            # rcnn model inference
            ret_dict = model.rcnn_forward(input_data)
            rcnn_cls = ret_dict['rcnn_cls']
            ioun_cls = ret_dict['ioun_cls']
            rcnn_reg = ret_dict['rcnn_reg']
            rcnn_iou = ret_dict['rcnn_iou']
            rcnn_ref = ret_dict['rcnn_ref'].view(1, 1, -1)
            rcnn_box3d = ret_dict['pred_boxes3d']
            refined_box = ret_dict['refined_box']

            rcnn_box3d = refined_box
            rcnn_box3d[:, :, 6] = rcnn_box3d[:, :, 6] % (np.pi * 2)
            if rcnn_box3d[:, :, 6] > np.pi: rcnn_box3d[:, :, 6] -= np.pi * 2

            rcnn_box3d[:, :, 0] += rpn_center[c][0]
            rcnn_box3d[:, :, 2] += rpn_center[c][1]
            rcnn_box3d[:, :, 1] += 1.65

            box_list.append(rcnn_box3d)

            raw_score_list.append(rcnn_cls.view(1, 1))
            #raw_score_list.append(ioun_cls.view(1,1))

            iou_score_list.append(rcnn_iou.view(1, 1))

        rcnn_box3d = torch.cat((box_list), dim=1)
        raw_rcnn_score = torch.cat((raw_score_list),
                                   dim=0).unsqueeze(0).float()
        norm_ioun_score = torch.cat((iou_score_list),
                                    dim=0).unsqueeze(0).float()

        # scoring
        pred_boxes3d = rcnn_box3d
        norm_ioun_score = norm_ioun_score
        raw_rcnn_score = raw_rcnn_score
        norm_rcnn_score = torch.sigmoid(raw_rcnn_score)

        # scores thresh
        pred_h = pred_boxes3d[:, :, 3].view(-1)
        pred_w = pred_boxes3d[:, :, 4].view(-1)
        pred_l = pred_boxes3d[:, :, 5].view(-1)
        inds = (norm_rcnn_score > cfg.RCNN.SCORE_THRESH) & (
            norm_ioun_score > cfg.IOUN.SCORE_THRESH)
        inds = inds.view(-1)
        #size filiter
        # inds = inds & \
        #         (pred_h > 1.2) & (pred_h < 2.2) & \
        #         (pred_w > 1.3) & (pred_w < 2.0) & \
        #         (pred_l > 2.2) & (pred_l < 5.0)
        inds = inds & \
                (pred_h > 1.1) & (pred_h < 2.3) & \
                (pred_w > 1.2) & (pred_w < 2.1) & \
                (pred_l > 2.1) & (pred_l < 5.1)

        pred_boxes3d = pred_boxes3d[:, inds]
        norm_rcnn_score = norm_rcnn_score[:, inds]
        norm_ioun_score = norm_ioun_score[:, inds]
        raw_rcnn_score = raw_rcnn_score[:, inds]

        if pred_boxes3d.shape[1] == 0: continue
        # evaluation
        recalled_num = gt_num = 0

        if not args.test:
            gt_boxes3d = data['gt_boxes3d']

            for k in range(1):
                # calculate recall
                cur_gt_boxes3d = gt_boxes3d[k]
                tmp_idx = cur_gt_boxes3d.__len__() - 1

                while tmp_idx >= 0 and cur_gt_boxes3d[tmp_idx].sum() == 0:
                    tmp_idx -= 1

                if tmp_idx >= 0:
                    cur_gt_boxes3d = cur_gt_boxes3d[:tmp_idx + 1]

                    cur_gt_boxes3d = torch.from_numpy(cur_gt_boxes3d).cuda(
                        non_blocking=True).float()
                    _, iou3d = iou3d_utils.boxes_iou3d_gpu(
                        pred_boxes3d[k], cur_gt_boxes3d)
                    gt_max_iou, _ = iou3d.max(dim=0)
                    refined_iou, _ = iou3d.max(dim=1)

                    iou_list.append(refined_iou.view(-1, 1))
                    iou_p_score_list.append(norm_ioun_score.view(-1, 1))
                    rcnn_p_score_list.append(norm_rcnn_score.view(-1, 1))

                    for idx, thresh in enumerate(thresh_list):
                        total_recalled_bbox_list[idx] += (gt_max_iou >
                                                          thresh).sum().item()
                    recalled_num += (gt_max_iou > 0.7).sum().item()
                    gt_num += cur_gt_boxes3d.shape[0]
                    total_gt_bbox += cur_gt_boxes3d.shape[0]

        if cnt == 1000:
            iou_clloe = torch.cat(iou_list, dim=0).detach().cpu().numpy()
            iou_score_clloe = torch.cat(iou_p_score_list,
                                        dim=0).detach().cpu().numpy()
            plt.axis([-.1, 1.1, -.1, 1.1])
            plt.scatter(iou_clloe,
                        iou_score_clloe,
                        s=20,
                        c='blue',
                        edgecolor='none',
                        cmap=plt.get_cmap('YlOrRd'),
                        alpha=1,
                        marker='.')
            plt.savefig(os.path.join(result_dir, 'distributercnn.png'))

        disp_dict = {
            'mode': mode,
            'recall': '%d/%d' % (total_recalled_bbox_list[3], total_gt_bbox)
        }
        progress_bar.set_postfix(disp_dict)
        progress_bar.update()

        if VISUAL:
            fig, ax = plt.subplots(figsize=(10, 10))
            inputs_plt = inputs.detach().cpu().numpy()
            #plt.axes(facecolor='silver')
            plt.axis([-35, 35, 0, 70])
            plt.scatter(inputs_plt[:, 0],
                        inputs_plt[:, 2],
                        s=15,
                        c=inputs_plt[:, 1],
                        edgecolor='none',
                        cmap=plt.get_cmap('Blues'),
                        alpha=1,
                        marker='.',
                        vmin=-1,
                        vmax=2)
            pred_boxes3d_numpy = pred_boxes3d[0].detach().cpu().numpy()
            pred_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                pred_boxes3d_numpy, rotate=True)
            for o in range(pred_boxes3d_corner.shape[0]):
                print_box_corner = pred_boxes3d_corner[o]

                x1, x2, x3, x4 = print_box_corner[0:4, 0]
                z1, z2, z3, z4 = print_box_corner[0:4, 2]

                polygon = np.zeros([5, 2], dtype=np.float32)
                polygon[0, 0] = x1
                polygon[1, 0] = x2
                polygon[2, 0] = x3
                polygon[3, 0] = x4
                polygon[4, 0] = x1

                polygon[0, 1] = z1
                polygon[1, 1] = z2
                polygon[2, 1] = z3
                polygon[3, 1] = z4
                polygon[4, 1] = z1

                line1 = [(x1, z1), (x2, z2)]
                line2 = [(x2, z2), (x3, z3)]
                line3 = [(x3, z3), (x4, z4)]
                line4 = [(x4, z4), (x1, z1)]
                (line1_xs, line1_ys) = zip(*line1)
                (line2_xs, line2_ys) = zip(*line2)
                (line3_xs, line3_ys) = zip(*line3)
                (line4_xs, line4_ys) = zip(*line4)
                ax.add_line(
                    Line2D(line1_xs, line1_ys, linewidth=1, color='green'))
                ax.add_line(
                    Line2D(line2_xs, line2_ys, linewidth=1, color='red'))
                ax.add_line(
                    Line2D(line3_xs, line3_ys, linewidth=1, color='red'))
                ax.add_line(
                    Line2D(line4_xs, line4_ys, linewidth=1, color='red'))

                # gt visualize

            if args.test == False and data['gt_boxes3d'].shape[1] > 0:
                gt_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                    data['gt_boxes3d'].reshape(-1, 7), rotate=True)

                for o in range(gt_boxes3d_corner.shape[0]):
                    print_box_corner = gt_boxes3d_corner[o]

                    x1, x2, x3, x4 = print_box_corner[0:4, 0]
                    z1, z2, z3, z4 = print_box_corner[0:4, 2]

                    polygon = np.zeros([5, 2], dtype=np.float32)
                    polygon[0, 0] = x1
                    polygon[1, 0] = x2
                    polygon[2, 0] = x3
                    polygon[3, 0] = x4
                    polygon[4, 0] = x1

                    polygon[0, 1] = z1
                    polygon[1, 1] = z2
                    polygon[2, 1] = z3
                    polygon[3, 1] = z4
                    polygon[4, 1] = z1

                    line1 = [(x1, z1), (x2, z2)]
                    line2 = [(x2, z2), (x3, z3)]
                    line3 = [(x3, z3), (x4, z4)]
                    line4 = [(x4, z4), (x1, z1)]
                    (line1_xs, line1_ys) = zip(*line1)
                    (line2_xs, line2_ys) = zip(*line2)
                    (line3_xs, line3_ys) = zip(*line3)
                    (line4_xs, line4_ys) = zip(*line4)
                    ax.add_line(
                        Line2D(line1_xs, line1_ys, linewidth=1,
                               color='yellow'))
                    ax.add_line(
                        Line2D(line2_xs, line2_ys, linewidth=1,
                               color='purple'))
                    ax.add_line(
                        Line2D(line3_xs, line3_ys, linewidth=1,
                               color='purple'))
                    ax.add_line(
                        Line2D(line4_xs, line4_ys, linewidth=1,
                               color='purple'))
            plt.savefig('../visual/rcnn.jpg')

        # scores thresh
        inds = (norm_rcnn_score > cfg.RCNN.SCORE_THRESH) & (
            norm_ioun_score > cfg.IOUN.SCORE_THRESH)
        #inds = (norm_ioun_score > cfg.IOUN.SCORE_THRESH)

        for k in range(1):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            norm_iou_scores_selected = norm_ioun_score[k, cur_inds]
            raw_rcnn_score_selected = raw_rcnn_score[k, cur_inds]

            #traditional nms
            # NMS thresh rotated nms
            # boxes_bev_selected = kitti_utils.boxes3d_to_bev_torch(pred_boxes3d_selected)
            # #score NMS
            # # boxes_bev_selected[:,-1] += np.pi/2
            # keep_idx = iou3d_utils.nms_normal_gpu(boxes_bev_selected, norm_iou_scores_selected, cfg.RCNN.NMS_THRESH).view(-1)
            # pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            # norm_iou_scores_selected = norm_iou_scores_selected[keep_idx]
            # raw_rcnn_score_selected = raw_rcnn_score_selected[keep_idx]

            #self NMS
            sort_boxes = torch.argsort(-norm_iou_scores_selected.view(-1))
            pred_boxes3d_selected = pred_boxes3d_selected[sort_boxes]
            norm_iou_scores_selected = norm_iou_scores_selected[sort_boxes]

            if pred_boxes3d_selected.shape[0] > 1:
                keep_id = [0]
                iou2d, iou3d = iou3d_utils.boxes_iou3d_gpu(
                    pred_boxes3d_selected, pred_boxes3d_selected)
                for i in range(1, pred_boxes3d_selected.shape[0]):
                    # if torch.min(prop_prop_distance[:i, i], dim=-1)[0] > 0.3:
                    if torch.max(iou2d[keep_id, i], dim=-1)[0] < 0.01:
                        keep_id.append(i)
                pred_boxes3d_selected = pred_boxes3d_selected[keep_id]
                norm_iou_scores_selected = norm_iou_scores_selected[keep_id]
            else:
                pred_boxes3d_selected = pred_boxes3d_selected
                norm_iou_scores_selected = norm_iou_scores_selected

            pred_boxes3d_selected, norm_iou_scores_selected = pred_boxes3d_selected.cpu(
            ).numpy(), norm_iou_scores_selected.cpu().numpy()

            cur_sample_id = sample_id
            calib = dataset.get_calib(cur_sample_id)
            final_total += pred_boxes3d_selected.shape[0]
            image_shape = dataset.get_image_shape(cur_sample_id)
            save_kitti_format(cur_sample_id, calib, pred_boxes3d_selected,
                              final_output_dir, norm_iou_scores_selected,
                              image_shape)

            if VISUAL:
                fig, ax = plt.subplots(figsize=(10, 10))
                inputs_plt = inputs.detach().cpu().numpy()
                # plt.axes(facecolor='silver')
                plt.axis([-35, 35, 0, 70])
                plt.scatter(inputs_plt[:, 0],
                            inputs_plt[:, 2],
                            s=15,
                            c=inputs_plt[:, 1],
                            edgecolor='none',
                            cmap=plt.get_cmap('Blues'),
                            alpha=1,
                            marker='.',
                            vmin=-1,
                            vmax=2)
                pred_boxes3d_numpy = pred_boxes3d_selected
                pred_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                    pred_boxes3d_numpy, rotate=True)
                for o in range(pred_boxes3d_corner.shape[0]):
                    print_box_corner = pred_boxes3d_corner[o]

                    x1, x2, x3, x4 = print_box_corner[0:4, 0]
                    z1, z2, z3, z4 = print_box_corner[0:4, 2]

                    polygon = np.zeros([5, 2], dtype=np.float32)
                    polygon[0, 0] = x1
                    polygon[1, 0] = x2
                    polygon[2, 0] = x3
                    polygon[3, 0] = x4
                    polygon[4, 0] = x1

                    polygon[0, 1] = z1
                    polygon[1, 1] = z2
                    polygon[2, 1] = z3
                    polygon[3, 1] = z4
                    polygon[4, 1] = z1

                    line1 = [(x1, z1), (x2, z2)]
                    line2 = [(x2, z2), (x3, z3)]
                    line3 = [(x3, z3), (x4, z4)]
                    line4 = [(x4, z4), (x1, z1)]
                    (line1_xs, line1_ys) = zip(*line1)
                    (line2_xs, line2_ys) = zip(*line2)
                    (line3_xs, line3_ys) = zip(*line3)
                    (line4_xs, line4_ys) = zip(*line4)
                    ax.add_line(
                        Line2D(line1_xs, line1_ys, linewidth=1, color='green'))
                    ax.add_line(
                        Line2D(line2_xs, line2_ys, linewidth=1, color='red'))
                    ax.add_line(
                        Line2D(line3_xs, line3_ys, linewidth=1, color='red'))
                    ax.add_line(
                        Line2D(line4_xs, line4_ys, linewidth=1, color='red'))

                    # gt visualize

                if args.test == False and data['gt_boxes3d'].shape[1] > 0:
                    gt_boxes3d_corner = kitti_utils.boxes3d_to_corners3d(
                        data['gt_boxes3d'].reshape(-1, 7), rotate=True)

                    for o in range(gt_boxes3d_corner.shape[0]):
                        print_box_corner = gt_boxes3d_corner[o]

                        x1, x2, x3, x4 = print_box_corner[0:4, 0]
                        z1, z2, z3, z4 = print_box_corner[0:4, 2]

                        polygon = np.zeros([5, 2], dtype=np.float32)
                        polygon[0, 0] = x1
                        polygon[1, 0] = x2
                        polygon[2, 0] = x3
                        polygon[3, 0] = x4
                        polygon[4, 0] = x1

                        polygon[0, 1] = z1
                        polygon[1, 1] = z2
                        polygon[2, 1] = z3
                        polygon[3, 1] = z4
                        polygon[4, 1] = z1

                        line1 = [(x1, z1), (x2, z2)]
                        line2 = [(x2, z2), (x3, z3)]
                        line3 = [(x3, z3), (x4, z4)]
                        line4 = [(x4, z4), (x1, z1)]
                        (line1_xs, line1_ys) = zip(*line1)
                        (line2_xs, line2_ys) = zip(*line2)
                        (line3_xs, line3_ys) = zip(*line3)
                        (line4_xs, line4_ys) = zip(*line4)
                        ax.add_line(
                            Line2D(line1_xs,
                                   line1_ys,
                                   linewidth=1,
                                   color='yellow'))
                        ax.add_line(
                            Line2D(line2_xs,
                                   line2_ys,
                                   linewidth=1,
                                   color='purple'))
                        ax.add_line(
                            Line2D(line3_xs,
                                   line3_ys,
                                   linewidth=1,
                                   color='purple'))
                        ax.add_line(
                            Line2D(line4_xs,
                                   line4_ys,
                                   linewidth=1,
                                   color='purple'))
                plt.savefig('../visual/ioun.jpg')

    progress_bar.close()
    # dump empty files
    split_file = os.path.join(dataset.imageset_dir, '..', 'ImageSets',
                              dataset.split + '.txt')
    split_file = os.path.abspath(split_file)
    image_idx_list = [x.strip() for x in open(split_file).readlines()]
    empty_cnt = 0
    for k in range(image_idx_list.__len__()):
        cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
        if not os.path.exists(cur_file):
            with open(cur_file, 'w') as temp_f:
                pass
            empty_cnt += 1
            logger.info('empty_cnt=%d: dump empty file %s' %
                        (empty_cnt, cur_file))

    ret_dict = {'empty_cnt': empty_cnt}

    if not args.eval_all:
        logger.info(
            '-------------------performance of epoch %s---------------------' %
            epoch_id)
        logger.info(str(datetime.now()))

        avg_rpn_iou = (total_rpn_iou / max(cnt, 1.0))
        avg_cls_acc = (total_cls_acc / max(cnt, 1.0))
        avg_cls_acc_refined = (total_cls_acc_refined / max(cnt, 1.0))
        avg_det_num = (final_total / max(len(dataset), 1.0))
        logger.info('final average detections: %.3f' % avg_det_num)
        logger.info('final average rpn_iou refined: %.3f' % avg_rpn_iou)
        logger.info('final average cls acc: %.3f' % avg_cls_acc)
        logger.info('final average cls acc refined: %.3f' %
                    avg_cls_acc_refined)
        ret_dict['rpn_iou'] = avg_rpn_iou
        ret_dict['rcnn_cls_acc'] = avg_cls_acc
        ret_dict['rcnn_cls_acc_refined'] = avg_cls_acc_refined
        ret_dict['rcnn_avg_num'] = avg_det_num

        for idx, thresh in enumerate(thresh_list):
            cur_recall = total_recalled_bbox_list[idx] / max(
                total_gt_bbox, 1.0)
            logger.info('total bbox recall(thresh=%.3f): %d / %d = %f' %
                        (thresh, total_recalled_bbox_list[idx], total_gt_bbox,
                         cur_recall))
            ret_dict['rcnn_recall(thresh=%.2f)' % thresh] = cur_recall
            if thresh == 0.7:
                recall = cur_recall

    if cfg.TEST.SPLIT != 'test':
        logger.info('Averate Precision:')
        name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        ap_result_str, ap_dict = kitti_evaluate(
            dataset.label_dir,
            final_output_dir,
            label_split_file=split_file,
            current_class=name_to_class[cfg.CLASSES])
        if not args.eval_all:
            logger.info(ap_result_str)
            ret_dict.update(ap_dict)

    logger.info('result is saved to: %s' % result_dir)
    precision = ap_dict['Car_3d_easy'] + ap_dict['Car_3d_moderate'] + ap_dict[
        'Car_3d_hard']
    recall = total_recalled_bbox_list[3] / max(total_gt_bbox, 1.0)
    F2_score = 0
    return precision, recall, F2_score