Ejemplo n.º 1
0
def main(args):
    data_path = args.data_dir
    test_batch_size = args.test_batch_size
    model_save_path = args.model_save_path
    output_path = args.test_output_path
    compression_model = args.grid_size[2]
    grid_size = args.grid_size
    pytorch_device = torch.device('cuda:0')
    model = args.model
    if model == 'polar':
        fea_dim = 9
        circular_padding = True
    elif model == 'traditional':
        fea_dim = 7
        circular_padding = False

    # prepare miou fun
    unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1]

    # prepare model
    my_BEV_model = BEV_Unet(n_class=len(unique_label),
                            n_height=compression_model,
                            input_batch_norm=True,
                            dropout=0.5,
                            circular_padding=circular_padding)
    my_model = ptBEVnet(my_BEV_model,
                        pt_model='pointnet',
                        grid_size=grid_size,
                        fea_dim=fea_dim,
                        max_pt_per_encode=256,
                        out_pt_fea_dim=512,
                        kernal_size=1,
                        pt_selection='random',
                        fea_compre=compression_model)
    if os.path.exists(model_save_path):
        my_model.load_state_dict(torch.load(model_save_path))
    my_model.to(pytorch_device)

    # prepare dataset
    test_pt_dataset = SemKITTI(data_path + '/sequences/',
                               imageset='test',
                               return_ref=True)
    val_pt_dataset = SemKITTI(data_path + '/sequences/',
                              imageset='val',
                              return_ref=True)
    if model == 'polar':
        test_dataset = spherical_dataset(test_pt_dataset,
                                         grid_size=grid_size,
                                         ignore_label=0,
                                         fixed_volume_space=True,
                                         return_test=True)
        val_dataset = spherical_dataset(val_pt_dataset,
                                        grid_size=grid_size,
                                        ignore_label=0,
                                        fixed_volume_space=True)
    elif model == 'traditional':
        test_dataset = voxel_dataset(test_pt_dataset,
                                     grid_size=grid_size,
                                     ignore_label=0,
                                     fixed_volume_space=True,
                                     return_test=True)
        val_dataset = voxel_dataset(val_pt_dataset,
                                    grid_size=grid_size,
                                    ignore_label=0,
                                    fixed_volume_space=True)
    test_dataset_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=test_batch_size,
        collate_fn=collate_fn_BEV_test,
        shuffle=False,
        num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=test_batch_size,
        collate_fn=collate_fn_BEV,
        shuffle=False,
        num_workers=4)

    # validation
    print('*' * 80)
    print('Test network performance on validation split')
    print('*' * 80)
    pbar = tqdm(total=len(val_dataset_loader))
    my_model.eval()
    hist_list = []
    time_list = []
    with torch.no_grad():
        for i_iter_val, (_, val_vox_label, val_grid, val_pt_labs,
                         val_pt_fea) in enumerate(val_dataset_loader):
            val_vox_label = SemKITTI2train(val_vox_label)
            val_pt_labs = SemKITTI2train(val_pt_labs)
            val_pt_fea_ten = [
                torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device)
                for i in val_pt_fea
            ]
            val_grid_ten = [
                torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid
            ]
            val_label_tensor = val_vox_label.type(
                torch.LongTensor).to(pytorch_device)

            torch.cuda.synchronize()
            start_time = time.time()
            predict_labels = my_model(val_pt_fea_ten, val_grid_ten)
            torch.cuda.synchronize()
            time_list.append(time.time() - start_time)

            predict_labels = torch.argmax(predict_labels, dim=1)
            predict_labels = predict_labels.cpu().detach().numpy()
            for count, i_val_grid in enumerate(val_grid):
                hist_list.append(
                    fast_hist_crop(
                        predict_labels[count, val_grid[count][:, 0],
                                       val_grid[count][:, 1],
                                       val_grid[count][:, 2]],
                        val_pt_labs[count], unique_label))
            pbar.update(1)
    iou = per_class_iu(sum(hist_list))
    print('Validation per class iou: ')
    for class_name, class_iou in zip(unique_label_str, iou):
        print('%s : %.2f%%' % (class_name, class_iou * 100))
    val_miou = np.nanmean(iou) * 100
    del val_vox_label, val_grid, val_pt_fea, val_grid_ten
    pbar.close()
    print('Current val miou is %.3f ' % val_miou)
    print('Inference time per %d is %.4f seconds\n' %
          (test_batch_size, np.mean(time_list)))

    # test
    print('*' * 80)
    print('Generate predictions for test split')
    print('*' * 80)
    pbar = tqdm(total=len(test_dataset_loader))
    with torch.no_grad():
        for i_iter_test, (_, _, test_grid, _, test_pt_fea,
                          test_index) in enumerate(test_dataset_loader):
            # predict
            test_pt_fea_ten = [
                torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device)
                for i in test_pt_fea
            ]
            test_grid_ten = [
                torch.from_numpy(i[:, :2]).to(pytorch_device)
                for i in test_grid
            ]

            predict_labels = my_model(test_pt_fea_ten, test_grid_ten)
            predict_labels = torch.argmax(predict_labels, 1).type(torch.uint8)
            predict_labels = predict_labels.cpu().detach().numpy()
            # write to label file
            for count, i_test_grid in enumerate(test_grid):
                test_pred_label = predict_labels[count, test_grid[count][:, 0],
                                                 test_grid[count][:, 1],
                                                 test_grid[count][:, 2]]
                test_pred_label = train2SemKITTI(test_pred_label)
                test_pred_label = np.expand_dims(test_pred_label, axis=1)
                save_dir = test_pt_dataset.im_idx[test_index[count]]
                _, dir2 = save_dir.split('/sequences/', 1)
                new_save_dir = output_path + '/sequences/' + dir2.replace(
                    'velodyne', 'predictions')[:-3] + 'label'
                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                test_pred_label = test_pred_label.astype(np.uint32)
                test_pred_label.tofile(new_save_dir)
            pbar.update(1)
    del test_grid, test_pt_fea, test_index
    pbar.close()
    print(
        'Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.'
        % output_path)
    print('Remapping script can be found in semantic-kitti-api.')
Ejemplo n.º 2
0
def main(args):
    data_path = args.data_dir
    test_batch_size = args.test_batch_size
    model_save_path = args.model_save_path
    output_path = args.test_output_path
    compression_model = args.grid_size[2]
    grid_size = args.grid_size
    visibilty = args.visibilty
    pytorch_device = torch.device('cuda:0')
    model = args.model
    if model == 'polar':
        fea_dim = 9
        circular_padding = True
    elif model == 'traditional':
        fea_dim = 7
        circular_padding = False

    #prepare miou fun
    unique_label_str = list(
        map_name_from_segmentation_class_to_segmentation_index)[1:]
    unique_label = np.asarray([
        map_name_from_segmentation_class_to_segmentation_index[s]
        for s in unique_label_str
    ]) - 1

    # prepare model
    my_BEV_model = BEV_Unet(n_class=len(unique_label),
                            n_height=compression_model,
                            input_batch_norm=True,
                            dropout=0.5,
                            circular_padding=circular_padding,
                            use_vis_fea=visibilty)
    my_model = ptBEVnet(my_BEV_model,
                        pt_model='pointnet',
                        grid_size=grid_size,
                        fea_dim=fea_dim,
                        max_pt_per_encode=256,
                        out_pt_fea_dim=512,
                        kernal_size=1,
                        pt_selection='random',
                        fea_compre=compression_model)
    if os.path.exists(model_save_path):
        my_model.load_state_dict(torch.load(model_save_path))
    my_model.to(pytorch_device)

    # prepare dataset
    test_pt_dataset = Nuscenes(data_path + '/test/',
                               version='v1.0-test',
                               split='test',
                               return_ref=True)
    val_pt_dataset = Nuscenes(data_path + '/trainval/',
                              version='v1.0-trainval',
                              split='val',
                              return_ref=True)
    if model == 'polar':
        test_dataset = spherical_dataset(test_pt_dataset,
                                         grid_size=grid_size,
                                         ignore_label=0,
                                         fixed_volume_space=True,
                                         return_test=True,
                                         max_volume_space=[50, np.pi, 3],
                                         min_volume_space=[0, -np.pi, -5])
        val_dataset = spherical_dataset(val_pt_dataset,
                                        grid_size=grid_size,
                                        ignore_label=0,
                                        fixed_volume_space=True,
                                        max_volume_space=[50, np.pi, 3],
                                        min_volume_space=[0, -np.pi, -5])
    elif model == 'traditional':
        test_dataset = voxel_dataset(test_pt_dataset,
                                     grid_size=grid_size,
                                     ignore_label=0,
                                     fixed_volume_space=True,
                                     return_test=True,
                                     max_volume_space=[50, 50, 3],
                                     min_volume_space=[-50, -50, -5])
        val_dataset = voxel_dataset(val_pt_dataset,
                                    grid_size=grid_size,
                                    ignore_label=0,
                                    fixed_volume_space=True,
                                    max_volume_space=[50, 50, 3],
                                    min_volume_space=[-50, -50, -5])
    test_dataset_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=test_batch_size,
        collate_fn=collate_fn_BEV_test,
        shuffle=False,
        num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=test_batch_size,
        collate_fn=collate_fn_BEV,
        shuffle=False,
        num_workers=4)

    # validation
    print('*' * 80)
    print('Test network performance on validation split')
    print('*' * 80)
    pbar = tqdm(total=len(val_dataset_loader))
    my_model.eval()
    hist_list = []
    time_list = []
    with torch.no_grad():
        for i_iter_val, (val_vox_fea, val_vox_label, val_grid, val_pt_labs,
                         val_pt_fea) in enumerate(val_dataset_loader):
            val_vox_fea_ten = val_vox_fea.to(pytorch_device)
            val_vox_label = SemKITTI2train(val_vox_label)
            val_pt_labs = SemKITTI2train(val_pt_labs)
            val_pt_fea_ten = [
                torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device)
                for i in val_pt_fea
            ]
            val_grid_ten = [
                torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid
            ]
            val_label_tensor = val_vox_label.type(
                torch.LongTensor).to(pytorch_device)

            torch.cuda.synchronize()
            start_time = time.time()
            if visibilty:
                predict_labels = my_model(val_pt_fea_ten, val_grid_ten,
                                          val_vox_fea_ten)
            else:
                predict_labels = my_model(val_pt_fea_ten, val_grid_ten)
            torch.cuda.synchronize()
            time_list.append(time.time() - start_time)

            predict_labels = torch.argmax(predict_labels, dim=1)
            predict_labels = predict_labels.cpu().detach().numpy()
            for count, i_val_grid in enumerate(val_grid):
                hist_list.append(
                    fast_hist_crop(
                        predict_labels[count, val_grid[count][:, 0],
                                       val_grid[count][:, 1],
                                       val_grid[count][:, 2]],
                        val_pt_labs[count], unique_label))
            pbar.update(1)
    iou = per_class_iu(sum(hist_list))
    print('Validation per class iou: ')
    for class_name, class_iou in zip(unique_label_str, iou):
        print('%s : %.2f%%' % (class_name, class_iou * 100))
    val_miou = np.nanmean(iou) * 100
    del val_vox_label, val_grid, val_pt_fea, val_grid_ten
    pbar.close()
    print('Current val miou is %.3f ' % val_miou)
    print('Inference time per %d is %.4f seconds\n' %
          (test_batch_size, np.mean(time_list)))

    # test
    print('*' * 80)
    print('Generate predictions for test split')
    print('*' * 80)
    pbar = tqdm(total=len(test_dataset_loader))
    with torch.no_grad():
        for i_iter_test, (test_vox_fea, _, test_grid, _, test_pt_fea,
                          test_index) in enumerate(test_dataset_loader):
            # predict
            test_vox_fea_ten = test_vox_fea.to(pytorch_device)
            test_pt_fea_ten = [
                torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device)
                for i in test_pt_fea
            ]
            test_grid_ten = [
                torch.from_numpy(i[:, :2]).to(pytorch_device)
                for i in test_grid
            ]

            if visibilty:
                predict_labels = my_model(test_pt_fea_ten, test_grid_ten,
                                          test_vox_fea_ten)
            else:
                predict_labels = my_model(test_pt_fea_ten, test_grid_ten)
            predict_labels = torch.argmax(predict_labels, 1).type(torch.uint8)
            predict_labels = predict_labels.cpu().detach().numpy()
            # write to label file
            for count, i_test_grid in enumerate(test_grid):
                test_pred_label = predict_labels[count, test_grid[count][:, 0],
                                                 test_grid[count][:, 1],
                                                 test_grid[count][:, 2]]
                test_pred_label = train2SemKITTI(test_pred_label)
                test_pred_label = np.expand_dims(test_pred_label, axis=1)
                label_token = test_pt_dataset.train_token_list[
                    test_index[count]]
                new_save_dir = os.path.join(output_path,
                                            label_token + '_lidarseg.bin')

                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                test_pred_label = test_pred_label.astype(np.uint32)
                test_pred_label.tofile(new_save_dir)
            pbar.update(1)
    del test_grid, test_pt_fea, test_index
    pbar.close()
    print('Predicted test labels are saved in %s.' % output_path)
Ejemplo n.º 3
0
def main(args):
    data_path = args['dataset']['path']
    test_batch_size = args['model']['test_batch_size']
    pretrained_model = args['model']['pretrained_model']
    output_path = args['dataset']['output_path']
    compression_model = args['dataset']['grid_size'][2]
    grid_size = args['dataset']['grid_size']
    visibility = args['model']['visibility']
    pytorch_device = torch.device('cuda:0')
    if args['model']['polar']:
        fea_dim = 9
        circular_padding = True
    else:
        fea_dim = 7
        circular_padding = False

    # prepare miou fun
    unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]

    # prepare model
    my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)
    my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size =  grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,
                            out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)
    if os.path.exists(pretrained_model):
        my_model.load_state_dict(torch.load(pretrained_model))
    pytorch_total_params = sum(p.numel() for p in my_model.parameters())
    print('params: ',pytorch_total_params)
    my_model.to(pytorch_device)
    my_model.eval()

    # prepare dataset
    test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'test', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    if args['model']['polar']:
        test_dataset=spherical_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
        val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    else:
        test_dataset=voxel_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
        val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    test_dataset_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                                    batch_size = test_batch_size,
                                                    collate_fn = collate_fn_BEV_test,
                                                    shuffle = False,
                                                    num_workers = 4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,
                                                    batch_size = test_batch_size,
                                                    collate_fn = collate_fn_BEV,
                                                    shuffle = False,
                                                    num_workers = 4)

    # validation
    print('*'*80)
    print('Test network performance on validation split')
    print('*'*80)
    pbar = tqdm(total=len(val_dataset_loader))
    time_list = []
    pp_time_list = []
    evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)
    with torch.no_grad():
        for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):
            val_vox_fea_ten = val_vox_fea.to(pytorch_device)
            val_vox_label = SemKITTI2train(val_vox_label)
            val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
            val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
            val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)
            val_gt_center_tensor = val_gt_center.to(pytorch_device)
            val_gt_offset_tensor = val_gt_offset.to(pytorch_device)

            torch.cuda.synchronize()
            start_time = time.time()
            if visibility:            
                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
            else:
                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)
            torch.cuda.synchronize()
            time_list.append(time.time()-start_time)

            for count,i_val_grid in enumerate(val_grid):
                # get foreground_mask
                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
                for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True
                # post processing
                torch.cuda.synchronize()
                start_time = time.time()
                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),val_pt_dataset.thing_list,\
                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                torch.cuda.synchronize()
                pp_time_list.append(time.time()-start_time)
                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
                panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]

                evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))
            del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points
            pbar.update(1)
    
    class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()
    miou,ious = evaluator.getSemIoU()
    print('Validation per class PQ, SQ, RQ and IoU: ')
    for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):
        print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))
    pbar.close()
    print('Current val PQ is %.3f' %
        (class_PQ*100))               
    print('Current val miou is %.3f'%
        (miou*100))
    print('Inference time per %d is %.4f seconds\n, postprocessing time is %.4f seconds per scan' %
        (test_batch_size,np.mean(time_list),np.mean(pp_time_list)))
    
    # test
    print('*'*80)
    print('Generate predictions for test split')
    print('*'*80)
    pbar = tqdm(total=len(test_dataset_loader))
    with torch.no_grad():
        for i_iter_test,(test_vox_fea,_,_,_,test_grid,_,_,test_pt_fea,test_index) in enumerate(test_dataset_loader):
            # predict
            test_vox_fea_ten = test_vox_fea.to(pytorch_device)
            test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea]
            test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid]

            if visibility:
                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten,test_vox_fea_ten)
            else:
                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten)
            # write to label file
            for count,i_test_grid in enumerate(test_grid):
                # get foreground_mask
                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
                for_mask[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] = True
                # post processing
                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),test_pt_dataset.thing_list,\
                                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
                panoptic = panoptic_labels[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]]
                save_dir = test_pt_dataset.im_idx[test_index[count]]
                _,dir2 = save_dir.split('/sequences/',1)
                new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne','predictions')[:-3]+'label'
                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                panoptic.tofile(new_save_dir)
            del test_pt_fea_ten,test_grid_ten,test_pt_fea,predict_labels,center,offset
            pbar.update(1)
    pbar.close()
    print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path)
    print('Remapping script can be found in semantic-kitti-api.')
Ejemplo n.º 4
0
def main(args):
    data_path = args.data_dir
    train_batch_size = args.train_batch_size
    val_batch_size = args.val_batch_size
    check_iter = args.check_iter
    model_save_path = args.model_save_path
    compression_model = args.grid_size[2]
    grid_size = args.grid_size
    pytorch_device = torch.device('cuda:0')
    model = args.model
    if model == 'polar':
        fea_dim = 9
        circular_padding = True
    elif model == 'traditional':
        fea_dim = 7
        circular_padding = False

    #prepare miou fun
    unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1]

    #prepare model
    my_BEV_model = BEV_Unet(n_class=len(unique_label),
                            n_height=compression_model,
                            input_batch_norm=True,
                            dropout=0.5,
                            circular_padding=circular_padding)
    my_model = ptBEVnet(my_BEV_model,
                        pt_model='pointnet',
                        grid_size=grid_size,
                        fea_dim=fea_dim,
                        max_pt_per_encode=256,
                        out_pt_fea_dim=512,
                        kernal_size=1,
                        pt_selection='random',
                        fea_compre=compression_model)
    if os.path.exists(model_save_path):
        my_model.load_state_dict(torch.load(model_save_path))
    my_model.to(pytorch_device)

    optimizer = optim.Adam(my_model.parameters())
    loss_fun = torch.nn.CrossEntropyLoss(ignore_index=255)

    #prepare dataset
    train_pt_dataset = SemKITTI(data_path + '/sequences/',
                                imageset='train',
                                return_ref=True)
    val_pt_dataset = SemKITTI(data_path + '/sequences/',
                              imageset='val',
                              return_ref=True)
    if model == 'polar':
        train_dataset = spherical_dataset(train_pt_dataset,
                                          grid_size=grid_size,
                                          flip_aug=True,
                                          ignore_label=0,
                                          rotate_aug=True,
                                          fixed_volume_space=True)
        val_dataset = spherical_dataset(val_pt_dataset,
                                        grid_size=grid_size,
                                        ignore_label=0,
                                        fixed_volume_space=True)
    elif model == 'traditional':
        train_dataset = voxel_dataset(train_pt_dataset,
                                      grid_size=grid_size,
                                      flip_aug=True,
                                      ignore_label=0,
                                      rotate_aug=True,
                                      fixed_volume_space=True)
        val_dataset = voxel_dataset(val_pt_dataset,
                                    grid_size=grid_size,
                                    ignore_label=0,
                                    fixed_volume_space=True)
    train_dataset_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        collate_fn=collate_fn_BEV,
        shuffle=True,
        num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                     batch_size=val_batch_size,
                                                     collate_fn=collate_fn_BEV,
                                                     shuffle=False,
                                                     num_workers=4)

    # training
    epoch = 0
    best_val_miou = 0
    start_training = False
    my_model.train()
    global_iter = 0
    exce_counter = 0

    while True:
        loss_list = []
        pbar = tqdm(total=len(train_dataset_loader))
        for i_iter, (_, train_vox_label, train_grid, _,
                     train_pt_fea) in enumerate(train_dataset_loader):
            # validation
            if global_iter % check_iter == 0:
                my_model.eval()
                hist_list = []
                val_loss_list = []
                with torch.no_grad():
                    for i_iter_val, (
                            _, val_vox_label, val_grid, val_pt_labs,
                            val_pt_fea) in enumerate(val_dataset_loader):
                        val_vox_label = SemKITTI2train(val_vox_label)
                        val_pt_labs = SemKITTI2train(val_pt_labs)
                        val_pt_fea_ten = [
                            torch.from_numpy(i).type(
                                torch.FloatTensor).to(pytorch_device)
                            for i in val_pt_fea
                        ]
                        val_grid_ten = [
                            torch.from_numpy(i[:, :2]).to(pytorch_device)
                            for i in val_grid
                        ]
                        val_label_tensor = val_vox_label.type(
                            torch.LongTensor).to(pytorch_device)

                        predict_labels = my_model(val_pt_fea_ten, val_grid_ten)
                        loss = lovasz_softmax(torch.nn.functional.softmax(
                            predict_labels).detach(),
                                              val_label_tensor,
                                              ignore=255) + loss_fun(
                                                  predict_labels.detach(),
                                                  val_label_tensor)
                        predict_labels = torch.argmax(predict_labels, dim=1)
                        predict_labels = predict_labels.cpu().detach().numpy()
                        for count, i_val_grid in enumerate(val_grid):
                            hist_list.append(
                                fast_hist_crop(
                                    predict_labels[count, val_grid[count][:,
                                                                          0],
                                                   val_grid[count][:, 1],
                                                   val_grid[count][:, 2]],
                                    val_pt_labs[count], unique_label))
                        val_loss_list.append(loss.detach().cpu().numpy())
                my_model.train()
                iou = per_class_iu(sum(hist_list))
                print('Validation per class iou: ')
                for class_name, class_iou in zip(unique_label_str, iou):
                    print('%s : %.2f%%' % (class_name, class_iou * 100))
                val_miou = np.nanmean(iou) * 100
                del val_vox_label, val_grid, val_pt_fea, val_grid_ten

                # save model if performance is improved
                if best_val_miou < val_miou:
                    best_val_miou = val_miou
                    torch.save(my_model.state_dict(), model_save_path)

                print(
                    'Current val miou is %.3f while the best val miou is %.3f'
                    % (val_miou, best_val_miou))
                print('Current val loss is %.3f' % (np.mean(val_loss_list)))
                if start_training:
                    print('epoch %d iter %5d, loss: %.3f\n' %
                          (epoch, i_iter, np.mean(loss_list)))
                print('%d exceptions encountered during last training\n' %
                      exce_counter)
                exce_counter = 0
                loss_list = []

            # training
            try:
                train_vox_label = SemKITTI2train(train_vox_label)
                train_pt_fea_ten = [
                    torch.from_numpy(i).type(
                        torch.FloatTensor).to(pytorch_device)
                    for i in train_pt_fea
                ]
                train_grid_ten = [
                    torch.from_numpy(i[:, :2]).to(pytorch_device)
                    for i in train_grid
                ]
                train_vox_ten = [
                    torch.from_numpy(i).to(pytorch_device) for i in train_grid
                ]
                point_label_tensor = train_vox_label.type(
                    torch.LongTensor).to(pytorch_device)

                # forward + backward + optimize
                outputs = my_model(train_pt_fea_ten, train_grid_ten)
                loss = lovasz_softmax(torch.nn.functional.softmax(outputs),
                                      point_label_tensor,
                                      ignore=255) + loss_fun(
                                          outputs, point_label_tensor)
                loss.backward()
                optimizer.step()
                loss_list.append(loss.item())
            except Exception:
                exce_counter += 1

            # zero the parameter gradients
            optimizer.zero_grad()
            pbar.update(1)
            start_training = True
            global_iter += 1
        pbar.close()
        epoch += 1
Ejemplo n.º 5
0
def main(args):
    data_path = args['dataset']['path']
    train_batch_size = args['model']['train_batch_size']
    val_batch_size = args['model']['val_batch_size']
    check_iter = args['model']['check_iter']
    model_save_path = args['model']['model_save_path']
    pretrained_model = args['model']['pretrained_model']
    compression_model = args['dataset']['grid_size'][2]
    grid_size = args['dataset']['grid_size']
    visibility = args['model']['visibility']
    pytorch_device = torch.device('cuda:0')
    if args['model']['polar']:
        fea_dim = 9
        circular_padding = True
    else:
        fea_dim = 7
        circular_padding = False

    #prepare miou fun
    unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1]

    #prepare model
    my_BEV_model = BEV_Unet(n_class=len(unique_label),
                            n_height=compression_model,
                            input_batch_norm=True,
                            dropout=0.5,
                            circular_padding=circular_padding,
                            use_vis_fea=visibility)
    my_model = ptBEVnet(my_BEV_model,
                        pt_model='pointnet',
                        grid_size=grid_size,
                        fea_dim=fea_dim,
                        max_pt_per_encode=256,
                        out_pt_fea_dim=512,
                        kernal_size=1,
                        pt_selection='random',
                        fea_compre=compression_model)
    if os.path.exists(model_save_path):
        my_model = load_pretrained_model(my_model, torch.load(model_save_path))
    elif os.path.exists(pretrained_model):
        my_model = load_pretrained_model(my_model,
                                         torch.load(pretrained_model))
    my_model.to(pytorch_device)

    optimizer = optim.Adam(my_model.parameters())
    loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\
                            center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss'])

    #prepare dataset
    train_pt_dataset = SemKITTI(
        data_path + '/sequences/',
        imageset='train',
        return_ref=True,
        instance_pkl_path=args['dataset']['instance_pkl_path'])
    val_pt_dataset = SemKITTI(
        data_path + '/sequences/',
        imageset='val',
        return_ref=True,
        instance_pkl_path=args['dataset']['instance_pkl_path'])
    if args['model']['polar']:
        train_dataset = spherical_dataset(train_pt_dataset,
                                          args['dataset'],
                                          grid_size=grid_size,
                                          ignore_label=0,
                                          use_aug=True)
        val_dataset = spherical_dataset(val_pt_dataset,
                                        args['dataset'],
                                        grid_size=grid_size,
                                        ignore_label=0)
    else:
        train_dataset = voxel_dataset(train_pt_dataset,
                                      args['dataset'],
                                      grid_size=grid_size,
                                      ignore_label=0,
                                      use_aug=True)
        val_dataset = voxel_dataset(val_pt_dataset,
                                    args['dataset'],
                                    grid_size=grid_size,
                                    ignore_label=0)
    train_dataset_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        collate_fn=collate_fn_BEV,
        shuffle=True,
        num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                     batch_size=val_batch_size,
                                                     collate_fn=collate_fn_BEV,
                                                     shuffle=False,
                                                     num_workers=4)

    # training
    epoch = 0
    best_val_PQ = 0
    start_training = False
    my_model.train()
    global_iter = 0
    exce_counter = 0
    evaluator = PanopticEval(len(unique_label) + 1, None, [0], min_points=50)

    while epoch < args['model']['max_epoch']:
        pbar = tqdm(total=len(train_dataset_loader))
        for i_iter, (train_vox_fea, train_label_tensor, train_gt_center,
                     train_gt_offset, train_grid, _, _,
                     train_pt_fea) in enumerate(train_dataset_loader):
            # validation
            if global_iter % check_iter == 0:
                my_model.eval()
                evaluator.reset()
                with torch.no_grad():
                    for i_iter_val, (
                            val_vox_fea, val_vox_label, val_gt_center,
                            val_gt_offset, val_grid, val_pt_labels,
                            val_pt_ints,
                            val_pt_fea) in enumerate(val_dataset_loader):
                        val_vox_fea_ten = val_vox_fea.to(pytorch_device)
                        val_vox_label = SemKITTI2train(val_vox_label)
                        val_pt_fea_ten = [
                            torch.from_numpy(i).type(
                                torch.FloatTensor).to(pytorch_device)
                            for i in val_pt_fea
                        ]
                        val_grid_ten = [
                            torch.from_numpy(i[:, :2]).to(pytorch_device)
                            for i in val_grid
                        ]
                        val_label_tensor = val_vox_label.type(
                            torch.LongTensor).to(pytorch_device)
                        val_gt_center_tensor = val_gt_center.to(pytorch_device)
                        val_gt_offset_tensor = val_gt_offset.to(pytorch_device)

                        if visibility:
                            predict_labels, center, offset = my_model(
                                val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
                        else:
                            predict_labels, center, offset = my_model(
                                val_pt_fea_ten, val_grid_ten)

                        for count, i_val_grid in enumerate(val_grid):
                            # get foreground_mask
                            for_mask = torch.zeros(
                                1,
                                grid_size[0],
                                grid_size[1],
                                grid_size[2],
                                dtype=torch.bool).to(pytorch_device)
                            for_mask[0, val_grid[count][:, 0],
                                     val_grid[count][:, 1],
                                     val_grid[count][:, 2]] = True
                            # post processing
                            panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\
                                                                                      val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                                      top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                            panoptic_labels = panoptic_labels.cpu().detach(
                            ).numpy().astype(np.int32)
                            panoptic = panoptic_labels[0, val_grid[count][:,
                                                                          0],
                                                       val_grid[count][:, 1],
                                                       val_grid[count][:, 2]]
                            evaluator.addBatch(
                                panoptic & 0xFFFF, panoptic,
                                np.squeeze(val_pt_labels[count]),
                                np.squeeze(val_pt_ints[count]))
                        del val_vox_label, val_pt_fea_ten, val_label_tensor, val_grid_ten, val_gt_center, val_gt_center_tensor, val_gt_offset, val_gt_offset_tensor, predict_labels, center, offset, panoptic_labels, center_points
                my_model.train()
                class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ(
                )
                miou, ious = evaluator.getSemIoU()
                print('Validation per class PQ, SQ, RQ and IoU: ')
                for class_name, class_pq, class_sq, class_rq, class_iou in zip(
                        unique_label_str, class_all_PQ[1:], class_all_SQ[1:],
                        class_all_RQ[1:], ious[1:]):
                    print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' %
                          (class_name, class_pq * 100, class_sq * 100,
                           class_rq * 100, class_iou * 100))
                # save model if performance is improved
                if best_val_PQ < class_PQ:
                    best_val_PQ = class_PQ
                    torch.save(my_model.state_dict(), model_save_path)
                print('Current val PQ is %.3f while the best val PQ is %.3f' %
                      (class_PQ * 100, best_val_PQ * 100))
                print('Current val miou is %.3f' % (miou * 100))

                if start_training:
                    sem_l, hm_l, os_l = np.mean(
                        loss_fn.lost_dict['semantic_loss']), np.mean(
                            loss_fn.lost_dict['heatmap_loss']), np.mean(
                                loss_fn.lost_dict['offset_loss'])
                    print(
                        'epoch %d iter %5d, loss: %.3f, semantic loss: %.3f, heatmap loss: %.3f, offset loss: %.3f\n'
                        % (epoch, i_iter, sem_l + hm_l + os_l, sem_l, hm_l,
                           os_l))
                print('%d exceptions encountered during last training\n' %
                      exce_counter)
                exce_counter = 0
                loss_fn.reset_loss_dict()

            # training
            try:
                train_vox_fea_ten = train_vox_fea.to(pytorch_device)
                train_label_tensor = SemKITTI2train(train_label_tensor)
                train_pt_fea_ten = [
                    torch.from_numpy(i).type(
                        torch.FloatTensor).to(pytorch_device)
                    for i in train_pt_fea
                ]
                train_grid_ten = [
                    torch.from_numpy(i[:, :2]).to(pytorch_device)
                    for i in train_grid
                ]
                train_label_tensor = train_label_tensor.type(
                    torch.LongTensor).to(pytorch_device)
                train_gt_center_tensor = train_gt_center.to(pytorch_device)
                train_gt_offset_tensor = train_gt_offset.to(pytorch_device)

                if args['model']['enable_SAP'] and epoch >= args['model'][
                        'SAP']['start_epoch']:
                    for fea in train_pt_fea_ten:
                        fea.requires_grad_()

                # forward
                if visibility:
                    sem_prediction, center, offset = my_model(
                        train_pt_fea_ten, train_grid_ten, train_vox_fea_ten)
                else:
                    sem_prediction, center, offset = my_model(
                        train_pt_fea_ten, train_grid_ten)
                # loss
                loss = loss_fn(sem_prediction, center, offset,
                               train_label_tensor, train_gt_center_tensor,
                               train_gt_offset_tensor)

                # self adversarial pruning
                if args['model']['enable_SAP'] and epoch >= args['model'][
                        'SAP']['start_epoch']:
                    loss.backward()
                    for i, fea in enumerate(train_pt_fea_ten):
                        fea_grad = torch.norm(fea.grad, dim=1)
                        top_k_grad, _ = torch.topk(
                            fea_grad,
                            int(args['model']['SAP']['rate'] *
                                fea_grad.shape[0]))
                        # delete high influential points
                        train_pt_fea_ten[i] = train_pt_fea_ten[i][
                            fea_grad < top_k_grad[-1]]
                        train_grid_ten[i] = train_grid_ten[i][
                            fea_grad < top_k_grad[-1]]
                    optimizer.zero_grad()

                    # second pass
                    # forward
                    if visibility:
                        sem_prediction, center, offset = my_model(
                            train_pt_fea_ten, train_grid_ten,
                            train_vox_fea_ten)
                    else:
                        sem_prediction, center, offset = my_model(
                            train_pt_fea_ten, train_grid_ten)
                    # loss
                    loss = loss_fn(sem_prediction, center, offset,
                                   train_label_tensor, train_gt_center_tensor,
                                   train_gt_offset_tensor)

                # backward + optimize
                loss.backward()
                optimizer.step()
            except Exception as error:
                if exce_counter == 0:
                    print(error)
                exce_counter += 1

            # zero the parameter gradients
            optimizer.zero_grad()
            pbar.update(1)
            start_training = True
            global_iter += 1
        pbar.close()
        epoch += 1