def load_GT_eval(indice, database, split):
    data_val = KittiDataset(root_dir='/root/frustum-pointnets_RSC/dataset/',
                            dataset=database,
                            mode='TRAIN',
                            split=split)
    id_list = data_val.sample_id_list
    obj_frame = []
    corners_frame = []
    size_class_frame = []
    size_residual_frame = []
    angle_class_frame = []
    angle_residual_frame = []
    center_frame = []
    id_list_new = []
    for i in range(len(id_list)):
        if (id_list[i] < indice + 1):
            gt_obj_list = data_val.filtrate_objects(
                data_val.get_label(id_list[i]))
            #print("GT objs per frame", id_list[i],len(gt_obj_list))
            gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                          transform=False)
            obj_frame.append(gt_obj_list)
            corners_frame.append(gt_corners)
            angle_class_list = []
            angle_residual_list = []
            size_class_list = []
            size_residual_list = []
            center_list = []
            for j in range(len(gt_obj_list)):

                angle_class, angle_residual = angle2class(
                    gt_boxes3d[j][6], NUM_HEADING_BIN)
                angle_class_list.append(angle_class)
                angle_residual_list.append(angle_residual)

                size_class, size_residual = size2class(
                    np.array(
                        [gt_boxes3d[j][3], gt_boxes3d[j][4],
                         gt_boxes3d[j][5]]), "Pedestrian")
                size_class_list.append(size_class)
                size_residual_list.append(size_residual)

                center_list.append(
                    (gt_corners[j][0, :] + gt_corners[j][6, :]) / 2.0)
            size_class_frame.append(size_class_list)
            size_residual_frame.append(size_residual_list)
            angle_class_frame.append(angle_class_list)
            angle_residual_frame.append(angle_residual_list)
            center_frame.append(center_list)
            id_list_new.append(id_list[i])

    return corners_frame, id_list_new
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/home/amben/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        elif (split == 'train'):
            """
            with open(overwritten_data_path,'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.box3d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.label_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                self.heading_list = pickle.load(fp)
                self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp) 
            """
            pos_cnt = 0
            all_cnt = 0
            #self.dataset_kitti.sample_id_list=self.dataset_kitti.sample_id_list[0:10]
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []

            perturb_box2d = True
            augmentX = 1
            for i in range(len(self.id_list)):
                #load pc
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #load_labels
                gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                #load pixels
                pixels = get_pixels(self.id_list[i])
                for j in range(len(gt_obj_list)):
                    for _ in range(augmentX):
                        # Augment data by box2d perturbation
                        if perturb_box2d:
                            box2d = random_shift_box2d(gt_obj_list[j].box2d)
                        frus_pc, frus_pc_ind = extract_pc_in_box2d(
                            pc_lidar, pixels, box2d)

                        #get frus angle
                        center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                                 (box2d[1] + box2d[2]) / 2.0])
                        pc_center_frus = get_closest_pc_to_center(
                            pc_lidar, pixels, center_box2d)
                        frustum_angle = -1 * np.arctan2(
                            pc_center_frus[2], pc_center_frus[0])

                        #get label list
                        cls_label = np.zeros((frus_pc.shape[0]),
                                             dtype=np.int32)
                        gt_boxes3d = kitti_utils.objs_to_boxes3d(
                            [gt_obj_list[j]])
                        gt_corners = kitti_utils.boxes3d_to_corners3d(
                            gt_boxes3d, transform=True)
                        box_corners = gt_corners[0]
                        print(box_corners.shape)
                        print(pc_center_frus.shape)
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = 1
                        if box2d[3] - box2d[1] < 25 or np.sum(cls_label) == 0:
                            continue
                        self.input_list.append(frus_pc)
                        self.frustum_angle_list.append(frustum_angle)
                        self.label_list.append(cls_label)
                        self.box3d_list.append(box_corners)
                        self.box2d_list.append(box2d)
                        self.type_list.append("Pedestrian")
                        self.heading_list.append(gt_obj_list[j].ry)
                        self.size_list.append(
                            np.array([
                                gt_obj_list[j].l, gt_obj_list[j].w,
                                gt_obj_list[j].h
                            ]))
                        batch_list.append(self.id_list[i])
                        pos_cnt += np.sum(cls_label)
                        all_cnt += frus_pc.shape[0]

            #estimate average pc input
            self.id_list = batch_list
            print('Average pos ratio: %f' % (pos_cnt / float(all_cnt)))
            print('Average npoints: %f' % (float(all_cnt) / len(self.id_list)))
            #estimate average labels
        elif (split == 'val'):
            self.dataset_kitti.sample_id_list = self.dataset_kitti.sample_id_list
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            for i in range(len(self.id_list)):
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #get 2D boxes:
                box2ds = get_2Dboxes_detected(self.id_list[i])
                if box2ds == None:
                    continue
                pixels = get_pixels(self.id_list[i])
                for j in range(len(box2ds)):
                    box2d = box2ds[j]
                    frus_pc, frus_pc_ind = extract_pc_in_box2d(
                        pc_lidar, pixels, box2d)
                    # get frus angle
                    center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                             (box2d[1] + box2d[2]) / 2.0])
                    pc_center_frus = get_closest_pc_to_center(
                        pc_lidar, pixels, center_box2d)
                    frustum_angle = -1 * np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])

                    if (box2d[3] - box2d[1]) < 25 or len(frus_pc) < 50:
                        continue

                    # get_labels
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                    cls_label = np.zeros((frus_pc.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1

                    if (np.count_nonzero(cls_label > 0) < 20):
                        center = np.ones((3)) * (-1.0)
                        heading = 0.0
                        size = np.ones((3))
                        cls_label[cls_label > 0] = 0
                        seg = cls_label
                        rot_angle = 0.0
                        box3d_center = np.ones((3)) * (-1.0)
                        box3d = np.array([[
                            box3d_center[0], box3d_center[1], box3d_center[2],
                            size[0], size[1], size[2], rot_angle
                        ]])
                        corners_empty = kitti_utils.boxes3d_to_corners3d(
                            box3d, transform=True)
                        bb_corners = corners_empty[0]
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1, 0)
                        bb_corners = gt_corners[corners_max]
                        obj = gt_boxes3d[k]
                        center = np.array([obj[0], obj[1], obj[2]])
                        size = np.array([obj[3], obj[4], obj[5]])
                        print("size", size)
                        rot_angle = obj[6]

                    self.input_list.append(frus_pc)
                    count = 0
                    for c in range(len(self.input_list)):
                        count += self.input_list[c].shape[0]
                    print("average number of cloud:",
                          count / len(self.input_list))
                    self.frustum_angle_list.append(frustum_angle)
                    self.label_list.append(seg)
                    self.box3d_list.append(bb_corners)
                    self.box2d_list.append(box2d)
                    self.type_list.append("Pedestrian")
                    self.heading_list.append(rot_angle)
                    self.size_list.append(size)
                    batch_list.append(self.id_list[i])
            self.id_list = batch_list
Beispiel #3
0
    def __init__(self,radar_file, npoints, split,
                 random_flip=False, random_shift=False, rotate_to_center=False,
                 overwritten_data_path=None, from_rgb_detection=False, one_hot=False,all_batches=False ):#,generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(radar_file,root_dir='/home/amben/frustum-pointnets_RSC/dataset/', mode='TRAIN', split=split)
        self.all_batches = all_batches
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        self.pc_lidar_list = []
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(ROOT_DIR,
                                                 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #list = os.listdir("/root/3D_BoundingBox_Annotation_Tool_3D_BAT/input/NuScenes/ONE/pointclouds_Radar")
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.radar_OI=[]
            self.batch_size = []
            self.batch_train =[]
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            self.radar_point_list =[]
            average_recall=0
            average_extracted_pt = 0
            average_precision = 0
            average_extracted_pt_per_frame=0
            for i in range(len(self.id_list)):
                print("frame nbr", self.id_list[i])
                pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_obj_list = self.dataset_kitti.filtrate_objects(
                    self.dataset_kitti.get_label(self.id_list[i]))
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)

                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = k + 1
                recall = 0.0
                extracted_pt = 0.0
                precision = 0.0
                for l in range(len(pc_radar)):
                    radar_mask = get_radar_mask(pc_lidar, pc_radar[l].reshape(-1, 3))
                    radar_label_int = radar_mask*cls_label
                    print("radar_label_int_nbr",np.count_nonzero(radar_label_int > 0))
                    if(np.count_nonzero(radar_label_int>0) == 0):
                        continue
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(radar_label_int == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        print("radar_mask",np.count_nonzero(radar_mask==1))
                        print("label_extracted",max)
                        print("ground truth",(np.count_nonzero(cls_label == corners_max + 1)))
                        print( max/float(np.count_nonzero(cls_label==corners_max+1)))
                        extracted_pt += np.count_nonzero(radar_mask==1)

                        recall_r = max/float(np.count_nonzero(cls_label==corners_max+1))
                        precision_r =  max/float(np.count_nonzero(radar_mask==1))
                        print("recall_r",recall_r)
                        print("precision_r",precision_r)
                        recall += recall_r
                        precision += precision_r
                print("recall",recall/float(len(pc_radar)))
                print("precision",precision/float(len(pc_radar)))
                average_recall += recall/float(len(pc_radar))
                average_extracted_pt += extracted_pt/float(len(pc_radar))
                average_extracted_pt_per_frame += extracted_pt
                average_precision +=  precision/len(pc_radar)
            average_recall = average_recall/float(len(self.id_list))
            average_extracted_pt = average_extracted_pt/float(len(self.id_list))
            average_precision = average_precision/float(len(self.id_list))
            average_extracted_pt_per_frame = average_extracted_pt_per_frame/len(self.id_list)
            print ("average_recall", average_recall)
            print("average_precision",average_precision )
            print("average_extracted_pt",average_extracted_pt)
            print ("average_extracted_pt_per_frame",average_extracted_pt_per_frame)





            """

                m=0
                for j in range(len(pc_radar)):
                    #print(pc_radar[j].reshape(-1, 3).shape[0])
                    if (pc_radar[j,2]>1.5):
                        radar_mask = get_radar_mask(pc_lidar, pc_radar[j].reshape(-1, 3))
                        if(np.count_nonzero(radar_mask==1)>50):
                            radar_idx = np.argwhere(radar_mask == 1)
                            pc_fil = pc_lidar[radar_idx.reshape(-1)]
                            self.radar_OI.append(j)
                            m=m+1
                            radar_angle = -1 * np.arctan2(pc_radar[j,2],pc_radar[j,0])
                            cls_label = np.zeros((pc_fil.shape[0]), dtype=np.int32)
                            gt_obj_list = self.dataset_kitti.filtrate_objects(
                                self.dataset_kitti.get_label(self.id_list[i]))
                            gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                            gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
                            for k in range(gt_boxes3d.shape[0]):
                                box_corners = gt_corners[k]
                                fg_pt_flag = kitti_utils.in_hull(pc_fil[:, 0:3], box_corners)
                                cls_label[fg_pt_flag] = k + 1
                            if (np.count_nonzero(cls_label > 0) < 20):
                                if(self.all_batches):
                                    center = np.ones((3)) * (-1.0)
                                    heading = 0.0
                                    size = np.ones((3))
                                    cls_label[cls_label > 0] = 0
                                    seg = cls_label
                                    rot_angle = 0.0
                                    box3d_center = np.ones((3)) * (-1.0)
                                    box3d = np.array([[box3d_center[0], box3d_center[1], box3d_center[2], size[0], size[1],
                                                   size[2], rot_angle]])
                                    corners_empty = kitti_utils.boxes3d_to_corners3d(box3d, transform=True)
                                    bb_corners = corners_empty[0]
                                    batch = 0
                                else:
                                    continue

                            else:
                                max = 0
                                corners_max = 0
                                for k in range(gt_boxes3d.shape[0]):
                                    count = np.count_nonzero(cls_label == k + 1)
                                    if count > max:
                                        max = count
                                        corners_max = k
                                seg = np.where(cls_label == corners_max + 1, 1, 0)
                                bb_corners = gt_corners[corners_max]
                                obj = gt_boxes3d[k]
                                center = np.array([obj[0], obj[1], obj[2]])
                                size = np.array([obj[3], obj[4], obj[5]])
                                rot_angle = obj[6]
                                batch = 1
                            self.input_list.append(pc_fil)
                            self.frustum_angle_list.append(radar_angle)
                            self.label_list.append(seg)
                            self.box3d_list.append(bb_corners)
                            self.type_list.append("Pedestrian")
                            self.heading_list.append(rot_angle)
                            self.size_list.append(size)
                            self.batch_train.append(batch)
                            self.radar_point_list.append(pc_radar[j])
                            batch_list.append(self.id_list[i])
                            print(len(batch_list))
                            print(len(self.input_list))


                self.batch_size.append(m)
            self.id_list= batch_list
            print("id_list",len(self.id_list))
            print("self.input_list",len(self.input_list))

            #load radar
            #load pc
            #create mask
            #save only the one containing pc cpntainiing more than 50>


            """
            """
Beispiel #4
0
    def __getitem__(self, index):
        sample_id = int(self.sample_id_list[index])
        calib = self.get_calib(sample_id)
        img_shape = self.get_image_shape(sample_id)
        pts_lidar = self.get_lidar(sample_id)

        # get valid point (projected points should be in image)
        pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
        pts_intensity = pts_lidar[:, 3]

        pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
        pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)

        pts_rect = pts_rect[pts_valid_flag][:, 0:3]
        pts_intensity = pts_intensity[pts_valid_flag]

        if self.npoints < len(pts_rect):
            pts_depth = pts_rect[:, 2]
            pts_near_flag = pts_depth < 40.0
            far_idxs_choice = np.where(pts_near_flag == 0)[0]
            near_idxs = np.where(pts_near_flag == 1)[0]
            near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)

            choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
                if len(far_idxs_choice) > 0 else near_idxs_choice
            np.random.shuffle(choice)
        else:
            choice = np.arange(0, len(pts_rect), dtype=np.int32)
            if self.npoints > len(pts_rect):
                extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
                choice = np.concatenate((choice, extra_choice), axis=0)
            np.random.shuffle(choice)

        ret_pts_rect = pts_rect[choice, :]
        ret_pts_intensity = pts_intensity[choice] - 0.5  # translate intensity to [-0.5, 0.5]

        pts_features = [ret_pts_intensity.reshape(-1, 1)]
        ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]

        sample_info = {'sample_id': sample_id}

        if self.mode == 'TEST':
            if USE_INTENSITY:
                pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1)  # (N, C)
            else:
                pts_input = ret_pts_rect
            sample_info['pts_input'] = pts_input
            sample_info['pts_rect'] = ret_pts_rect
            sample_info['pts_features'] = ret_pts_features
            return sample_info

        gt_obj_list = self.filtrate_objects(self.get_label(sample_id))

        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)

        # prepare input
        if USE_INTENSITY:
            pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1)  # (N, C)
        else:
            pts_input = ret_pts_rect

        # generate training labels
        cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d)
        sample_info['pts_input'] = pts_input
        sample_info['pts_rect'] = ret_pts_rect
        sample_info['cls_labels'] = cls_labels
        return sample_info
    def __init__(self,
                 npoints,
                 database,
                 split,
                 res,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            dataset=database,
            mode='TRAIN',
            split=split)
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.res_det = res
        self.one_hot = one_hot

        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        elif (split == 'train'):
            """
            with open(overwritten_data_path,'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.box3d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.label_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                self.heading_list = pickle.load(fp)
                self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp) 
            """

            self.id_list = self.dataset_kitti.sample_id_list[:32]
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []

            perturb_box2d = True
            augmentX = 5
            for i in range(len(self.id_list)):
                #load pc
                print(self.id_list[i])
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #load_labels
                gt_obj_list_2D = self.dataset_kitti.get_label_2D(
                    self.id_list[i])
                ps = pc_lidar
                """gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=False)
                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = 1

                seg = cls_label
                fig = mlab.figure(figure=None, bgcolor=(0.4, 0.4, 0.4), fgcolor=None, engine=None, size=(1000, 500))
                mlab.points3d(ps[:, 0], ps[:, 1], ps[:, 2], seg, mode='point', colormap='gnuplot', scale_factor=1,
                              figure=fig)
                mlab.points3d(0, 0, 0, color=(1, 1, 1), mode='sphere', scale_factor=0.2, figure=fig)
                for s in range(len(gt_corners)):
                    center = np.array([gt_boxes3d[s][0], gt_boxes3d[s][1], gt_boxes3d[s][2]])
                    size = np.array([gt_boxes3d[s][3], gt_boxes3d[s][4], gt_boxes3d[s][5]])
                    rot_angle = gt_boxes3d[s][6]
                    box3d_from_label = get_3d_box(size, rot_angle,
                                                  center)
                    draw_gt_boxes3d([box3d_from_label], fig, color=(1, 0, 0))
                mlab.orientation_axes()
                raw_input()"""
                #load pixels
                pixels = get_pixels(self.id_list[i], split)
                for j in range(len(gt_obj_list_2D)):
                    for _ in range(augmentX):
                        # Augment data by box2d perturbation
                        if perturb_box2d:
                            box2d = random_shift_box2d(gt_obj_list_2D[j].box2d)
                        frus_pc, frus_pc_ind = extract_pc_in_box2d(
                            pc_lidar, pixels, box2d)
                        #get frus angle
                        center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                                 (box2d[1] + box2d[2]) / 2.0])
                        pc_center_frus = get_closest_pc_to_center(
                            pc_lidar, pixels, center_box2d)
                        frustum_angle = -np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])
                        #fig = plt.figure()
                        #ax = fig.add_subplot(111, projection="3d")
                        #ax.scatter(frus_pc[:, 0], frus_pc[:, 1], frus_pc[:, 2], c=frus_pc[:, 3:6], s=1)
                        #plt.show()

                        #get label list
                        gt_obj_list = self.dataset_kitti.get_label(
                            self.id_list[i])

                        cls_label = np.zeros((frus_pc.shape[0]),
                                             dtype=np.int32)
                        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                        gt_corners = kitti_utils.boxes3d_to_corners3d(
                            gt_boxes3d, transform=False)
                        for k in range(gt_boxes3d.shape[0]):
                            box_corners = gt_corners[k]
                            fg_pt_flag = kitti_utils.in_hull(
                                frus_pc[:, 0:3], box_corners)
                            cls_label[fg_pt_flag] = k + 1
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1.0, 0.0)

                        cls_label = seg
                        print("train", np.count_nonzero(cls_label == 1))
                        if box2d[3] - box2d[1] < 25 or np.sum(cls_label) == 0:
                            continue
                        self.input_list.append(frus_pc)
                        self.frustum_angle_list.append(frustum_angle)
                        self.label_list.append(cls_label)
                        self.box3d_list.append(gt_corners[corners_max])
                        self.box2d_list.append(box2d)
                        self.type_list.append("Pedestrian")
                        self.heading_list.append(gt_obj_list[corners_max].ry)
                        self.size_list.append(
                            np.array([
                                gt_obj_list[corners_max].h,
                                gt_obj_list[corners_max].w,
                                gt_obj_list[corners_max].l
                            ]))
                        batch_list.append(self.id_list[i])
            #estimate average pc input
            self.id_list = batch_list

            #estimate average labels
        elif (split == 'val' or split == 'test'):

            self.indice_box = []
            self.dataset_kitti.sample_id_list = self.dataset_kitti.sample_id_list[:
                                                                                  32]
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            for i in range(len(self.id_list)):
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                print(self.id_list[i])
                """ps = pc_lidar
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=False)
                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = 1

                seg = cls_label
                fig = mlab.figure(figure=None, bgcolor=(0.4, 0.4, 0.4), fgcolor=None, engine=None, size=(1000, 500))
                mlab.points3d(ps[:, 0], ps[:, 1], ps[:, 2], seg, mode='point', colormap='gnuplot', scale_factor=1,
                              figure=fig)
                mlab.points3d(0, 0, 0, color=(1, 1, 1), mode='sphere', scale_factor=0.2, figure=fig)"""
                """for s in range(len(gt_corners)):
                    center = np.array([gt_boxes3d[s][0], gt_boxes3d[s][1], gt_boxes3d[s][2]])
                    size = np.array([gt_boxes3d[s][3], gt_boxes3d[s][4], gt_boxes3d[s][5]])
                    rot_angle = gt_boxes3d[s][6]
                    box3d_from_label = get_3d_box(size,rot_angle,
                                                  center)
                    draw_gt_boxes3d([box3d_from_label], fig, color=(1, 0, 0))
                mlab.orientation_axes()
                raw_input()"""
                #get val 2D boxes:
                box2ds = get_2Dboxes_detected(self.id_list[i], self.res_det,
                                              split)
                if box2ds == None:
                    print("what")
                    continue
                print("number detection", len(box2ds))
                pixels = get_pixels(self.id_list[i], split)
                for j in range(len(box2ds)):
                    box2d = box2ds[j]

                    if (box2d[3] - box2d[1]) < 25 or (
                        (box2d[3] > 720 and box2d[1] > 720)) or (
                            (box2d[0] > 1280 and box2d[2] > 1280)) or (
                                (box2d[3] <= 0
                                 and box2d[1] <= 0)) or (box2d[0] <= 0
                                                         and box2d[2] <= 0):
                        continue
                    print(box2d)
                    print("box_height", box2d[3] - box2d[1])
                    frus_pc, frus_pc_ind = extract_pc_in_box2d(
                        pc_lidar, pixels, box2d)
                    #fig = plt.figure()
                    #ax = fig.add_subplot(111, projection="3d")
                    #ax.scatter(frus_pc[:, 0], frus_pc[:, 1], frus_pc[:, 2], c=frus_pc[:, 3:6], s=1)
                    #plt.show()
                    # get frus angle
                    center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                             (box2d[1] + box2d[2]) / 2.0])
                    pc_center_frus = get_closest_pc_to_center(
                        pc_lidar, pixels, center_box2d)
                    frustum_angle = -1 * np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])

                    if len(frus_pc) < 20:
                        continue

                    # get_labels
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                    cls_label = np.zeros((frus_pc.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=False)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1
                    print("gt in frus", np.count_nonzero(cls_label > 0))
                    if (np.count_nonzero(cls_label > 0) < 20):
                        center = np.ones((3)) * (-10.0)
                        heading = 0.0
                        size = np.ones((3))
                        cls_label[cls_label > 0] = 0
                        seg = cls_label
                        rot_angle = 0.0
                        box3d_center = np.ones((3)) * (-1.0)
                        box3d = np.array([[
                            box3d_center[0], box3d_center[1], box3d_center[2],
                            size[0], size[1], size[2], rot_angle
                        ]])
                        corners_empty = kitti_utils.boxes3d_to_corners3d(
                            box3d, transform=False)
                        bb_corners = corners_empty[0]
                        self.indice_box.append(0)
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1, 0)
                        self.indice_box.append(corners_max + 1)
                        print("val:", np.count_nonzero(cls_label == 1))
                        bb_corners = gt_corners[corners_max]
                        obj = gt_boxes3d[corners_max]
                        center = np.array([obj[0], obj[1], obj[2]])
                        size = np.array([obj[3], obj[4], obj[5]])
                        rot_angle = obj[6]
                    self.input_list.append(frus_pc)
                    self.frustum_angle_list.append(frustum_angle)
                    self.label_list.append(seg)
                    self.box3d_list.append(bb_corners)
                    self.box2d_list.append(box2d)
                    self.type_list.append("Pedestrian")
                    self.heading_list.append(rot_angle)
                    self.size_list.append(size)
                    batch_list.append(self.id_list[i])
            self.id_list = batch_list
            print("batch_list", batch_list)
    def __getitem__(self, index):
        sample_id = int(self.sample_id_list[index])
        #calib = self.get_calib(sample_id)
        #img_shape = self.get_image_shape(sample_id)
        #sample_id=0
        pts_lidar = self.get_lidar(sample_id)
        # TODO dont need to check if pc are in image model as they are already in there.
        # get valid point (projected points should be in image)
        # pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
        # pts_intensity = pts_lidar[:, 3]

        # pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
        # pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)

        # pts_rect = pts_rect[pts_valid_flag][:, 0:3]
        # pts_intensity = pts_intensity[pts_valid_flag]

        pts_rect = pts_lidar[:, 0:3]
        pts_intensity = pts_lidar[:, 3:]
        #print("intensity:", pts_intensity[0] )
        #TODO: modifiy the minimum number of data
        if self.npoints < len(pts_rect):

            #print(len(pts_rect))
            pts_depth = pts_rect[:, 2]
            pts_near_flag = pts_depth < 20.0
            far_idxs_choice = np.where(pts_near_flag == 0)[0]
            near_idxs = np.where(pts_near_flag == 1)[0]
            #print(len(pts_depth),len(far_idxs_choice),len(near_idxs),self.npoints, self.npoints - len(far_idxs_choice))
            near_idxs_choice = np.random.choice(near_idxs,
                                                self.npoints -
                                                len(far_idxs_choice),
                                                replace=False)

            choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
                if len(far_idxs_choice) > 0 else near_idxs_choice
            np.random.shuffle(choice)
        else:
            if (self.npoints / 2) > len(pts_rect):
                diff = int(self.npoints / 2 - len(pts_rect))
                #print(diff)
                add_pts = np.zeros((diff, 3), dtype=np.float32)
                add_int = np.zeros((diff, 3), dtype=np.float32)
                #print("add_int", add_int[0])
                pts_rect = np.concatenate((pts_rect, add_pts), axis=0)
                pts_intensity = np.concatenate((pts_intensity, add_int),
                                               axis=0)
            choice = np.arange(0, len(pts_rect), dtype=np.int32)
            if self.npoints > len(pts_rect):

                #print(len(pts_rect),self.npoints - len(pts_rect))
                extra_choice = np.random.choice(choice,
                                                self.npoints - len(pts_rect),
                                                replace=False)
                choice = np.concatenate((choice, extra_choice), axis=0)
            np.random.shuffle(choice)
        #print(len(pts_rect))
        ret_pts_rect = pts_rect[choice, :]
        #ret_pts_rect=pts_rect
        # TODO don't use intensity feature or try a method to add rgb
        #ret_pts_intensity = pts_intensity[choice] - 0.5  # translate intensity to [-0.5, 0.5]
        ret_pts_intensity = pts_intensity[choice]
        pts_features = [ret_pts_intensity.reshape(-1, 3)]
        ret_pts_features = np.concatenate(
            pts_features,
            axis=1) if pts_features.__len__() > 1 else pts_features[0]

        sample_info = {'sample_id': sample_id}

        if self.mode == 'TEST':
            if USE_INTENSITY:
                pts_input = np.concatenate((ret_pts_rect, ret_pts_features),
                                           axis=1)  # (N, C)
            else:
                pts_input = ret_pts_rect
            sample_info['pts_input'] = pts_input
            sample_info['pts_rect'] = ret_pts_rect
            sample_info['pts_features'] = ret_pts_features
            return sample_info

        gt_obj_list = self.filtrate_objects(self.get_label(sample_id))

        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)

        # prepare input
        if USE_INTENSITY:
            pts_input = np.concatenate((ret_pts_rect, ret_pts_features),
                                       axis=1)  # (N, C)
        else:
            pts_input = ret_pts_rect

        # generate training labels
        cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d)
        sample_info['pts_input'] = pts_input
        sample_info['pts_rect'] = ret_pts_rect
        sample_info['cls_labels'] = cls_labels
        return sample_info
    def __getitem__(self, index):
        ''' Get index-th element from the picked file dataset. '''
        # ------------------------------ INPUTS ----------------------------
        #rot_angle = self.get_center_view_rot_angle(index)
        # load radar points
        input_radar = self.dataset_kitti.get_radar(self.id_list[index])
        input = self.dataset_kitti.get_lidar(self.id_list[index])
        radar_mask = get_radar_mask(
            input, input_radar[self.radar_OI[index]].reshape(-1, 3))
        num_point_fil = np.count_nonzero(radar_mask == 1)
        radar_idx = np.argwhere(radar_mask == 1)
        input = input[radar_idx.reshape(-1)]
        print(input.shape)

        pts_rect = input[:, 0:3]
        pts_intensity = input[:, 3:]
        if self.npoints < len(pts_rect):

            # print(len(pts_rect))
            print(pts_rect.shape)
            pts_depth = pts_rect[:, 2]
            pts_near_flag = pts_depth < 20.0
            far_idxs_choice = np.where(pts_near_flag == 0)[0]
            near_idxs = np.where(pts_near_flag == 1)[0]
            # print(len(pts_depth),len(far_idxs_choice),len(near_idxs),self.npoints, self.npoints - len(far_idxs_choice))
            near_idxs_choice = np.random.choice(near_idxs,
                                                self.npoints -
                                                len(far_idxs_choice),
                                                replace=False)

            choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
                if len(far_idxs_choice) > 0 else near_idxs_choice
            np.random.shuffle(choice)
        else:
            if (self.npoints / 2) > len(pts_rect):
                diff = int(self.npoints / 2 - len(pts_rect))
                # print(diff)
                add_pts = np.zeros((diff, 3), dtype=np.float32)
                add_int = np.zeros((diff, 3), dtype=np.float32)
                # print("add_int", add_int[0])
                pts_rect = np.concatenate((pts_rect, add_pts), axis=0)
                pts_intensity = np.concatenate((pts_intensity, add_int),
                                               axis=0)
            choice = np.arange(0, len(pts_rect), dtype=np.int32)
            if self.npoints > len(pts_rect):
                # print(len(pts_rect),self.npoints - len(pts_rect))
                extra_choice = np.random.choice(choice,
                                                self.npoints - len(pts_rect),
                                                replace=False)
                choice = np.concatenate((choice, extra_choice), axis=0)
        np.random.shuffle(choice)
        # print(len(pts_rect))
        ret_pts_rect = pts_rect[choice, :]
        # ret_pts_rect=pts_rect
        # TODO don't use intensity feature or try a method to add rgb
        # ret_pts_intensity = pts_intensity[choice] - 0.5  # translate intensity to [-0.5, 0.5]
        ret_pts_intensity = pts_intensity[choice]
        pts_features = [ret_pts_intensity.reshape(-1, 3)]
        ret_pts_features = np.concatenate(
            pts_features,
            axis=1) if pts_features.__len__() > 1 else pts_features[0]
        input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1)

        radar_mask = get_radar_mask(input, input_radar)

        gt_obj_list = self.dataset_kitti.filtrate_objects(
            self.dataset_kitti.get_label(self.id_list[index]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
        gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)
        corners3d = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                     transform=True)

        cls_label = np.zeros((input.shape[0]), dtype=np.int32)
        gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                      transform=True)
        for k in range(gt_boxes3d.shape[0]):
            box_corners = gt_corners[k]
            fg_pt_flag = kitti_utils.in_hull(input[:, 0:3], box_corners)
            cls_label[fg_pt_flag] = 1

        type = "Pedestrian"
        center = gt_boxes3d[:, 0:3]
        #closest_radar_point = get_closest_radar_point(center,input_radar)
        heading = gt_boxes3d[:, 6]

        size = gt_boxes3d[:, 3:6]
        # frustum_angle with 0.0 populate
        frustum_angle = 0.0

        rot_angle = 0.0
        # Compute one hot vector
        if self.one_hot:
            cls_type = type
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        # Get point cloud
        if self.rotate_to_center:
            point_set = self.get_center_view_point_set(input, frustum_angle)
        else:
            point_set = input
        # Resample

        #print(point_set.shape[0],self.npoints)
        #choice = np.random.choice(point_set.shape[0], self.npoints, replace=True)
        #print(len(choice))
        #point_set = point_set[choice, :]

        if self.from_rgb_detection:
            if self.one_hot:
                return point_set, rot_angle, self.prob_list[index], one_hot_vec
            else:
                return point_set, rot_angle, self.prob_list[index]

        # ------------------------------ LABELS ----------------------------
        seg = cls_label
        #seg = seg[choice]
        #print("batch seg 3asba:", np.count_nonzero(seg == 1))

        # Get center point of 3D box
        if self.rotate_to_center:
            box3d_center = self.get_center_view_box3d_center(
                corners3d, frustum_angle)
        else:
            box3d_center = self.get_box3d_center(corners3d)

        # Heading
        if self.rotate_to_center:
            heading_angle = heading - rot_angle
        else:
            heading_angle = heading

        # Size
        size_class, size_residual = size2class(size, type)

        # Data Augmentation
        if self.random_flip:
            # note: rot_angle won't be correct if we have random_flip
            # so do not use it in case of random flipping.
            if np.random.random() > 0.5:  # 50% chance flipping
                point_set[:, 0] *= -1
                box3d_center[0] *= -1
                heading_angle = np.pi - heading_angle
        if self.random_shift:
            dist = np.sqrt(np.sum(box3d_center[0]**2 + box3d_center[1]**2))
            shift = np.clip(np.random.randn() * dist * 0.05, dist * 0.8,
                            dist * 1.2)
            point_set[:, 2] += shift
            box3d_center[2] += shift
        print(heading_angle)
        angle_class, angle_residual = angle2class(heading_angle,
                                                  NUM_HEADING_BIN)
        #print(angle_class,angle_residual)
        if self.one_hot:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle, one_hot_vec,radar_mask
        else:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle,radar_mask
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False,
                 generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)

        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #with open(overwritten_data_path, 'rb') as fp:
            #load list of frames
            self.id_list = self.dataset_kitti.sample_id_list
            print("id_list", len(self.id_list))
            #fil = np.zeros((len(self.id_list)))
            #for i in range(len(self.id_list)):
            #    print(self.id_list[i])
            #    gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[i]))
            #    print(len(gt_obj_list))
            #gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            #print(gt_boxes3d)
            #    if(len(gt_obj_list)==1):
            #        fil[i]=1

            #self.id_list= np.extract(fil,self.id_list)
            if (generate_database):
                self.index_batch = []
                self.label_present = []
                self.radar_OI = []
                self.radar_mask_len = []
                self.cls_labels_len = []

                for i in range(len(self.id_list)):
                    pc_input = self.dataset_kitti.get_lidar(self.id_list[i])
                    pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))

                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    corners3d = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)

                    cls_label = np.zeros((pc_input.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            pc_input[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1
                    print("indice", self.id_list[i])
                    print("number of boxes", gt_boxes3d.shape[0])
                    print("cls_label nunber", np.count_nonzero(cls_label > 0))
                    print("number of radar pts present in frame",
                          len(pc_radar))
                    for j in range(len(pc_radar)):
                        radar_mask = get_radar_mask(pc_input,
                                                    pc_radar[j].reshape(-1, 3))
                        # check label present in radar ROI
                        print("radar_mask", np.count_nonzero(radar_mask == 1))
                        label_radar_intersection = radar_mask * cls_label

                        print("intersection",
                              np.count_nonzero(label_radar_intersection > 0))
                        # if there's one
                        labels_present = []
                        for m in range(gt_boxes3d.shape[0]):
                            if (np.isin(m + 1, label_radar_intersection)):
                                labels_present.append(m + 1)
                        print("labels present", labels_present)
                        if (len(labels_present) == 1):
                            self.radar_mask_len.append(
                                np.count_nonzero(radar_mask == 1))
                            self.index_batch.append(self.id_list[i])
                            self.radar_OI.append(j)
                            self.label_present.append(labels_present[0])
                            self.cls_labels_len.append(
                                np.count_nonzero(label_radar_intersection > 0))
                print("retained indices", self.index_batch)
                print("len retained indices", len(self.index_batch))
                # keep this as a batch
                # if there's isn't
                # than forget about it
                with open('radar_batches_stats_' + split + '.csv',
                          mode='w') as csv_file:
                    fieldnames = [
                        'index_batch', 'radar_mask_len', 'radar_OI',
                        'box_present', 'cls_labels_len'
                    ]
                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
                    writer.writeheader()
                    for i in range(len(self.index_batch)):
                        writer.writerow({
                            'index_batch':
                            self.index_batch[i],
                            'radar_mask_len':
                            self.radar_mask_len[i],
                            'radar_OI':
                            self.radar_OI[i],
                            'box_present':
                            self.box_present[i],
                            'cls_labels_len':
                            self.cls_labels_len[i]
                        })
                self.id_list = self.index_batch
                print("id_list_filtered", len(self.id_list))
            else:
                database_infos = pandas.read_csv(
                    '/root/frustum-pointnets_RSC_RADAR_fil_PC_batch/train/radar_batches_stats_'
                    + split + '.csv')
                self.id_list = database_infos['index_batch']
                self.radar_OI = database_infos['radar_OI']
                self.box_present = database_infos['box_present']
            """
if __name__ == '__main__':

    dataset_kitti = KittiDataset(
        "pc_radar_f3_vox",
        root_dir='/home/amben/frustum-pointnets_RSC/dataset/',
        mode='TRAIN',
        split="trainval")
    id_list = dataset_kitti.sample_id_list
    det_obj = []
    present_obj = []
    radar_pts = []
    for i in range(len(id_list)):
        pc_radar = dataset_kitti.get_radar(id_list[i])
        gt_obj_list = dataset_kitti.filtrate_objects(
            dataset_kitti.get_label(id_list[i]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)

        gt_boxes3d[:, 3] = gt_boxes3d[:, 3] + 2
        gt_boxes3d[:, 4] = gt_boxes3d[:, 4] + 2
        gt_boxes3d[:, 5] = gt_boxes3d[:, 5] + 2
        # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)
        gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                      transform=True)
        for k in range(gt_boxes3d.shape[0]):
            box_corners = gt_corners[k]
            fg_pt_flag = kitti_utils.in_hull(pc_radar[:, 0:3], box_corners)
            count_radar = np.count_nonzero(fg_pt_flag == True)
            radar_rest_idx = np.argwhere(fg_pt_flag == False)
            pc_radar = pc_radar[radar_rest_idx.reshape(-1)]
            radar_pos_idx = np.argwhere(fg_pt_flag == True)
            #pc_radar_obj = pc_radar(radar_pos_idx)