def load_GT_eval(indice, database, split):
    data_val = KittiDataset(root_dir='/root/frustum-pointnets_RSC/dataset/',
                            dataset=database,
                            mode='TRAIN',
                            split=split)
    id_list = data_val.sample_id_list
    obj_frame = []
    corners_frame = []
    size_class_frame = []
    size_residual_frame = []
    angle_class_frame = []
    angle_residual_frame = []
    center_frame = []
    id_list_new = []
    for i in range(len(id_list)):
        if (id_list[i] < indice + 1):
            gt_obj_list = data_val.filtrate_objects(
                data_val.get_label(id_list[i]))
            #print("GT objs per frame", id_list[i],len(gt_obj_list))
            gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                          transform=False)
            obj_frame.append(gt_obj_list)
            corners_frame.append(gt_corners)
            angle_class_list = []
            angle_residual_list = []
            size_class_list = []
            size_residual_list = []
            center_list = []
            for j in range(len(gt_obj_list)):

                angle_class, angle_residual = angle2class(
                    gt_boxes3d[j][6], NUM_HEADING_BIN)
                angle_class_list.append(angle_class)
                angle_residual_list.append(angle_residual)

                size_class, size_residual = size2class(
                    np.array(
                        [gt_boxes3d[j][3], gt_boxes3d[j][4],
                         gt_boxes3d[j][5]]), "Pedestrian")
                size_class_list.append(size_class)
                size_residual_list.append(size_residual)

                center_list.append(
                    (gt_corners[j][0, :] + gt_corners[j][6, :]) / 2.0)
            size_class_frame.append(size_class_list)
            size_residual_frame.append(size_residual_list)
            angle_class_frame.append(angle_class_list)
            angle_residual_frame.append(angle_residual_list)
            center_frame.append(center_list)
            id_list_new.append(id_list[i])

    return corners_frame, id_list_new
Beispiel #2
0
                                   lr_scheduler, total_it, tb_log, log_f)

        if epoch % args.ckpt_save_interval == 0:
            with torch.no_grad():
                avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log,
                                         log_f)
                ckpt_name = os.path.join(ckpt_dir,
                                         'checkpoint_epoch_%d' % epoch)
                save_checkpoint(model, epoch, ckpt_name)


if __name__ == '__main__':
    MODEL = importlib.import_module(args.net)  # import network module
    model = MODEL.get_model(input_channels=0)

    eval_set = KittiDataset(root_dir='./data', mode='EVAL')
    eval_loader = DataLoader(eval_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=args.workers,
                             collate_fn=eval_set.collate_batch)

    if args.mode == 'train':
        train_set = KittiDataset(root_dir='./data', mode='TRAIN')
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.workers,
                                  collate_fn=train_set.collate_batch)
Beispiel #3
0
    def __init__(self,radar_file, npoints, split,
                 random_flip=False, random_shift=False, rotate_to_center=False,
                 overwritten_data_path=None, from_rgb_detection=False, one_hot=False,all_batches=False ):#,generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(radar_file,root_dir='/home/amben/frustum-pointnets_RSC/dataset/', mode='TRAIN', split=split)
        self.all_batches = all_batches
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        self.pc_lidar_list = []
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(ROOT_DIR,
                                                 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #list = os.listdir("/root/3D_BoundingBox_Annotation_Tool_3D_BAT/input/NuScenes/ONE/pointclouds_Radar")
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.radar_OI=[]
            self.batch_size = []
            self.batch_train =[]
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            self.radar_point_list =[]
            average_recall=0
            average_extracted_pt = 0
            average_precision = 0
            average_extracted_pt_per_frame=0
            for i in range(len(self.id_list)):
                print("frame nbr", self.id_list[i])
                pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_obj_list = self.dataset_kitti.filtrate_objects(
                    self.dataset_kitti.get_label(self.id_list[i]))
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)

                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = k + 1
                recall = 0.0
                extracted_pt = 0.0
                precision = 0.0
                for l in range(len(pc_radar)):
                    radar_mask = get_radar_mask(pc_lidar, pc_radar[l].reshape(-1, 3))
                    radar_label_int = radar_mask*cls_label
                    print("radar_label_int_nbr",np.count_nonzero(radar_label_int > 0))
                    if(np.count_nonzero(radar_label_int>0) == 0):
                        continue
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(radar_label_int == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        print("radar_mask",np.count_nonzero(radar_mask==1))
                        print("label_extracted",max)
                        print("ground truth",(np.count_nonzero(cls_label == corners_max + 1)))
                        print( max/float(np.count_nonzero(cls_label==corners_max+1)))
                        extracted_pt += np.count_nonzero(radar_mask==1)

                        recall_r = max/float(np.count_nonzero(cls_label==corners_max+1))
                        precision_r =  max/float(np.count_nonzero(radar_mask==1))
                        print("recall_r",recall_r)
                        print("precision_r",precision_r)
                        recall += recall_r
                        precision += precision_r
                print("recall",recall/float(len(pc_radar)))
                print("precision",precision/float(len(pc_radar)))
                average_recall += recall/float(len(pc_radar))
                average_extracted_pt += extracted_pt/float(len(pc_radar))
                average_extracted_pt_per_frame += extracted_pt
                average_precision +=  precision/len(pc_radar)
            average_recall = average_recall/float(len(self.id_list))
            average_extracted_pt = average_extracted_pt/float(len(self.id_list))
            average_precision = average_precision/float(len(self.id_list))
            average_extracted_pt_per_frame = average_extracted_pt_per_frame/len(self.id_list)
            print ("average_recall", average_recall)
            print("average_precision",average_precision )
            print("average_extracted_pt",average_extracted_pt)
            print ("average_extracted_pt_per_frame",average_extracted_pt_per_frame)





            """

                m=0
                for j in range(len(pc_radar)):
                    #print(pc_radar[j].reshape(-1, 3).shape[0])
                    if (pc_radar[j,2]>1.5):
                        radar_mask = get_radar_mask(pc_lidar, pc_radar[j].reshape(-1, 3))
                        if(np.count_nonzero(radar_mask==1)>50):
                            radar_idx = np.argwhere(radar_mask == 1)
                            pc_fil = pc_lidar[radar_idx.reshape(-1)]
                            self.radar_OI.append(j)
                            m=m+1
                            radar_angle = -1 * np.arctan2(pc_radar[j,2],pc_radar[j,0])
                            cls_label = np.zeros((pc_fil.shape[0]), dtype=np.int32)
                            gt_obj_list = self.dataset_kitti.filtrate_objects(
                                self.dataset_kitti.get_label(self.id_list[i]))
                            gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                            gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
                            for k in range(gt_boxes3d.shape[0]):
                                box_corners = gt_corners[k]
                                fg_pt_flag = kitti_utils.in_hull(pc_fil[:, 0:3], box_corners)
                                cls_label[fg_pt_flag] = k + 1
                            if (np.count_nonzero(cls_label > 0) < 20):
                                if(self.all_batches):
                                    center = np.ones((3)) * (-1.0)
                                    heading = 0.0
                                    size = np.ones((3))
                                    cls_label[cls_label > 0] = 0
                                    seg = cls_label
                                    rot_angle = 0.0
                                    box3d_center = np.ones((3)) * (-1.0)
                                    box3d = np.array([[box3d_center[0], box3d_center[1], box3d_center[2], size[0], size[1],
                                                   size[2], rot_angle]])
                                    corners_empty = kitti_utils.boxes3d_to_corners3d(box3d, transform=True)
                                    bb_corners = corners_empty[0]
                                    batch = 0
                                else:
                                    continue

                            else:
                                max = 0
                                corners_max = 0
                                for k in range(gt_boxes3d.shape[0]):
                                    count = np.count_nonzero(cls_label == k + 1)
                                    if count > max:
                                        max = count
                                        corners_max = k
                                seg = np.where(cls_label == corners_max + 1, 1, 0)
                                bb_corners = gt_corners[corners_max]
                                obj = gt_boxes3d[k]
                                center = np.array([obj[0], obj[1], obj[2]])
                                size = np.array([obj[3], obj[4], obj[5]])
                                rot_angle = obj[6]
                                batch = 1
                            self.input_list.append(pc_fil)
                            self.frustum_angle_list.append(radar_angle)
                            self.label_list.append(seg)
                            self.box3d_list.append(bb_corners)
                            self.type_list.append("Pedestrian")
                            self.heading_list.append(rot_angle)
                            self.size_list.append(size)
                            self.batch_train.append(batch)
                            self.radar_point_list.append(pc_radar[j])
                            batch_list.append(self.id_list[i])
                            print(len(batch_list))
                            print(len(self.input_list))


                self.batch_size.append(m)
            self.id_list= batch_list
            print("id_list",len(self.id_list))
            print("self.input_list",len(self.input_list))

            #load radar
            #load pc
            #create mask
            #save only the one containing pc cpntainiing more than 50>


            """
            """
Beispiel #4
0
class FrustumDataset(object):
    ''' Dataset class for Frustum PointNets training/evaluation.
    Load prepared KITTI data from pickled files, return individual data element
    [optional] along with its annotations.
    '''

    def __init__(self,radar_file, npoints, split,
                 random_flip=False, random_shift=False, rotate_to_center=False,
                 overwritten_data_path=None, from_rgb_detection=False, one_hot=False,all_batches=False ):#,generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(radar_file,root_dir='/home/amben/frustum-pointnets_RSC/dataset/', mode='TRAIN', split=split)
        self.all_batches = all_batches
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        self.pc_lidar_list = []
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(ROOT_DIR,
                                                 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #list = os.listdir("/root/3D_BoundingBox_Annotation_Tool_3D_BAT/input/NuScenes/ONE/pointclouds_Radar")
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.radar_OI=[]
            self.batch_size = []
            self.batch_train =[]
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            self.radar_point_list =[]
            average_recall=0
            average_extracted_pt = 0
            average_precision = 0
            average_extracted_pt_per_frame=0
            for i in range(len(self.id_list)):
                print("frame nbr", self.id_list[i])
                pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_obj_list = self.dataset_kitti.filtrate_objects(
                    self.dataset_kitti.get_label(self.id_list[i]))
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)

                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = k + 1
                recall = 0.0
                extracted_pt = 0.0
                precision = 0.0
                for l in range(len(pc_radar)):
                    radar_mask = get_radar_mask(pc_lidar, pc_radar[l].reshape(-1, 3))
                    radar_label_int = radar_mask*cls_label
                    print("radar_label_int_nbr",np.count_nonzero(radar_label_int > 0))
                    if(np.count_nonzero(radar_label_int>0) == 0):
                        continue
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(radar_label_int == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        print("radar_mask",np.count_nonzero(radar_mask==1))
                        print("label_extracted",max)
                        print("ground truth",(np.count_nonzero(cls_label == corners_max + 1)))
                        print( max/float(np.count_nonzero(cls_label==corners_max+1)))
                        extracted_pt += np.count_nonzero(radar_mask==1)

                        recall_r = max/float(np.count_nonzero(cls_label==corners_max+1))
                        precision_r =  max/float(np.count_nonzero(radar_mask==1))
                        print("recall_r",recall_r)
                        print("precision_r",precision_r)
                        recall += recall_r
                        precision += precision_r
                print("recall",recall/float(len(pc_radar)))
                print("precision",precision/float(len(pc_radar)))
                average_recall += recall/float(len(pc_radar))
                average_extracted_pt += extracted_pt/float(len(pc_radar))
                average_extracted_pt_per_frame += extracted_pt
                average_precision +=  precision/len(pc_radar)
            average_recall = average_recall/float(len(self.id_list))
            average_extracted_pt = average_extracted_pt/float(len(self.id_list))
            average_precision = average_precision/float(len(self.id_list))
            average_extracted_pt_per_frame = average_extracted_pt_per_frame/len(self.id_list)
            print ("average_recall", average_recall)
            print("average_precision",average_precision )
            print("average_extracted_pt",average_extracted_pt)
            print ("average_extracted_pt_per_frame",average_extracted_pt_per_frame)





            """

                m=0
                for j in range(len(pc_radar)):
                    #print(pc_radar[j].reshape(-1, 3).shape[0])
                    if (pc_radar[j,2]>1.5):
                        radar_mask = get_radar_mask(pc_lidar, pc_radar[j].reshape(-1, 3))
                        if(np.count_nonzero(radar_mask==1)>50):
                            radar_idx = np.argwhere(radar_mask == 1)
                            pc_fil = pc_lidar[radar_idx.reshape(-1)]
                            self.radar_OI.append(j)
                            m=m+1
                            radar_angle = -1 * np.arctan2(pc_radar[j,2],pc_radar[j,0])
                            cls_label = np.zeros((pc_fil.shape[0]), dtype=np.int32)
                            gt_obj_list = self.dataset_kitti.filtrate_objects(
                                self.dataset_kitti.get_label(self.id_list[i]))
                            gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                            gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
                            for k in range(gt_boxes3d.shape[0]):
                                box_corners = gt_corners[k]
                                fg_pt_flag = kitti_utils.in_hull(pc_fil[:, 0:3], box_corners)
                                cls_label[fg_pt_flag] = k + 1
                            if (np.count_nonzero(cls_label > 0) < 20):
                                if(self.all_batches):
                                    center = np.ones((3)) * (-1.0)
                                    heading = 0.0
                                    size = np.ones((3))
                                    cls_label[cls_label > 0] = 0
                                    seg = cls_label
                                    rot_angle = 0.0
                                    box3d_center = np.ones((3)) * (-1.0)
                                    box3d = np.array([[box3d_center[0], box3d_center[1], box3d_center[2], size[0], size[1],
                                                   size[2], rot_angle]])
                                    corners_empty = kitti_utils.boxes3d_to_corners3d(box3d, transform=True)
                                    bb_corners = corners_empty[0]
                                    batch = 0
                                else:
                                    continue

                            else:
                                max = 0
                                corners_max = 0
                                for k in range(gt_boxes3d.shape[0]):
                                    count = np.count_nonzero(cls_label == k + 1)
                                    if count > max:
                                        max = count
                                        corners_max = k
                                seg = np.where(cls_label == corners_max + 1, 1, 0)
                                bb_corners = gt_corners[corners_max]
                                obj = gt_boxes3d[k]
                                center = np.array([obj[0], obj[1], obj[2]])
                                size = np.array([obj[3], obj[4], obj[5]])
                                rot_angle = obj[6]
                                batch = 1
                            self.input_list.append(pc_fil)
                            self.frustum_angle_list.append(radar_angle)
                            self.label_list.append(seg)
                            self.box3d_list.append(bb_corners)
                            self.type_list.append("Pedestrian")
                            self.heading_list.append(rot_angle)
                            self.size_list.append(size)
                            self.batch_train.append(batch)
                            self.radar_point_list.append(pc_radar[j])
                            batch_list.append(self.id_list[i])
                            print(len(batch_list))
                            print(len(self.input_list))


                self.batch_size.append(m)
            self.id_list= batch_list
            print("id_list",len(self.id_list))
            print("self.input_list",len(self.input_list))

            #load radar
            #load pc
            #create mask
            #save only the one containing pc cpntainiing more than 50>


            """
            """
                self.input_list=[]
                self.box3d_list=[]
                self.label_list=[]
                self.type_list=[]
                self.heading_list=[]
                self.size_list=[]
                self.frustum_angle_list=[]
                for i in range(len(self.id_list)):

                    #BOX3D_IN_CORNERS FORMAT
                    gt_obj_list = dataset_kitti.filtrate_objects(dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    print(gt_boxes3d)
                    self.box3d_list.append(kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True))

                    #INPUT_DATA_LIST
                    input = dataset_kitti.get_lidar(self.id_list[i])
                    self.input_list.append(input)

                    #LABEL_LIST
                    cls_label = np.zeros((self.input_list[i].shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(self.input_list[i][:,0:3], box_corners)
                        cls_label[fg_pt_flag] = 1

                    #print(cls_label.shape)
                    print("cls_label", (np.count_nonzero(cls_label == 1)))
                    
                    label_pts = np.ndarray((cls_label_count, 3))
                    j = 0
                    c = np.ndarray((len(input), 3))

                    for i in range(len(input)):
                        if (cls_label[i] == 1):
                            c[i] = np.array([1.0, 0.0, 0.0])
                            label_pts[j] = input[i,0:3]
                            j = j + 1
                        else:
                            c[i] = np.array([0.0, 0.0, 1.0])

                    fig = plt.figure()
                    ax = fig.add_subplot(111, projection="3d")
                    ax.scatter(label_pts[:, 0], label_pts[:, 1], label_pts[:, 2])
                    plt.show()

                    fig = plt.figure()
                    ax = fig.add_subplot(111, projection="3d")
                    ax.scatter(input[:, 0], input[:, 1], input[:, 2], c=c, s=1)
                    plt.show()
                    
                    self.label_list.append(cls_label)

                    #TYPE_LIST
                    self.type_list.append("Pedestrian")
                    #HEADING_LIST
                    self.heading_list.append(gt_boxes3d[:,6])

                    #SIZE_LIST l,w,h
                    self.size_list.append(gt_boxes3d[:,3:6])
                    #frustum_angle with 0.0 populate
                    self.frustum_angle_list.append(0.0)
                """
                # box2d in corners format
                #self.box2d_list = pickle.load(fp)
                # box3d in corners format
                #self.box3d_list = pickle.load(fp)
                # point cloud, hole or frustum filtered? looks like frustrum filtered because number of pc is too small
                #self.input_list = pickle.load(fp)
                # from frustrum point cloud which one belongs to label
                #self.label_list = pickle.load(fp)
                # for each 2d box/frustrum point cloud, detected object
                #self.type_list = pickle.load(fp)
                # rotation of 3d label box (ry)
                #self.heading_list = pickle.load(fp)
                # array of l,w,h
                #self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                #self.frustum_angle_list = pickle.load(fp)


    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, index):
        ''' Get index-th element from the picked file dataset. '''
        # ------------------------------ INPUTS ----------------------------

        label_mask = self.batch_train[index]
        rot_angle = self.get_center_view_rot_angle(index)

        # Compute one hot vector
        if self.one_hot:
            cls_type = self.type_list[index]
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        # Get point cloud
        if self.rotate_to_center:
            point_set = self.get_center_view_point_set(index)
        else:
            point_set = self.input_list[index]
        # Resample
        choice = np.random.choice(point_set.shape[0], self.npoints, replace=True)
        point_set = point_set[choice, :]

        if self.from_rgb_detection:
            if self.one_hot:
                return point_set, rot_angle, self.prob_list[index], one_hot_vec
            else:
                return point_set, rot_angle, self.prob_list[index]

        # ------------------------------ LABELS ----------------------------
        seg = self.label_list[index]
        seg = seg[choice]

        # Get center point of 3D box
        if self.rotate_to_center:
            box3d_center = self.get_center_view_box3d_center(index)
        else:
            box3d_center = self.get_box3d_center(index)

        # Heading
        if self.rotate_to_center:
            heading_angle = self.heading_list[index] - rot_angle
        else:
            heading_angle = self.heading_list[index]

        # Size
        size_class, size_residual = size2class(self.size_list[index],
                                               self.type_list[index])

        # Data Augmentation
        if self.random_flip:
            # note: rot_angle won't be correct if we have random_flip
            # so do not use it in case of random flipping.
            if np.random.random() > 0.5:  # 50% chance flipping
                point_set[:, 0] *= -1
                box3d_center[0] *= -1
                heading_angle = np.pi - heading_angle
        if self.random_shift:
            dist = np.sqrt(np.sum(box3d_center[0] ** 2 + box3d_center[1] ** 2))
            shift = np.clip(np.random.randn() * dist * 0.05, dist * 0.8, dist * 1.2)
            point_set[:, 2] += shift
            box3d_center[2] += shift

        angle_class, angle_residual = angle2class(heading_angle,
                                                  NUM_HEADING_BIN)

        if self.one_hot:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle, one_hot_vec,label_mask
        else:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle,label_mask
    """
    def __getitem__(self, index):
        ''' Get index-th element from the picked file dataset. '''
        # ------------------------------ INPUTS ----------------------------
        #input_radar = self.dataset_kitti.get_radar(self.id_list[index])
        #input = self.dataset_kitti.get_lidar(self.id_list[index])
        #radar_mask = get_radar_mask(input, input_radar[self.radar_OI[index]].reshape(-1, 3))
        #num_point_fil = np.count_nonzero(radar_mask == 1)
        #radar_idx =np.argwhere(radar_mask==1)
        input = self.pc_lidar_list[index]

        pts_rect = input[:, 0:3]
        pts_intensity = input[:, 3:]
        if self.npoints < len(pts_rect):
            pts_depth = pts_rect[:, 2]
            pts_near_flag = pts_depth < 20.0
            far_idxs_choice = np.where(pts_near_flag == 0)[0]
            near_idxs = np.where(pts_near_flag == 1)[0]
            near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)

            choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
                if len(far_idxs_choice) > 0 else near_idxs_choice
            np.random.shuffle(choice)
        else:
            if (self.npoints / 2) > len(pts_rect):
                diff = int(self.npoints / 2 - len(pts_rect))
                add_pts = np.zeros((diff, 3), dtype=np.float32)
                add_int = np.zeros((diff, 3), dtype=np.float32)
                pts_rect = np.concatenate((pts_rect, add_pts), axis=0)
                pts_intensity = np.concatenate((pts_intensity, add_int), axis=0)
            choice = np.arange(0, len(pts_rect), dtype=np.int32)
            if self.npoints > len(pts_rect):
                extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
                choice = np.concatenate((choice, extra_choice), axis=0)
        np.random.shuffle(choice)
        ret_pts_rect = pts_rect[choice, :]
        # TODO don't use intensity feature or try a method to add rgb
        ret_pts_intensity = pts_intensity[choice]
        pts_features = [ret_pts_intensity.reshape(-1, 3)]
        ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]
        ret_pts_features = np.ones((len(ret_pts_rect),1))
        input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1)

        #radar_mask = get_radar_mask(input,input_radar)
        point_set=input
        type = "Pedestrian"
        if self.one_hot:
            cls_type = type
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[index]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
        #gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

        cls_label = np.zeros((input.shape[0]), dtype=np.int32)
        gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
        for k in range(gt_boxes3d.shape[0]):
            box_corners = gt_corners[k]
            fg_pt_flag = kitti_utils.in_hull(input[:, 0:3], box_corners)
            cls_label[fg_pt_flag] = k+1

        if(np.count_nonzero(cls_label>0) < 50):
            center = np.zeros((3))
            heading = 0.0
            size = np.zeros((3))
            frustum_angle = 0.0
            seg = cls_label[cls_label > 0] = 0
            #print("label_0:", np.count_nonzero(seg>0))
            rot_angle = 0.0
            #print("cls_label: ", np.count_nonzero(cls_label > 0))
            #seg = cls_label[cls_label > 0]= 1
            #print("seg: ",np.count_nonzero(cls_label == 0))
            box3d_center = np.zeros((3))

            if self.rotate_to_center:
                heading_angle = heading - rot_angle
            else:
                heading_angle = heading
            angle_class, angle_residual = angle2class(heading_angle,
                                                      NUM_HEADING_BIN)
            size_class, size_residual = size2class(size,
                                                   type)
            label_mask = 0.0
            size_residual = np.zeros((3))
            #print(" box3d_center, angle_class, angle_residual, size_class, size_residual, rot_angle, one_hot_vec, label_mask", box3d_center, angle_class, angle_residual, \
            #size_class, size_residual, rot_angle, one_hot_vec, label_mask)
            #print("noo zero point cloud: ", np.count_nonzero(input > 0.0))
        else:
            max = 0
            corners_max = 0
            for k in range(gt_boxes3d.shape[0]):
                count = np.count_nonzero(cls_label == k+1)
                if count > max:
                    max = count
                    corners_max = k
            # give the indice box to x generation
            center = gt_boxes3d[corners_max, 0:3]
            heading = gt_boxes3d[corners_max, 6]
            size = gt_boxes3d[corners_max, 3:6]
            # frustum angle = 0.0
            frustum_angle = 0.0

            if self.rotate_to_center:
                point_set = self.get_center_view_point_set(input, frustum_angle)
            else:
                point_set = input
            # ------------------------------ LABELS ----------------------------
            #print("cls_label: ", np.count_nonzero(cls_label > 0))
            #seg = cls_label[cls_label > 0]= 1
            seg = np.where(cls_label>0,1,0)
            #print("seg: ",np.count_nonzero(cls_label == 1) )
            if self.rotate_to_center:
                box3d_center = self.get_center_view_box3d_center(gt_corners[corners_max], frustum_angle)
            else:
                box3d_center = self.get_box3d_center(gt_corners[corners_max])

            rot_angle = 0.0
            if self.rotate_to_center:
                heading_angle = heading - rot_angle
            else:
                heading_angle = heading

            angle_class, angle_residual = angle2class(heading_angle,
                                                      NUM_HEADING_BIN)

            size_class, size_residual = size2class(size,
                                                   type)

            label_mask = 1.0
            #print(" box3d_center, angle_class, angle_residual,size_class, size_residual, rot_angle, one_hot_vec, label_mask", box3d_center, angle_class, angle_residual, \
            #size_class, size_residual, rot_angle, one_hot_vec, label_mask)
            #print("noo zero point cloud: ", np.count_nonzero(input > 0.0))
        return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle, one_hot_vec, label_mask


        '''
        gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[index]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
        gt_boxes3d= gt_boxes3d[self.box_present[index]-1].reshape(-1,7)
        corners3d = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)

        cls_label = np.zeros((input.shape[0]), dtype=np.int32)
        gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
        for k in range(gt_boxes3d.shape[0]):
            box_corners = gt_corners[k]
            fg_pt_flag = kitti_utils.in_hull(input[:, 0:3], box_corners)
            cls_label[fg_pt_flag] = 1

        type="Pedestrian"
        center = gt_boxes3d[:,0:3]
        #closest_radar_point = get_closest_radar_point(center,input_radar)
        heading = gt_boxes3d[:, 6]

        size=gt_boxes3d[:, 3:6]
        # frustum_angle with 0.0 populate
        frustum_angle=0.0


        rot_angle=0.0
        # Compute one hot vector
        if self.one_hot:
            cls_type = type
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        # Get point cloud
        if self.rotate_to_center:
            point_set = self.get_center_view_point_set(input,frustum_angle)
        else:
            point_set = input
        # Resample

        #print(point_set.shape[0],self.npoints)
        #choice = np.random.choice(point_set.shape[0], self.npoints, replace=True)
        #print(len(choice))
        #point_set = point_set[choice, :]

        if self.from_rgb_detection:
            if self.one_hot:
                return point_set, rot_angle, self.prob_list[index], one_hot_vec
            else:
                return point_set, rot_angle, self.prob_list[index]

        # ------------------------------ LABELS ----------------------------
        seg = cls_label
        #seg = seg[choice]
        #print("batch seg 3asba:", np.count_nonzero(seg == 1))

        # Get center point of 3D box
        if self.rotate_to_center:
            box3d_center = self.get_center_view_box3d_center(corners3d,frustum_angle)
        else:
            box3d_center = self.get_box3d_center(corners3d)

        # Heading
        if self.rotate_to_center:
            heading_angle = heading - rot_angle
        else:
            heading_angle = heading

        # Size
        size_class, size_residual = size2class(size,
                                               type)

        # Data Augmentation
        if self.random_flip:
            # note: rot_angle won't be correct if we have random_flip
            # so do not use it in case of random flipping.
            if np.random.random() > 0.5:  # 50% chance flipping
                point_set[:, 0] *= -1
                box3d_center[0] *= -1
                heading_angle = np.pi - heading_angle
        if self.random_shift:
            dist = np.sqrt(np.sum(box3d_center[0] ** 2 + box3d_center[1] ** 2))
            shift = np.clip(np.random.randn() * dist * 0.05, dist * 0.8, dist * 1.2)
            point_set[:, 2] += shift
            box3d_center[2] += shift
        print(heading_angle)
        angle_class, angle_residual = angle2class(heading_angle,
                                                  NUM_HEADING_BIN)
        #print(angle_class,angle_residual)
        
        rot_angle=0.0
        if self.one_hot:
            return point_set,one_hot_vec,rot_angle
        else:
            return point_set,rot_angle
        '''
    """
    def get_center_view_rot_angle(self, index):
        ''' Get the frustum rotation angle, it isshifted by pi/2 so that it
        can be directly used to adjust GT heading angle '''
        return np.pi / 2.0 + self.frustum_angle_list[index]

    def get_box3d_center(self, index):
        ''' Get the center (XYZ) of 3D bounding box. '''
        box3d_center = (self.box3d_list[index][0, :] + \
                        self.box3d_list[index][6, :]) / 2.0
        return box3d_center

    def get_center_view_box3d_center(self, index):
        ''' Frustum rotation of 3D bounding box center. '''
        box3d_center = (self.box3d_list[index][0, :] + \
                        self.box3d_list[index][6, :]) / 2.0
        return rotate_pc_along_y(np.expand_dims(box3d_center, 0), \
                                 self.get_center_view_rot_angle(index)).squeeze()

    def get_center_view_box3d(self, index):
        ''' Frustum rotation of 3D bounding box corners. '''
        box3d = self.box3d_list[index]
        box3d_center_view = np.copy(box3d)
        return rotate_pc_along_y(box3d_center_view, \
                                 self.get_center_view_rot_angle(index))

    def get_center_view_point_set(self, index):
        ''' Frustum rotation of point clouds.
        NxC points with first 3 channels as XYZ
        z is facing forward, x is left ward, y is downward
        '''
        # Use np.copy to avoid corrupting original data
        point_set = np.copy(self.input_list[index])
        return rotate_pc_along_y(point_set, \
                                 self.get_center_view_rot_angle(index))
    def __init__(self,
                 npoints,
                 database,
                 split,
                 res,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            dataset=database,
            mode='TRAIN',
            split=split)
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.res_det = res
        self.one_hot = one_hot

        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        elif (split == 'train'):
            """
            with open(overwritten_data_path,'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.box3d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.label_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                self.heading_list = pickle.load(fp)
                self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp) 
            """

            self.id_list = self.dataset_kitti.sample_id_list[:32]
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []

            perturb_box2d = True
            augmentX = 5
            for i in range(len(self.id_list)):
                #load pc
                print(self.id_list[i])
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #load_labels
                gt_obj_list_2D = self.dataset_kitti.get_label_2D(
                    self.id_list[i])
                ps = pc_lidar
                """gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=False)
                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = 1

                seg = cls_label
                fig = mlab.figure(figure=None, bgcolor=(0.4, 0.4, 0.4), fgcolor=None, engine=None, size=(1000, 500))
                mlab.points3d(ps[:, 0], ps[:, 1], ps[:, 2], seg, mode='point', colormap='gnuplot', scale_factor=1,
                              figure=fig)
                mlab.points3d(0, 0, 0, color=(1, 1, 1), mode='sphere', scale_factor=0.2, figure=fig)
                for s in range(len(gt_corners)):
                    center = np.array([gt_boxes3d[s][0], gt_boxes3d[s][1], gt_boxes3d[s][2]])
                    size = np.array([gt_boxes3d[s][3], gt_boxes3d[s][4], gt_boxes3d[s][5]])
                    rot_angle = gt_boxes3d[s][6]
                    box3d_from_label = get_3d_box(size, rot_angle,
                                                  center)
                    draw_gt_boxes3d([box3d_from_label], fig, color=(1, 0, 0))
                mlab.orientation_axes()
                raw_input()"""
                #load pixels
                pixels = get_pixels(self.id_list[i], split)
                for j in range(len(gt_obj_list_2D)):
                    for _ in range(augmentX):
                        # Augment data by box2d perturbation
                        if perturb_box2d:
                            box2d = random_shift_box2d(gt_obj_list_2D[j].box2d)
                        frus_pc, frus_pc_ind = extract_pc_in_box2d(
                            pc_lidar, pixels, box2d)
                        #get frus angle
                        center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                                 (box2d[1] + box2d[2]) / 2.0])
                        pc_center_frus = get_closest_pc_to_center(
                            pc_lidar, pixels, center_box2d)
                        frustum_angle = -np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])
                        #fig = plt.figure()
                        #ax = fig.add_subplot(111, projection="3d")
                        #ax.scatter(frus_pc[:, 0], frus_pc[:, 1], frus_pc[:, 2], c=frus_pc[:, 3:6], s=1)
                        #plt.show()

                        #get label list
                        gt_obj_list = self.dataset_kitti.get_label(
                            self.id_list[i])

                        cls_label = np.zeros((frus_pc.shape[0]),
                                             dtype=np.int32)
                        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                        gt_corners = kitti_utils.boxes3d_to_corners3d(
                            gt_boxes3d, transform=False)
                        for k in range(gt_boxes3d.shape[0]):
                            box_corners = gt_corners[k]
                            fg_pt_flag = kitti_utils.in_hull(
                                frus_pc[:, 0:3], box_corners)
                            cls_label[fg_pt_flag] = k + 1
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1.0, 0.0)

                        cls_label = seg
                        print("train", np.count_nonzero(cls_label == 1))
                        if box2d[3] - box2d[1] < 25 or np.sum(cls_label) == 0:
                            continue
                        self.input_list.append(frus_pc)
                        self.frustum_angle_list.append(frustum_angle)
                        self.label_list.append(cls_label)
                        self.box3d_list.append(gt_corners[corners_max])
                        self.box2d_list.append(box2d)
                        self.type_list.append("Pedestrian")
                        self.heading_list.append(gt_obj_list[corners_max].ry)
                        self.size_list.append(
                            np.array([
                                gt_obj_list[corners_max].h,
                                gt_obj_list[corners_max].w,
                                gt_obj_list[corners_max].l
                            ]))
                        batch_list.append(self.id_list[i])
            #estimate average pc input
            self.id_list = batch_list

            #estimate average labels
        elif (split == 'val' or split == 'test'):

            self.indice_box = []
            self.dataset_kitti.sample_id_list = self.dataset_kitti.sample_id_list[:
                                                                                  32]
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            for i in range(len(self.id_list)):
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                print(self.id_list[i])
                """ps = pc_lidar
                gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                cls_label = np.zeros((pc_lidar.shape[0]), dtype=np.int32)
                gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=False)
                for k in range(gt_boxes3d.shape[0]):
                    box_corners = gt_corners[k]
                    fg_pt_flag = kitti_utils.in_hull(pc_lidar[:, 0:3], box_corners)
                    cls_label[fg_pt_flag] = 1

                seg = cls_label
                fig = mlab.figure(figure=None, bgcolor=(0.4, 0.4, 0.4), fgcolor=None, engine=None, size=(1000, 500))
                mlab.points3d(ps[:, 0], ps[:, 1], ps[:, 2], seg, mode='point', colormap='gnuplot', scale_factor=1,
                              figure=fig)
                mlab.points3d(0, 0, 0, color=(1, 1, 1), mode='sphere', scale_factor=0.2, figure=fig)"""
                """for s in range(len(gt_corners)):
                    center = np.array([gt_boxes3d[s][0], gt_boxes3d[s][1], gt_boxes3d[s][2]])
                    size = np.array([gt_boxes3d[s][3], gt_boxes3d[s][4], gt_boxes3d[s][5]])
                    rot_angle = gt_boxes3d[s][6]
                    box3d_from_label = get_3d_box(size,rot_angle,
                                                  center)
                    draw_gt_boxes3d([box3d_from_label], fig, color=(1, 0, 0))
                mlab.orientation_axes()
                raw_input()"""
                #get val 2D boxes:
                box2ds = get_2Dboxes_detected(self.id_list[i], self.res_det,
                                              split)
                if box2ds == None:
                    print("what")
                    continue
                print("number detection", len(box2ds))
                pixels = get_pixels(self.id_list[i], split)
                for j in range(len(box2ds)):
                    box2d = box2ds[j]

                    if (box2d[3] - box2d[1]) < 25 or (
                        (box2d[3] > 720 and box2d[1] > 720)) or (
                            (box2d[0] > 1280 and box2d[2] > 1280)) or (
                                (box2d[3] <= 0
                                 and box2d[1] <= 0)) or (box2d[0] <= 0
                                                         and box2d[2] <= 0):
                        continue
                    print(box2d)
                    print("box_height", box2d[3] - box2d[1])
                    frus_pc, frus_pc_ind = extract_pc_in_box2d(
                        pc_lidar, pixels, box2d)
                    #fig = plt.figure()
                    #ax = fig.add_subplot(111, projection="3d")
                    #ax.scatter(frus_pc[:, 0], frus_pc[:, 1], frus_pc[:, 2], c=frus_pc[:, 3:6], s=1)
                    #plt.show()
                    # get frus angle
                    center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                             (box2d[1] + box2d[2]) / 2.0])
                    pc_center_frus = get_closest_pc_to_center(
                        pc_lidar, pixels, center_box2d)
                    frustum_angle = -1 * np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])

                    if len(frus_pc) < 20:
                        continue

                    # get_labels
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                    cls_label = np.zeros((frus_pc.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=False)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1
                    print("gt in frus", np.count_nonzero(cls_label > 0))
                    if (np.count_nonzero(cls_label > 0) < 20):
                        center = np.ones((3)) * (-10.0)
                        heading = 0.0
                        size = np.ones((3))
                        cls_label[cls_label > 0] = 0
                        seg = cls_label
                        rot_angle = 0.0
                        box3d_center = np.ones((3)) * (-1.0)
                        box3d = np.array([[
                            box3d_center[0], box3d_center[1], box3d_center[2],
                            size[0], size[1], size[2], rot_angle
                        ]])
                        corners_empty = kitti_utils.boxes3d_to_corners3d(
                            box3d, transform=False)
                        bb_corners = corners_empty[0]
                        self.indice_box.append(0)
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1, 0)
                        self.indice_box.append(corners_max + 1)
                        print("val:", np.count_nonzero(cls_label == 1))
                        bb_corners = gt_corners[corners_max]
                        obj = gt_boxes3d[corners_max]
                        center = np.array([obj[0], obj[1], obj[2]])
                        size = np.array([obj[3], obj[4], obj[5]])
                        rot_angle = obj[6]
                    self.input_list.append(frus_pc)
                    self.frustum_angle_list.append(frustum_angle)
                    self.label_list.append(seg)
                    self.box3d_list.append(bb_corners)
                    self.box2d_list.append(box2d)
                    self.type_list.append("Pedestrian")
                    self.heading_list.append(rot_angle)
                    self.size_list.append(size)
                    batch_list.append(self.id_list[i])
            self.id_list = batch_list
            print("batch_list", batch_list)
Beispiel #6
0
                        augmented_lidar_cam_coords)
                    loss = criterion(predicted_locs, predicted_scores, boxes,
                                     classes)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * bat_size
    return model  #, val_acc_history


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ssd = SSD(resnet_type=34, n_classes=2).to(device)
trainset = KittiDataset(root="/hdd/KITTY/", mode="training", valid=False)
valset = KittiDataset(root="/hdd/KITTY/", mode="training", valid=True)

datasets = {'train': trainset, 'val': valset}
dataloaders_dict = {
    x: DataLoader(datasets[x],
                  batch_size=4,
                  shuffle=True,
                  collate_fn=datasets[x].collate_fn,
                  num_workers=0,
                  drop_last=True)
    for x in ['train', 'val']
}

optimizer_ft = torch.optim.SGD(ssd.parameters(), lr=0.0001, momentum=0.9)
criterion = MultiBoxLoss(priors_cxcy=ssd.priors_cxcy).to(device)
                                   lr_scheduler, total_it, tb_log, log_f)

        if epoch % args.ckpt_save_interval == 0:
            with torch.no_grad():
                avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log,
                                         log_f)
                ckpt_name = os.path.join(ckpt_dir,
                                         'checkpoint_epoch_%d' % epoch)
                save_checkpoint(model, epoch, ckpt_name)


if __name__ == '__main__':
    MODEL = importlib.import_module(args.net)  # import network module
    model = MODEL.get_model(input_channels=0)

    eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val')
    eval_loader = DataLoader(eval_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=args.workers,
                             collate_fn=eval_set.collate_batch)

    if args.mode == 'train':
        train_set = KittiDataset(root_dir='./data',
                                 mode='TRAIN',
                                 split='train')
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
Beispiel #8
0
                        augmented_lidar_cam_coords)
                    loss = criterion(predicted_locs, predicted_scores, boxes,
                                     classes)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * bat_size
    return model  #, val_acc_history


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ssd = SSD(resnet_type=34, n_classes=2).to(device)
trainset = KittiDataset(root="/repository/KITTI", mode="training", valid=False)
valset = KittiDataset(root="/repository/KITTI", mode="training", valid=True)

datasets = {'train': trainset, 'val': valset}
dataloaders_dict = {
    x: DataLoader(datasets[x],
                  batch_size=4,
                  shuffle=True,
                  collate_fn=datasets[x].collate_fn,
                  num_workers=0,
                  drop_last=True)
    for x in ['train', 'val']
}

optimizer_ft = torch.optim.SGD(ssd.parameters(), lr=0.0001, momentum=0.9)
criterion = MultiBoxLoss(priors_cxcy=ssd.priors_cxcy).to(device)
Beispiel #9
0
                batch_num += 1

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    predicted_locs, predicted_scores, _ = model(augmented_lidar_cam_coords)
                    loss = criterion(predicted_locs, predicted_scores, boxes, classes)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * bat_size
    return model #, val_acc_history
    
device = torch.device("cuda:0")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ssd = SSD(resnet_type=34, n_classes=2).to(device)
trainset = KittiDataset(root="/Users/xymbiotec/Desktop/painting-master/work", mode="training", valid=False)
valset = KittiDataset(root="/Users/xymbiotec/Desktop/painting-master/work", mode="training", valid=True)

datasets = {'train': trainset, 'val': valset}
dataloaders_dict = {x: DataLoader(datasets[x], batch_size=4, shuffle=True, collate_fn=datasets[x].collate_fn, num_workers=0, drop_last=True) for x in ['train', 'val']}

optimizer_ft = torch.optim.SGD(ssd.parameters(), lr=0.0001, momentum=0.9)
criterion = MultiBoxLoss(priors_cxcy=ssd.priors_cxcy).to(device)

ssd = train_model(ssd, dataloaders_dict, criterion, optimizer_ft, num_epochs=10)
torch.save(ssd.state_dict(), './pointpillars.pth')
class FrustumDataset(object):
    ''' Dataset class for Frustum PointNets training/evaluation.
    Load prepared KITTI data from pickled files, return individual data element
    [optional] along with its annotations.
    '''
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False,
                 generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)

        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #with open(overwritten_data_path, 'rb') as fp:
            #load list of frames
            self.id_list = self.dataset_kitti.sample_id_list
            print("id_list", len(self.id_list))
            #fil = np.zeros((len(self.id_list)))
            #for i in range(len(self.id_list)):
            #    print(self.id_list[i])
            #    gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[i]))
            #    print(len(gt_obj_list))
            #gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            #print(gt_boxes3d)
            #    if(len(gt_obj_list)==1):
            #        fil[i]=1

            #self.id_list= np.extract(fil,self.id_list)
            if (generate_database):
                self.index_batch = []
                self.label_present = []
                self.radar_OI = []
                self.radar_mask_len = []
                self.cls_labels_len = []

                for i in range(len(self.id_list)):
                    pc_input = self.dataset_kitti.get_lidar(self.id_list[i])
                    pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))

                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    corners3d = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)

                    cls_label = np.zeros((pc_input.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            pc_input[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1
                    print("indice", self.id_list[i])
                    print("number of boxes", gt_boxes3d.shape[0])
                    print("cls_label nunber", np.count_nonzero(cls_label > 0))
                    print("number of radar pts present in frame",
                          len(pc_radar))
                    for j in range(len(pc_radar)):
                        radar_mask = get_radar_mask(pc_input,
                                                    pc_radar[j].reshape(-1, 3))
                        # check label present in radar ROI
                        print("radar_mask", np.count_nonzero(radar_mask == 1))
                        label_radar_intersection = radar_mask * cls_label

                        print("intersection",
                              np.count_nonzero(label_radar_intersection > 0))
                        # if there's one
                        labels_present = []
                        for m in range(gt_boxes3d.shape[0]):
                            if (np.isin(m + 1, label_radar_intersection)):
                                labels_present.append(m + 1)
                        print("labels present", labels_present)
                        if (len(labels_present) == 1):
                            self.radar_mask_len.append(
                                np.count_nonzero(radar_mask == 1))
                            self.index_batch.append(self.id_list[i])
                            self.radar_OI.append(j)
                            self.label_present.append(labels_present[0])
                            self.cls_labels_len.append(
                                np.count_nonzero(label_radar_intersection > 0))
                print("retained indices", self.index_batch)
                print("len retained indices", len(self.index_batch))
                # keep this as a batch
                # if there's isn't
                # than forget about it
                with open('radar_batches_stats_' + split + '.csv',
                          mode='w') as csv_file:
                    fieldnames = [
                        'index_batch', 'radar_mask_len', 'radar_OI',
                        'box_present', 'cls_labels_len'
                    ]
                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
                    writer.writeheader()
                    for i in range(len(self.index_batch)):
                        writer.writerow({
                            'index_batch':
                            self.index_batch[i],
                            'radar_mask_len':
                            self.radar_mask_len[i],
                            'radar_OI':
                            self.radar_OI[i],
                            'box_present':
                            self.box_present[i],
                            'cls_labels_len':
                            self.cls_labels_len[i]
                        })
                self.id_list = self.index_batch
                print("id_list_filtered", len(self.id_list))
            else:
                database_infos = pandas.read_csv(
                    '/root/frustum-pointnets_RSC_RADAR_fil_PC_batch/train/radar_batches_stats_'
                    + split + '.csv')
                self.id_list = database_infos['index_batch']
                self.radar_OI = database_infos['radar_OI']
                self.box_present = database_infos['box_present']
            """
                self.input_list=[]
                self.box3d_list=[]
                self.label_list=[]
                self.type_list=[]
                self.heading_list=[]
                self.size_list=[]
                self.frustum_angle_list=[]
                for i in range(len(self.id_list)):

                    #BOX3D_IN_CORNERS FORMAT
                    gt_obj_list = dataset_kitti.filtrate_objects(dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    print(gt_boxes3d)
                    self.box3d_list.append(kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True))

                    #INPUT_DATA_LIST
                    input = dataset_kitti.get_lidar(self.id_list[i])
                    self.input_list.append(input)

                    #LABEL_LIST
                    cls_label = np.zeros((self.input_list[i].shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(self.input_list[i][:,0:3], box_corners)
                        cls_label[fg_pt_flag] = 1

                    #print(cls_label.shape)
                    print("cls_label", (np.count_nonzero(cls_label == 1)))
                    
                    label_pts = np.ndarray((cls_label_count, 3))
                    j = 0
                    c = np.ndarray((len(input), 3))

                    for i in range(len(input)):
                        if (cls_label[i] == 1):
                            c[i] = np.array([1.0, 0.0, 0.0])
                            label_pts[j] = input[i,0:3]
                            j = j + 1
                        else:
                            c[i] = np.array([0.0, 0.0, 1.0])

                    fig = plt.figure()
                    ax = fig.add_subplot(111, projection="3d")
                    ax.scatter(label_pts[:, 0], label_pts[:, 1], label_pts[:, 2])
                    plt.show()

                    fig = plt.figure()
                    ax = fig.add_subplot(111, projection="3d")
                    ax.scatter(input[:, 0], input[:, 1], input[:, 2], c=c, s=1)
                    plt.show()
                    
                    self.label_list.append(cls_label)

                    #TYPE_LIST
                    self.type_list.append("Pedestrian")
                    #HEADING_LIST
                    self.heading_list.append(gt_boxes3d[:,6])

                    #SIZE_LIST l,w,h
                    self.size_list.append(gt_boxes3d[:,3:6])
                    #frustum_angle with 0.0 populate
                    self.frustum_angle_list.append(0.0)
                """
            # box2d in corners format
            #self.box2d_list = pickle.load(fp)
            # box3d in corners format
            #self.box3d_list = pickle.load(fp)
            # point cloud, hole or frustum filtered? looks like frustrum filtered because number of pc is too small
            #self.input_list = pickle.load(fp)
            # from frustrum point cloud which one belongs to label
            #self.label_list = pickle.load(fp)
            # for each 2d box/frustrum point cloud, detected object
            #self.type_list = pickle.load(fp)
            # rotation of 3d label box (ry)
            #self.heading_list = pickle.load(fp)
            # array of l,w,h
            #self.size_list = pickle.load(fp)
            # frustum_angle is clockwise angle from positive x-axis
            #self.frustum_angle_list = pickle.load(fp)

    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, index):
        ''' Get index-th element from the picked file dataset. '''
        # ------------------------------ INPUTS ----------------------------
        #rot_angle = self.get_center_view_rot_angle(index)
        # load radar points
        input_radar = self.dataset_kitti.get_radar(self.id_list[index])
        input = self.dataset_kitti.get_lidar(self.id_list[index])
        radar_mask = get_radar_mask(
            input, input_radar[self.radar_OI[index]].reshape(-1, 3))
        num_point_fil = np.count_nonzero(radar_mask == 1)
        radar_idx = np.argwhere(radar_mask == 1)
        input = input[radar_idx.reshape(-1)]
        print(input.shape)

        pts_rect = input[:, 0:3]
        pts_intensity = input[:, 3:]
        if self.npoints < len(pts_rect):

            # print(len(pts_rect))
            print(pts_rect.shape)
            pts_depth = pts_rect[:, 2]
            pts_near_flag = pts_depth < 20.0
            far_idxs_choice = np.where(pts_near_flag == 0)[0]
            near_idxs = np.where(pts_near_flag == 1)[0]
            # print(len(pts_depth),len(far_idxs_choice),len(near_idxs),self.npoints, self.npoints - len(far_idxs_choice))
            near_idxs_choice = np.random.choice(near_idxs,
                                                self.npoints -
                                                len(far_idxs_choice),
                                                replace=False)

            choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
                if len(far_idxs_choice) > 0 else near_idxs_choice
            np.random.shuffle(choice)
        else:
            if (self.npoints / 2) > len(pts_rect):
                diff = int(self.npoints / 2 - len(pts_rect))
                # print(diff)
                add_pts = np.zeros((diff, 3), dtype=np.float32)
                add_int = np.zeros((diff, 3), dtype=np.float32)
                # print("add_int", add_int[0])
                pts_rect = np.concatenate((pts_rect, add_pts), axis=0)
                pts_intensity = np.concatenate((pts_intensity, add_int),
                                               axis=0)
            choice = np.arange(0, len(pts_rect), dtype=np.int32)
            if self.npoints > len(pts_rect):
                # print(len(pts_rect),self.npoints - len(pts_rect))
                extra_choice = np.random.choice(choice,
                                                self.npoints - len(pts_rect),
                                                replace=False)
                choice = np.concatenate((choice, extra_choice), axis=0)
        np.random.shuffle(choice)
        # print(len(pts_rect))
        ret_pts_rect = pts_rect[choice, :]
        # ret_pts_rect=pts_rect
        # TODO don't use intensity feature or try a method to add rgb
        # ret_pts_intensity = pts_intensity[choice] - 0.5  # translate intensity to [-0.5, 0.5]
        ret_pts_intensity = pts_intensity[choice]
        pts_features = [ret_pts_intensity.reshape(-1, 3)]
        ret_pts_features = np.concatenate(
            pts_features,
            axis=1) if pts_features.__len__() > 1 else pts_features[0]
        input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1)

        radar_mask = get_radar_mask(input, input_radar)

        gt_obj_list = self.dataset_kitti.filtrate_objects(
            self.dataset_kitti.get_label(self.id_list[index]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
        gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)
        corners3d = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                     transform=True)

        cls_label = np.zeros((input.shape[0]), dtype=np.int32)
        gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d,
                                                      transform=True)
        for k in range(gt_boxes3d.shape[0]):
            box_corners = gt_corners[k]
            fg_pt_flag = kitti_utils.in_hull(input[:, 0:3], box_corners)
            cls_label[fg_pt_flag] = 1

        type = "Pedestrian"
        center = gt_boxes3d[:, 0:3]
        #closest_radar_point = get_closest_radar_point(center,input_radar)
        heading = gt_boxes3d[:, 6]

        size = gt_boxes3d[:, 3:6]
        # frustum_angle with 0.0 populate
        frustum_angle = 0.0

        rot_angle = 0.0
        # Compute one hot vector
        if self.one_hot:
            cls_type = type
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        # Get point cloud
        if self.rotate_to_center:
            point_set = self.get_center_view_point_set(input, frustum_angle)
        else:
            point_set = input
        # Resample

        #print(point_set.shape[0],self.npoints)
        #choice = np.random.choice(point_set.shape[0], self.npoints, replace=True)
        #print(len(choice))
        #point_set = point_set[choice, :]

        if self.from_rgb_detection:
            if self.one_hot:
                return point_set, rot_angle, self.prob_list[index], one_hot_vec
            else:
                return point_set, rot_angle, self.prob_list[index]

        # ------------------------------ LABELS ----------------------------
        seg = cls_label
        #seg = seg[choice]
        #print("batch seg 3asba:", np.count_nonzero(seg == 1))

        # Get center point of 3D box
        if self.rotate_to_center:
            box3d_center = self.get_center_view_box3d_center(
                corners3d, frustum_angle)
        else:
            box3d_center = self.get_box3d_center(corners3d)

        # Heading
        if self.rotate_to_center:
            heading_angle = heading - rot_angle
        else:
            heading_angle = heading

        # Size
        size_class, size_residual = size2class(size, type)

        # Data Augmentation
        if self.random_flip:
            # note: rot_angle won't be correct if we have random_flip
            # so do not use it in case of random flipping.
            if np.random.random() > 0.5:  # 50% chance flipping
                point_set[:, 0] *= -1
                box3d_center[0] *= -1
                heading_angle = np.pi - heading_angle
        if self.random_shift:
            dist = np.sqrt(np.sum(box3d_center[0]**2 + box3d_center[1]**2))
            shift = np.clip(np.random.randn() * dist * 0.05, dist * 0.8,
                            dist * 1.2)
            point_set[:, 2] += shift
            box3d_center[2] += shift
        print(heading_angle)
        angle_class, angle_residual = angle2class(heading_angle,
                                                  NUM_HEADING_BIN)
        #print(angle_class,angle_residual)
        if self.one_hot:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle, one_hot_vec,radar_mask
        else:
            return point_set, seg, box3d_center, angle_class, angle_residual, \
                   size_class, size_residual, rot_angle,radar_mask

    def get_center_view_rot_angle(self, frustrum_angle):
        ''' Get the frustum rotation angle, it isshifted by pi/2 so that it
        can be directly used to adjust GT heading angle '''
        return 0.0  #np.pi / 2.0 + self.frustum_angle_list[index]

    def get_box3d_center(self, corners3d):
        ''' Get the center (XYZ) of 3D bounding box. '''
        corners3d = corners3d.reshape((8, 3))
        box3d_center = (corners3d[0, :] + \
                        corners3d[6, :]) / 2.0
        return box3d_center

    def get_center_view_box3d_center(self, box3d, frustrum_angle):
        ''' Frustum rotation of 3D bounding box center. '''
        box3d = box3d.reshape((8, 3))
        box3d_center = (box3d[0, :] + box3d[6, :]) / 2.0
        rotate_pc_along_y(
            np.expand_dims(box3d_center, 0),
            self.get_center_view_rot_angle(frustrum_angle)).squeeze()

        return rotate_pc_along_y(
            np.expand_dims(box3d_center, 0),
            self.get_center_view_rot_angle(frustrum_angle)).squeeze()

    def get_center_view_box3d(self, index):
        ''' Frustum rotation of 3D bounding box corners. '''
        box3d = self.box3d_list[index]
        box3d_center_view = np.copy(box3d)
        return rotate_pc_along_y(box3d_center_view, \
                                 self.get_center_view_rot_angle(index))

    def get_center_view_point_set(self, input, frustrum_angle):
        ''' Frustum rotation of point clouds.
        NxC points with first 3 channels as XYZ
        z is facing forward, x is left ward, y is downward
        '''
        # Use np.copy to avoid corrupting original data
        point_set = np.copy(input)
        return rotate_pc_along_y(point_set, \
                                 self.get_center_view_rot_angle(frustrum_angle))
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False,
                 generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)

        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #with open(overwritten_data_path, 'rb') as fp:
            #load list of frames
            self.id_list = self.dataset_kitti.sample_id_list
            print("id_list", len(self.id_list))
            #fil = np.zeros((len(self.id_list)))
            #for i in range(len(self.id_list)):
            #    print(self.id_list[i])
            #    gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[i]))
            #    print(len(gt_obj_list))
            #gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            #print(gt_boxes3d)
            #    if(len(gt_obj_list)==1):
            #        fil[i]=1

            #self.id_list= np.extract(fil,self.id_list)
            if (generate_database):
                self.index_batch = []
                self.label_present = []
                self.radar_OI = []
                self.radar_mask_len = []
                self.cls_labels_len = []

                for i in range(len(self.id_list)):
                    pc_input = self.dataset_kitti.get_lidar(self.id_list[i])
                    pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))

                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    corners3d = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)

                    cls_label = np.zeros((pc_input.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            pc_input[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1
                    print("indice", self.id_list[i])
                    print("number of boxes", gt_boxes3d.shape[0])
                    print("cls_label nunber", np.count_nonzero(cls_label > 0))
                    print("number of radar pts present in frame",
                          len(pc_radar))
                    for j in range(len(pc_radar)):
                        radar_mask = get_radar_mask(pc_input,
                                                    pc_radar[j].reshape(-1, 3))
                        # check label present in radar ROI
                        print("radar_mask", np.count_nonzero(radar_mask == 1))
                        label_radar_intersection = radar_mask * cls_label

                        print("intersection",
                              np.count_nonzero(label_radar_intersection > 0))
                        # if there's one
                        labels_present = []
                        for m in range(gt_boxes3d.shape[0]):
                            if (np.isin(m + 1, label_radar_intersection)):
                                labels_present.append(m + 1)
                        print("labels present", labels_present)
                        if (len(labels_present) == 1):
                            self.radar_mask_len.append(
                                np.count_nonzero(radar_mask == 1))
                            self.index_batch.append(self.id_list[i])
                            self.radar_OI.append(j)
                            self.label_present.append(labels_present[0])
                            self.cls_labels_len.append(
                                np.count_nonzero(label_radar_intersection > 0))
                print("retained indices", self.index_batch)
                print("len retained indices", len(self.index_batch))
                # keep this as a batch
                # if there's isn't
                # than forget about it
                with open('radar_batches_stats_' + split + '.csv',
                          mode='w') as csv_file:
                    fieldnames = [
                        'index_batch', 'radar_mask_len', 'radar_OI',
                        'box_present', 'cls_labels_len'
                    ]
                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
                    writer.writeheader()
                    for i in range(len(self.index_batch)):
                        writer.writerow({
                            'index_batch':
                            self.index_batch[i],
                            'radar_mask_len':
                            self.radar_mask_len[i],
                            'radar_OI':
                            self.radar_OI[i],
                            'box_present':
                            self.box_present[i],
                            'cls_labels_len':
                            self.cls_labels_len[i]
                        })
                self.id_list = self.index_batch
                print("id_list_filtered", len(self.id_list))
            else:
                database_infos = pandas.read_csv(
                    '/root/frustum-pointnets_RSC_RADAR_fil_PC_batch/train/radar_batches_stats_'
                    + split + '.csv')
                self.id_list = database_infos['index_batch']
                self.radar_OI = database_infos['radar_OI']
                self.box_present = database_infos['box_present']
            """
Beispiel #12
0
                    print(f'phase is {phase} and batch is {batch_num}.')
                batch_num += 1

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    predicted_locs, predicted_scores, _ = model(augmented_lidar_cam_coords)
                    loss = criterion(predicted_locs, predicted_scores, boxes, classes)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * bat_size
    return model #, val_acc_history
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ssd = SSD(resnet_type=34, n_classes=2).to(device)
trainset = KittiDataset(root="/home/jovyan/work", mode="training", valid=False)
valset = KittiDataset(root="/home/jovyan/work", mode="training", valid=True)

datasets = {'train': trainset, 'val': valset}
dataloaders_dict = {x: DataLoader(datasets[x], batch_size=4, shuffle=True, collate_fn=datasets[x].collate_fn, num_workers=0, drop_last=True) for x in ['train', 'val']}

optimizer_ft = torch.optim.SGD(ssd.parameters(), lr=0.0001, momentum=0.9)
criterion = MultiBoxLoss(priors_cxcy=ssd.priors_cxcy).to(device)

ssd = train_model(ssd, dataloaders_dict, criterion, optimizer_ft, num_epochs=10)
torch.save(ssd.state_dict(), './pointpillars.pth')
Beispiel #13
0
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    predicted_locs, predicted_scores, _ = model(augmented_lidar_cam_coords)
                    loss = criterion(predicted_locs, predicted_scores, boxes, classes)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * bat_size
    return model  # val_acc_history


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ssd = SSD(resnet_type=34, n_classes=2).to(device)
trainset = KittiDataset(root="D:/AI/KITTI", mode="training", valid=False)
valset = KittiDataset(root="D:/AI/KITTI", mode="testing", valid=False)

datasets = {'train': trainset, 'val': valset}
dataloaders_dict = {x: DataLoader(datasets[x], batch_size=4, shuffle=True, collate_fn=datasets[x].collate_fn,
                                  num_workers=0, drop_last=True) for x in ['train', 'val']}

optimizer_ft = torch.optim.SGD(ssd.parameters(), lr=0.0001, momentum=0.9)
criterion = MultiBoxLoss(priors_cxcy=ssd.priors_cxcy).to(device)

ssd = train_model(ssd, dataloaders_dict, criterion, optimizer_ft, num_epochs=10)
torch.save(ssd.state_dict(), './pointpillars.pth')
from dataset import KittiDataset
from collections import Counter
import kitti_utils
import csv
import pandas
from pypcd import pypcd
try:
    raw_input  # Python 2
except NameError:
    raw_input = input  # Python 3

if __name__ == '__main__':

    dataset_kitti = KittiDataset(
        "pc_radar_f3_vox",
        root_dir='/home/amben/frustum-pointnets_RSC/dataset/',
        mode='TRAIN',
        split="trainval")
    id_list = dataset_kitti.sample_id_list
    det_obj = []
    present_obj = []
    radar_pts = []
    for i in range(len(id_list)):
        pc_radar = dataset_kitti.get_radar(id_list[i])
        gt_obj_list = dataset_kitti.filtrate_objects(
            dataset_kitti.get_label(id_list[i]))
        gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)

        gt_boxes3d[:, 3] = gt_boxes3d[:, 3] + 2
        gt_boxes3d[:, 4] = gt_boxes3d[:, 4] + 2
        gt_boxes3d[:, 5] = gt_boxes3d[:, 5] + 2
class FrustumDataset(object):
    ''' Dataset class for Frustum PointNets training/evaluation.
    Load prepared KITTI data from pickled files, return individual data element
    [optional] along with its annotations.
    '''
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/home/amben/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        elif (split == 'train'):
            """
            with open(overwritten_data_path,'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.box3d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.label_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                self.heading_list = pickle.load(fp)
                self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp) 
            """
            pos_cnt = 0
            all_cnt = 0
            #self.dataset_kitti.sample_id_list=self.dataset_kitti.sample_id_list[0:10]
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []

            perturb_box2d = True
            augmentX = 1
            for i in range(len(self.id_list)):
                #load pc
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #load_labels
                gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                #load pixels
                pixels = get_pixels(self.id_list[i])
                for j in range(len(gt_obj_list)):
                    for _ in range(augmentX):
                        # Augment data by box2d perturbation
                        if perturb_box2d:
                            box2d = random_shift_box2d(gt_obj_list[j].box2d)
                        frus_pc, frus_pc_ind = extract_pc_in_box2d(
                            pc_lidar, pixels, box2d)

                        #get frus angle
                        center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                                 (box2d[1] + box2d[2]) / 2.0])
                        pc_center_frus = get_closest_pc_to_center(
                            pc_lidar, pixels, center_box2d)
                        frustum_angle = -1 * np.arctan2(
                            pc_center_frus[2], pc_center_frus[0])

                        #get label list
                        cls_label = np.zeros((frus_pc.shape[0]),
                                             dtype=np.int32)
                        gt_boxes3d = kitti_utils.objs_to_boxes3d(
                            [gt_obj_list[j]])
                        gt_corners = kitti_utils.boxes3d_to_corners3d(
                            gt_boxes3d, transform=True)
                        box_corners = gt_corners[0]
                        print(box_corners.shape)
                        print(pc_center_frus.shape)
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = 1
                        if box2d[3] - box2d[1] < 25 or np.sum(cls_label) == 0:
                            continue
                        self.input_list.append(frus_pc)
                        self.frustum_angle_list.append(frustum_angle)
                        self.label_list.append(cls_label)
                        self.box3d_list.append(box_corners)
                        self.box2d_list.append(box2d)
                        self.type_list.append("Pedestrian")
                        self.heading_list.append(gt_obj_list[j].ry)
                        self.size_list.append(
                            np.array([
                                gt_obj_list[j].l, gt_obj_list[j].w,
                                gt_obj_list[j].h
                            ]))
                        batch_list.append(self.id_list[i])
                        pos_cnt += np.sum(cls_label)
                        all_cnt += frus_pc.shape[0]

            #estimate average pc input
            self.id_list = batch_list
            print('Average pos ratio: %f' % (pos_cnt / float(all_cnt)))
            print('Average npoints: %f' % (float(all_cnt) / len(self.id_list)))
            #estimate average labels
        elif (split == 'val'):
            self.dataset_kitti.sample_id_list = self.dataset_kitti.sample_id_list
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            for i in range(len(self.id_list)):
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #get 2D boxes:
                box2ds = get_2Dboxes_detected(self.id_list[i])
                if box2ds == None:
                    continue
                pixels = get_pixels(self.id_list[i])
                for j in range(len(box2ds)):
                    box2d = box2ds[j]
                    frus_pc, frus_pc_ind = extract_pc_in_box2d(
                        pc_lidar, pixels, box2d)
                    # get frus angle
                    center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                             (box2d[1] + box2d[2]) / 2.0])
                    pc_center_frus = get_closest_pc_to_center(
                        pc_lidar, pixels, center_box2d)
                    frustum_angle = -1 * np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])

                    if (box2d[3] - box2d[1]) < 25 or len(frus_pc) < 50:
                        continue

                    # get_labels
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                    cls_label = np.zeros((frus_pc.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1

                    if (np.count_nonzero(cls_label > 0) < 20):
                        center = np.ones((3)) * (-1.0)
                        heading = 0.0
                        size = np.ones((3))
                        cls_label[cls_label > 0] = 0
                        seg = cls_label
                        rot_angle = 0.0
                        box3d_center = np.ones((3)) * (-1.0)
                        box3d = np.array([[
                            box3d_center[0], box3d_center[1], box3d_center[2],
                            size[0], size[1], size[2], rot_angle
                        ]])
                        corners_empty = kitti_utils.boxes3d_to_corners3d(
                            box3d, transform=True)
                        bb_corners = corners_empty[0]
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1, 0)
                        bb_corners = gt_corners[corners_max]
                        obj = gt_boxes3d[k]
                        center = np.array([obj[0], obj[1], obj[2]])
                        size = np.array([obj[3], obj[4], obj[5]])
                        print("size", size)
                        rot_angle = obj[6]

                    self.input_list.append(frus_pc)
                    count = 0
                    for c in range(len(self.input_list)):
                        count += self.input_list[c].shape[0]
                    print("average number of cloud:",
                          count / len(self.input_list))
                    self.frustum_angle_list.append(frustum_angle)
                    self.label_list.append(seg)
                    self.box3d_list.append(bb_corners)
                    self.box2d_list.append(box2d)
                    self.type_list.append("Pedestrian")
                    self.heading_list.append(rot_angle)
                    self.size_list.append(size)
                    batch_list.append(self.id_list[i])
            self.id_list = batch_list

    def __len__(self):
        return len(self.input_list)

    def __getitem__(self, index):
        ''' Get index-th element from the picked file dataset. '''
        # ------------------------------ INPUTS ----------------------------
        rot_angle = self.get_center_view_rot_angle(index)

        # Compute one hot vector
        if self.one_hot:
            cls_type = self.type_list[index]
            assert (cls_type in ['Car', 'Pedestrian', 'Cyclist'])
            one_hot_vec = np.zeros((3))
            one_hot_vec[g_type2onehotclass[cls_type]] = 1

        # Get point cloud
        if self.rotate_to_center:
            point_set = self.get_center_view_point_set(index)
        else:
            point_set = self.input_list[index]
        # Resample
        choice = np.random.choice(point_set.shape[0],
                                  self.npoints,
                                  replace=True)
        point_set = point_set[choice, :]

        if self.from_rgb_detection:
            if self.one_hot:
                return point_set, rot_angle, self.prob_list[index], one_hot_vec
            else:
                return point_set, rot_angle, self.prob_list[index]

        # ------------------------------ LABELS ----------------------------
        seg = self.label_list[index]
        seg = seg[choice]

        # Get center point of 3D box
        if self.rotate_to_center:
            box3d_center = self.get_center_view_box3d_center(index)
        else:
            box3d_center = self.get_box3d_center(index)

        # Heading
        if self.rotate_to_center:
            heading_angle = self.heading_list[index] - rot_angle
        else:
            heading_angle = self.heading_list[index]

        # Size
        size_class, size_residual = size2class(self.size_list[index],
                                               self.type_list[index])

        # Data Augmentation
        if self.random_flip:
            # note: rot_angle won't be correct if we have random_flip
            # so do not use it in case of random flipping.
            if np.random.random() > 0.5:  # 50% chance flipping
                point_set[:, 0] *= -1
                box3d_center[0] *= -1
                heading_angle = np.pi - heading_angle
        if self.random_shift:
            dist = np.sqrt(np.sum(box3d_center[0]**2 + box3d_center[1]**2))
            shift = np.clip(np.random.randn() * dist * 0.05, dist * 0.8,
                            dist * 1.2)
            point_set[:, 2] += shift
            box3d_center[2] += shift

        angle_class, angle_residual = angle2class(heading_angle,
                                                  NUM_HEADING_BIN)

        if self.one_hot:
            return point_set, seg, box3d_center, angle_class, angle_residual,\
                size_class, size_residual, rot_angle, one_hot_vec
        else:
            return point_set, seg, box3d_center, angle_class, angle_residual,\
                size_class, size_residual, rot_angle

    def get_center_view_rot_angle(self, index):
        ''' Get the frustum rotation angle, it isshifted by pi/2 so that it
        can be directly used to adjust GT heading angle '''
        return np.pi / 2.0 + self.frustum_angle_list[index]

    def get_box3d_center(self, index):
        ''' Get the center (XYZ) of 3D bounding box. '''
        box3d_center = (self.box3d_list[index][0,:] + \
            self.box3d_list[index][6,:])/2.0
        return box3d_center

    def get_center_view_box3d_center(self, index):
        ''' Frustum rotation of 3D bounding box center. '''
        box3d_center = (self.box3d_list[index][0,:] + \
            self.box3d_list[index][6,:])/2.0
        return rotate_pc_along_y(np.expand_dims(box3d_center,0), \
            self.get_center_view_rot_angle(index)).squeeze()

    def get_center_view_box3d(self, index):
        ''' Frustum rotation of 3D bounding box corners. '''
        box3d = self.box3d_list[index]
        box3d_center_view = np.copy(box3d)
        return rotate_pc_along_y(box3d_center_view, \
            self.get_center_view_rot_angle(index))

    def get_center_view_point_set(self, index):
        ''' Frustum rotation of point clouds.
        NxC points with first 3 channels as XYZ
        z is facing forward, x is left ward, y is downward
        '''
        # Use np.copy to avoid corrupting original data
        point_set = np.copy(self.input_list[index])
        return rotate_pc_along_y(point_set, \
            self.get_center_view_rot_angle(index))
Beispiel #16
0
W = 608
data_transform = transforms.Compose([
        transforms.Resize(size=(H, W)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

depth_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(size=(H, W)),
        transforms.ToTensor()
    ])


kitti_train = KittiDataset(data_transform=data_transform)
train_dataloader = DataLoader(kitti_train, batch_size=args.batch_size, shuffle=True, **kwargs)
kitti_test = KittiDataset(root_dir='../images/test', train=False, data_transform=data_transform, depth_transform=depth_transform, num_test=args.num_test)
test_dataloader = DataLoader(kitti_test, batch_size=1, shuffle=False, **kwargs)
print(len(kitti_test))

net = resnet50(pretrained=True, use_deconv=args.use_deconv).to(device)
if args.model_dir is not '':
    net.load_state_dict(torch.load(args.model_dir))

optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

log_dir = '../images/logs'
now = str(datetime.datetime.now())
log_dir = os.path.join(log_dir, now)
logger = Logger(log_dir)
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/home/amben/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)
        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        elif (split == 'train'):
            """
            with open(overwritten_data_path,'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.box3d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.label_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                self.heading_list = pickle.load(fp)
                self.size_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp) 
            """
            pos_cnt = 0
            all_cnt = 0
            #self.dataset_kitti.sample_id_list=self.dataset_kitti.sample_id_list[0:10]
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []

            perturb_box2d = True
            augmentX = 1
            for i in range(len(self.id_list)):
                #load pc
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #load_labels
                gt_obj_list = self.dataset_kitti.get_label(self.id_list[i])
                #load pixels
                pixels = get_pixels(self.id_list[i])
                for j in range(len(gt_obj_list)):
                    for _ in range(augmentX):
                        # Augment data by box2d perturbation
                        if perturb_box2d:
                            box2d = random_shift_box2d(gt_obj_list[j].box2d)
                        frus_pc, frus_pc_ind = extract_pc_in_box2d(
                            pc_lidar, pixels, box2d)

                        #get frus angle
                        center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                                 (box2d[1] + box2d[2]) / 2.0])
                        pc_center_frus = get_closest_pc_to_center(
                            pc_lidar, pixels, center_box2d)
                        frustum_angle = -1 * np.arctan2(
                            pc_center_frus[2], pc_center_frus[0])

                        #get label list
                        cls_label = np.zeros((frus_pc.shape[0]),
                                             dtype=np.int32)
                        gt_boxes3d = kitti_utils.objs_to_boxes3d(
                            [gt_obj_list[j]])
                        gt_corners = kitti_utils.boxes3d_to_corners3d(
                            gt_boxes3d, transform=True)
                        box_corners = gt_corners[0]
                        print(box_corners.shape)
                        print(pc_center_frus.shape)
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = 1
                        if box2d[3] - box2d[1] < 25 or np.sum(cls_label) == 0:
                            continue
                        self.input_list.append(frus_pc)
                        self.frustum_angle_list.append(frustum_angle)
                        self.label_list.append(cls_label)
                        self.box3d_list.append(box_corners)
                        self.box2d_list.append(box2d)
                        self.type_list.append("Pedestrian")
                        self.heading_list.append(gt_obj_list[j].ry)
                        self.size_list.append(
                            np.array([
                                gt_obj_list[j].l, gt_obj_list[j].w,
                                gt_obj_list[j].h
                            ]))
                        batch_list.append(self.id_list[i])
                        pos_cnt += np.sum(cls_label)
                        all_cnt += frus_pc.shape[0]

            #estimate average pc input
            self.id_list = batch_list
            print('Average pos ratio: %f' % (pos_cnt / float(all_cnt)))
            print('Average npoints: %f' % (float(all_cnt) / len(self.id_list)))
            #estimate average labels
        elif (split == 'val'):
            self.dataset_kitti.sample_id_list = self.dataset_kitti.sample_id_list
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.frustum_angle_list = []
            self.input_list = []
            self.label_list = []
            self.box3d_list = []
            self.box2d_list = []
            self.type_list = []
            self.heading_list = []
            self.size_list = []
            for i in range(len(self.id_list)):
                pc_lidar = self.dataset_kitti.get_lidar(self.id_list[i])
                #get 2D boxes:
                box2ds = get_2Dboxes_detected(self.id_list[i])
                if box2ds == None:
                    continue
                pixels = get_pixels(self.id_list[i])
                for j in range(len(box2ds)):
                    box2d = box2ds[j]
                    frus_pc, frus_pc_ind = extract_pc_in_box2d(
                        pc_lidar, pixels, box2d)
                    # get frus angle
                    center_box2d = np.array([(box2d[0] + box2d[2]) / 2.0,
                                             (box2d[1] + box2d[2]) / 2.0])
                    pc_center_frus = get_closest_pc_to_center(
                        pc_lidar, pixels, center_box2d)
                    frustum_angle = -1 * np.arctan2(pc_center_frus[2],
                                                    pc_center_frus[0])

                    if (box2d[3] - box2d[1]) < 25 or len(frus_pc) < 50:
                        continue

                    # get_labels
                    gt_obj_list = self.dataset_kitti.filtrate_objects(
                        self.dataset_kitti.get_label(self.id_list[i]))
                    gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
                    # gt_boxes3d = gt_boxes3d[self.box_present[index] - 1].reshape(-1, 7)

                    cls_label = np.zeros((frus_pc.shape[0]), dtype=np.int32)
                    gt_corners = kitti_utils.boxes3d_to_corners3d(
                        gt_boxes3d, transform=True)
                    for k in range(gt_boxes3d.shape[0]):
                        box_corners = gt_corners[k]
                        fg_pt_flag = kitti_utils.in_hull(
                            frus_pc[:, 0:3], box_corners)
                        cls_label[fg_pt_flag] = k + 1

                    if (np.count_nonzero(cls_label > 0) < 20):
                        center = np.ones((3)) * (-1.0)
                        heading = 0.0
                        size = np.ones((3))
                        cls_label[cls_label > 0] = 0
                        seg = cls_label
                        rot_angle = 0.0
                        box3d_center = np.ones((3)) * (-1.0)
                        box3d = np.array([[
                            box3d_center[0], box3d_center[1], box3d_center[2],
                            size[0], size[1], size[2], rot_angle
                        ]])
                        corners_empty = kitti_utils.boxes3d_to_corners3d(
                            box3d, transform=True)
                        bb_corners = corners_empty[0]
                    else:
                        max = 0
                        corners_max = 0
                        for k in range(gt_boxes3d.shape[0]):
                            count = np.count_nonzero(cls_label == k + 1)
                            if count > max:
                                max = count
                                corners_max = k
                        seg = np.where(cls_label == corners_max + 1, 1, 0)
                        bb_corners = gt_corners[corners_max]
                        obj = gt_boxes3d[k]
                        center = np.array([obj[0], obj[1], obj[2]])
                        size = np.array([obj[3], obj[4], obj[5]])
                        print("size", size)
                        rot_angle = obj[6]

                    self.input_list.append(frus_pc)
                    count = 0
                    for c in range(len(self.input_list)):
                        count += self.input_list[c].shape[0]
                    print("average number of cloud:",
                          count / len(self.input_list))
                    self.frustum_angle_list.append(frustum_angle)
                    self.label_list.append(seg)
                    self.box3d_list.append(bb_corners)
                    self.box2d_list.append(box2d)
                    self.type_list.append("Pedestrian")
                    self.heading_list.append(rot_angle)
                    self.size_list.append(size)
                    batch_list.append(self.id_list[i])
            self.id_list = batch_list
    def __init__(self,
                 npoints,
                 split,
                 random_flip=False,
                 random_shift=False,
                 rotate_to_center=False,
                 overwritten_data_path=None,
                 from_rgb_detection=False,
                 one_hot=False,
                 generate_database=False):
        '''
        Input:
            npoints: int scalar, number of points for frustum point cloud.
            split: string, train or val
            random_flip: bool, in 50% randomly flip the point cloud
                in left and right (after the frustum rotation if any)
            random_shift: bool, if True randomly shift the point cloud
                back and forth by a random distance
            rotate_to_center: bool, whether to do frustum rotation
            overwritten_data_path: string, specify pickled file path.
                if None, use default path (with the split)
            from_rgb_detection: bool, if True we assume we do not have
                groundtruth, just return data elements.
            one_hot: bool, if True, return one hot vector
        '''
        self.dataset_kitti = KittiDataset(
            root_dir='/root/frustum-pointnets_RSC/dataset/',
            mode='TRAIN',
            split=split)

        self.npoints = npoints
        self.random_flip = random_flip
        self.random_shift = random_shift
        self.rotate_to_center = rotate_to_center
        self.one_hot = one_hot
        if overwritten_data_path is None:
            overwritten_data_path = os.path.join(
                ROOT_DIR, 'kitti/frustum_carpedcyc_%s.pickle' % (split))

        self.from_rgb_detection = from_rgb_detection
        if from_rgb_detection:
            with open(overwritten_data_path, 'rb') as fp:
                self.id_list = pickle.load(fp)
                self.box2d_list = pickle.load(fp)
                self.input_list = pickle.load(fp)
                self.type_list = pickle.load(fp)
                # frustum_angle is clockwise angle from positive x-axis
                self.frustum_angle_list = pickle.load(fp)
                self.prob_list = pickle.load(fp)
        else:
            #list = os.listdir("/root/3D_BoundingBox_Annotation_Tool_3D_BAT/input/NuScenes/ONE/pointclouds_Radar")
            self.id_list = self.dataset_kitti.sample_id_list
            self.idx_batch = self.id_list
            batch_list = []
            self.radar_OI = []
            self.batch_size = []
            for i in range(len(self.id_list)):
                pc_radar = self.dataset_kitti.get_radar(self.id_list[i])
                m = 0
                for j in range(len(pc_radar)):
                    if (pc_radar[j, 2] > 2.0):
                        batch_list.append(self.id_list[i])
                        self.radar_OI.append(j)
                        m = m + 1
                self.batch_size.append(m)
            self.id_list = batch_list

            #with open(overwritten_data_path, 'rb') as fp:
            #load list of frames
            #self.id_list = self.dataset_kitti.sample_id_list
            print("id_list", len(self.id_list))
            #fil = np.zeros((len(self.id_list)))
            #for i in range(len(self.id_list)):
            #    print(self.id_list[i])
            #    gt_obj_list = self.dataset_kitti.filtrate_objects(self.dataset_kitti.get_label(self.id_list[i]))
            #    print(len(gt_obj_list))
            #gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
            #print(gt_boxes3d)
            #    if(len(gt_obj_list)==1):
            #        fil[i]=1

            #self.id_list= np.extract(fil,self.id_list)
            """