def main(args): data_path = args.data_dir test_batch_size = args.test_batch_size model_save_path = args.model_save_path output_path = args.test_output_path compression_model = args.grid_size[2] grid_size = args.grid_size pytorch_device = torch.device('cuda:0') model = args.model if model == 'polar': fea_dim = 9 circular_padding = True elif model == 'traditional': fea_dim = 7 circular_padding = False # prepare miou fun unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] # prepare model my_BEV_model = BEV_Unet(n_class=len(unique_label), n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding) my_model = ptBEVnet(my_BEV_model, pt_model='pointnet', grid_size=grid_size, fea_dim=fea_dim, max_pt_per_encode=256, out_pt_fea_dim=512, kernal_size=1, pt_selection='random', fea_compre=compression_model) if os.path.exists(model_save_path): my_model.load_state_dict(torch.load(model_save_path)) my_model.to(pytorch_device) # prepare dataset test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='test', return_ref=True) val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='val', return_ref=True) if model == 'polar': test_dataset = spherical_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True) val_dataset = spherical_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True) elif model == 'traditional': test_dataset = voxel_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True) val_dataset = voxel_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True) test_dataset_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV_test, shuffle=False, num_workers=4) val_dataset_loader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV, shuffle=False, num_workers=4) # validation print('*' * 80) print('Test network performance on validation split') print('*' * 80) pbar = tqdm(total=len(val_dataset_loader)) my_model.eval() hist_list = [] time_list = [] with torch.no_grad(): for i_iter_val, (_, val_vox_label, val_grid, val_pt_labs, val_pt_fea) in enumerate(val_dataset_loader): val_vox_label = SemKITTI2train(val_vox_label) val_pt_labs = SemKITTI2train(val_pt_labs) val_pt_fea_ten = [ torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea ] val_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid ] val_label_tensor = val_vox_label.type( torch.LongTensor).to(pytorch_device) torch.cuda.synchronize() start_time = time.time() predict_labels = my_model(val_pt_fea_ten, val_grid_ten) torch.cuda.synchronize() time_list.append(time.time() - start_time) predict_labels = torch.argmax(predict_labels, dim=1) predict_labels = predict_labels.cpu().detach().numpy() for count, i_val_grid in enumerate(val_grid): hist_list.append( fast_hist_crop( predict_labels[count, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]], val_pt_labs[count], unique_label)) pbar.update(1) iou = per_class_iu(sum(hist_list)) print('Validation per class iou: ') for class_name, class_iou in zip(unique_label_str, iou): print('%s : %.2f%%' % (class_name, class_iou * 100)) val_miou = np.nanmean(iou) * 100 del val_vox_label, val_grid, val_pt_fea, val_grid_ten pbar.close() print('Current val miou is %.3f ' % val_miou) print('Inference time per %d is %.4f seconds\n' % (test_batch_size, np.mean(time_list))) # test print('*' * 80) print('Generate predictions for test split') print('*' * 80) pbar = tqdm(total=len(test_dataset_loader)) with torch.no_grad(): for i_iter_test, (_, _, test_grid, _, test_pt_fea, test_index) in enumerate(test_dataset_loader): # predict test_pt_fea_ten = [ torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea ] test_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in test_grid ] predict_labels = my_model(test_pt_fea_ten, test_grid_ten) predict_labels = torch.argmax(predict_labels, 1).type(torch.uint8) predict_labels = predict_labels.cpu().detach().numpy() # write to label file for count, i_test_grid in enumerate(test_grid): test_pred_label = predict_labels[count, test_grid[count][:, 0], test_grid[count][:, 1], test_grid[count][:, 2]] test_pred_label = train2SemKITTI(test_pred_label) test_pred_label = np.expand_dims(test_pred_label, axis=1) save_dir = test_pt_dataset.im_idx[test_index[count]] _, dir2 = save_dir.split('/sequences/', 1) new_save_dir = output_path + '/sequences/' + dir2.replace( 'velodyne', 'predictions')[:-3] + 'label' if not os.path.exists(os.path.dirname(new_save_dir)): try: os.makedirs(os.path.dirname(new_save_dir)) except OSError as exc: if exc.errno != errno.EEXIST: raise test_pred_label = test_pred_label.astype(np.uint32) test_pred_label.tofile(new_save_dir) pbar.update(1) del test_grid, test_pt_fea, test_index pbar.close() print( 'Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path) print('Remapping script can be found in semantic-kitti-api.')
def main(args): data_path = args.data_dir test_batch_size = args.test_batch_size model_save_path = args.model_save_path output_path = args.test_output_path compression_model = args.grid_size[2] grid_size = args.grid_size visibilty = args.visibilty pytorch_device = torch.device('cuda:0') model = args.model if model == 'polar': fea_dim = 9 circular_padding = True elif model == 'traditional': fea_dim = 7 circular_padding = False #prepare miou fun unique_label_str = list( map_name_from_segmentation_class_to_segmentation_index)[1:] unique_label = np.asarray([ map_name_from_segmentation_class_to_segmentation_index[s] for s in unique_label_str ]) - 1 # prepare model my_BEV_model = BEV_Unet(n_class=len(unique_label), n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding, use_vis_fea=visibilty) my_model = ptBEVnet(my_BEV_model, pt_model='pointnet', grid_size=grid_size, fea_dim=fea_dim, max_pt_per_encode=256, out_pt_fea_dim=512, kernal_size=1, pt_selection='random', fea_compre=compression_model) if os.path.exists(model_save_path): my_model.load_state_dict(torch.load(model_save_path)) my_model.to(pytorch_device) # prepare dataset test_pt_dataset = Nuscenes(data_path + '/test/', version='v1.0-test', split='test', return_ref=True) val_pt_dataset = Nuscenes(data_path + '/trainval/', version='v1.0-trainval', split='val', return_ref=True) if model == 'polar': test_dataset = spherical_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True, max_volume_space=[50, np.pi, 3], min_volume_space=[0, -np.pi, -5]) val_dataset = spherical_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, max_volume_space=[50, np.pi, 3], min_volume_space=[0, -np.pi, -5]) elif model == 'traditional': test_dataset = voxel_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True, max_volume_space=[50, 50, 3], min_volume_space=[-50, -50, -5]) val_dataset = voxel_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, max_volume_space=[50, 50, 3], min_volume_space=[-50, -50, -5]) test_dataset_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV_test, shuffle=False, num_workers=4) val_dataset_loader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV, shuffle=False, num_workers=4) # validation print('*' * 80) print('Test network performance on validation split') print('*' * 80) pbar = tqdm(total=len(val_dataset_loader)) my_model.eval() hist_list = [] time_list = [] with torch.no_grad(): for i_iter_val, (val_vox_fea, val_vox_label, val_grid, val_pt_labs, val_pt_fea) in enumerate(val_dataset_loader): val_vox_fea_ten = val_vox_fea.to(pytorch_device) val_vox_label = SemKITTI2train(val_vox_label) val_pt_labs = SemKITTI2train(val_pt_labs) val_pt_fea_ten = [ torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea ] val_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid ] val_label_tensor = val_vox_label.type( torch.LongTensor).to(pytorch_device) torch.cuda.synchronize() start_time = time.time() if visibilty: predict_labels = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten) else: predict_labels = my_model(val_pt_fea_ten, val_grid_ten) torch.cuda.synchronize() time_list.append(time.time() - start_time) predict_labels = torch.argmax(predict_labels, dim=1) predict_labels = predict_labels.cpu().detach().numpy() for count, i_val_grid in enumerate(val_grid): hist_list.append( fast_hist_crop( predict_labels[count, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]], val_pt_labs[count], unique_label)) pbar.update(1) iou = per_class_iu(sum(hist_list)) print('Validation per class iou: ') for class_name, class_iou in zip(unique_label_str, iou): print('%s : %.2f%%' % (class_name, class_iou * 100)) val_miou = np.nanmean(iou) * 100 del val_vox_label, val_grid, val_pt_fea, val_grid_ten pbar.close() print('Current val miou is %.3f ' % val_miou) print('Inference time per %d is %.4f seconds\n' % (test_batch_size, np.mean(time_list))) # test print('*' * 80) print('Generate predictions for test split') print('*' * 80) pbar = tqdm(total=len(test_dataset_loader)) with torch.no_grad(): for i_iter_test, (test_vox_fea, _, test_grid, _, test_pt_fea, test_index) in enumerate(test_dataset_loader): # predict test_vox_fea_ten = test_vox_fea.to(pytorch_device) test_pt_fea_ten = [ torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea ] test_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in test_grid ] if visibilty: predict_labels = my_model(test_pt_fea_ten, test_grid_ten, test_vox_fea_ten) else: predict_labels = my_model(test_pt_fea_ten, test_grid_ten) predict_labels = torch.argmax(predict_labels, 1).type(torch.uint8) predict_labels = predict_labels.cpu().detach().numpy() # write to label file for count, i_test_grid in enumerate(test_grid): test_pred_label = predict_labels[count, test_grid[count][:, 0], test_grid[count][:, 1], test_grid[count][:, 2]] test_pred_label = train2SemKITTI(test_pred_label) test_pred_label = np.expand_dims(test_pred_label, axis=1) label_token = test_pt_dataset.train_token_list[ test_index[count]] new_save_dir = os.path.join(output_path, label_token + '_lidarseg.bin') if not os.path.exists(os.path.dirname(new_save_dir)): try: os.makedirs(os.path.dirname(new_save_dir)) except OSError as exc: if exc.errno != errno.EEXIST: raise test_pred_label = test_pred_label.astype(np.uint32) test_pred_label.tofile(new_save_dir) pbar.update(1) del test_grid, test_pt_fea, test_index pbar.close() print('Predicted test labels are saved in %s.' % output_path)
def main(args): data_path = args['dataset']['path'] test_batch_size = args['model']['test_batch_size'] pretrained_model = args['model']['pretrained_model'] output_path = args['dataset']['output_path'] compression_model = args['dataset']['grid_size'][2] grid_size = args['dataset']['grid_size'] visibility = args['model']['visibility'] pytorch_device = torch.device('cuda:0') if args['model']['polar']: fea_dim = 9 circular_padding = True else: fea_dim = 7 circular_padding = False # prepare miou fun unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1] # prepare model my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility) my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size = grid_size, fea_dim = fea_dim, max_pt_per_encode = 256, out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model) if os.path.exists(pretrained_model): my_model.load_state_dict(torch.load(pretrained_model)) pytorch_total_params = sum(p.numel() for p in my_model.parameters()) print('params: ',pytorch_total_params) my_model.to(pytorch_device) my_model.eval() # prepare dataset test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'test', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path']) val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path']) if args['model']['polar']: test_dataset=spherical_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True) val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0) else: test_dataset=voxel_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True) val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0) test_dataset_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = test_batch_size, collate_fn = collate_fn_BEV_test, shuffle = False, num_workers = 4) val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset, batch_size = test_batch_size, collate_fn = collate_fn_BEV, shuffle = False, num_workers = 4) # validation print('*'*80) print('Test network performance on validation split') print('*'*80) pbar = tqdm(total=len(val_dataset_loader)) time_list = [] pp_time_list = [] evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50) with torch.no_grad(): for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader): val_vox_fea_ten = val_vox_fea.to(pytorch_device) val_vox_label = SemKITTI2train(val_vox_label) val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea] val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid] val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device) val_gt_center_tensor = val_gt_center.to(pytorch_device) val_gt_offset_tensor = val_gt_offset.to(pytorch_device) torch.cuda.synchronize() start_time = time.time() if visibility: predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten) else: predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten) torch.cuda.synchronize() time_list.append(time.time()-start_time) for count,i_val_grid in enumerate(val_grid): # get foreground_mask for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device) for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True # post processing torch.cuda.synchronize() start_time = time.time() panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),val_pt_dataset.thing_list,\ threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\ top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask) torch.cuda.synchronize() pp_time_list.append(time.time()-start_time) panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32) panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count])) del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points pbar.update(1) class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ() miou,ious = evaluator.getSemIoU() print('Validation per class PQ, SQ, RQ and IoU: ') for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]): print('%15s : %6.2f%% %6.2f%% %6.2f%% %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100)) pbar.close() print('Current val PQ is %.3f' % (class_PQ*100)) print('Current val miou is %.3f'% (miou*100)) print('Inference time per %d is %.4f seconds\n, postprocessing time is %.4f seconds per scan' % (test_batch_size,np.mean(time_list),np.mean(pp_time_list))) # test print('*'*80) print('Generate predictions for test split') print('*'*80) pbar = tqdm(total=len(test_dataset_loader)) with torch.no_grad(): for i_iter_test,(test_vox_fea,_,_,_,test_grid,_,_,test_pt_fea,test_index) in enumerate(test_dataset_loader): # predict test_vox_fea_ten = test_vox_fea.to(pytorch_device) test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea] test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid] if visibility: predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten,test_vox_fea_ten) else: predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten) # write to label file for count,i_test_grid in enumerate(test_grid): # get foreground_mask for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device) for_mask[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] = True # post processing panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),test_pt_dataset.thing_list,\ threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\ top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask) panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32) panoptic = panoptic_labels[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] save_dir = test_pt_dataset.im_idx[test_index[count]] _,dir2 = save_dir.split('/sequences/',1) new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne','predictions')[:-3]+'label' if not os.path.exists(os.path.dirname(new_save_dir)): try: os.makedirs(os.path.dirname(new_save_dir)) except OSError as exc: if exc.errno != errno.EEXIST: raise panoptic.tofile(new_save_dir) del test_pt_fea_ten,test_grid_ten,test_pt_fea,predict_labels,center,offset pbar.update(1) pbar.close() print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path) print('Remapping script can be found in semantic-kitti-api.')
def main(args): data_path = args.data_dir train_batch_size = args.train_batch_size val_batch_size = args.val_batch_size check_iter = args.check_iter model_save_path = args.model_save_path compression_model = args.grid_size[2] grid_size = args.grid_size pytorch_device = torch.device('cuda:0') model = args.model if model == 'polar': fea_dim = 9 circular_padding = True elif model == 'traditional': fea_dim = 7 circular_padding = False #prepare miou fun unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] #prepare model my_BEV_model = BEV_Unet(n_class=len(unique_label), n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding) my_model = ptBEVnet(my_BEV_model, pt_model='pointnet', grid_size=grid_size, fea_dim=fea_dim, max_pt_per_encode=256, out_pt_fea_dim=512, kernal_size=1, pt_selection='random', fea_compre=compression_model) if os.path.exists(model_save_path): my_model.load_state_dict(torch.load(model_save_path)) my_model.to(pytorch_device) optimizer = optim.Adam(my_model.parameters()) loss_fun = torch.nn.CrossEntropyLoss(ignore_index=255) #prepare dataset train_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='train', return_ref=True) val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='val', return_ref=True) if model == 'polar': train_dataset = spherical_dataset(train_pt_dataset, grid_size=grid_size, flip_aug=True, ignore_label=0, rotate_aug=True, fixed_volume_space=True) val_dataset = spherical_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True) elif model == 'traditional': train_dataset = voxel_dataset(train_pt_dataset, grid_size=grid_size, flip_aug=True, ignore_label=0, rotate_aug=True, fixed_volume_space=True) val_dataset = voxel_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True) train_dataset_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=train_batch_size, collate_fn=collate_fn_BEV, shuffle=True, num_workers=4) val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=val_batch_size, collate_fn=collate_fn_BEV, shuffle=False, num_workers=4) # training epoch = 0 best_val_miou = 0 start_training = False my_model.train() global_iter = 0 exce_counter = 0 while True: loss_list = [] pbar = tqdm(total=len(train_dataset_loader)) for i_iter, (_, train_vox_label, train_grid, _, train_pt_fea) in enumerate(train_dataset_loader): # validation if global_iter % check_iter == 0: my_model.eval() hist_list = [] val_loss_list = [] with torch.no_grad(): for i_iter_val, ( _, val_vox_label, val_grid, val_pt_labs, val_pt_fea) in enumerate(val_dataset_loader): val_vox_label = SemKITTI2train(val_vox_label) val_pt_labs = SemKITTI2train(val_pt_labs) val_pt_fea_ten = [ torch.from_numpy(i).type( torch.FloatTensor).to(pytorch_device) for i in val_pt_fea ] val_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid ] val_label_tensor = val_vox_label.type( torch.LongTensor).to(pytorch_device) predict_labels = my_model(val_pt_fea_ten, val_grid_ten) loss = lovasz_softmax(torch.nn.functional.softmax( predict_labels).detach(), val_label_tensor, ignore=255) + loss_fun( predict_labels.detach(), val_label_tensor) predict_labels = torch.argmax(predict_labels, dim=1) predict_labels = predict_labels.cpu().detach().numpy() for count, i_val_grid in enumerate(val_grid): hist_list.append( fast_hist_crop( predict_labels[count, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]], val_pt_labs[count], unique_label)) val_loss_list.append(loss.detach().cpu().numpy()) my_model.train() iou = per_class_iu(sum(hist_list)) print('Validation per class iou: ') for class_name, class_iou in zip(unique_label_str, iou): print('%s : %.2f%%' % (class_name, class_iou * 100)) val_miou = np.nanmean(iou) * 100 del val_vox_label, val_grid, val_pt_fea, val_grid_ten # save model if performance is improved if best_val_miou < val_miou: best_val_miou = val_miou torch.save(my_model.state_dict(), model_save_path) print( 'Current val miou is %.3f while the best val miou is %.3f' % (val_miou, best_val_miou)) print('Current val loss is %.3f' % (np.mean(val_loss_list))) if start_training: print('epoch %d iter %5d, loss: %.3f\n' % (epoch, i_iter, np.mean(loss_list))) print('%d exceptions encountered during last training\n' % exce_counter) exce_counter = 0 loss_list = [] # training try: train_vox_label = SemKITTI2train(train_vox_label) train_pt_fea_ten = [ torch.from_numpy(i).type( torch.FloatTensor).to(pytorch_device) for i in train_pt_fea ] train_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in train_grid ] train_vox_ten = [ torch.from_numpy(i).to(pytorch_device) for i in train_grid ] point_label_tensor = train_vox_label.type( torch.LongTensor).to(pytorch_device) # forward + backward + optimize outputs = my_model(train_pt_fea_ten, train_grid_ten) loss = lovasz_softmax(torch.nn.functional.softmax(outputs), point_label_tensor, ignore=255) + loss_fun( outputs, point_label_tensor) loss.backward() optimizer.step() loss_list.append(loss.item()) except Exception: exce_counter += 1 # zero the parameter gradients optimizer.zero_grad() pbar.update(1) start_training = True global_iter += 1 pbar.close() epoch += 1
def main(args): data_path = args['dataset']['path'] train_batch_size = args['model']['train_batch_size'] val_batch_size = args['model']['val_batch_size'] check_iter = args['model']['check_iter'] model_save_path = args['model']['model_save_path'] pretrained_model = args['model']['pretrained_model'] compression_model = args['dataset']['grid_size'][2] grid_size = args['dataset']['grid_size'] visibility = args['model']['visibility'] pytorch_device = torch.device('cuda:0') if args['model']['polar']: fea_dim = 9 circular_padding = True else: fea_dim = 7 circular_padding = False #prepare miou fun unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] #prepare model my_BEV_model = BEV_Unet(n_class=len(unique_label), n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding, use_vis_fea=visibility) my_model = ptBEVnet(my_BEV_model, pt_model='pointnet', grid_size=grid_size, fea_dim=fea_dim, max_pt_per_encode=256, out_pt_fea_dim=512, kernal_size=1, pt_selection='random', fea_compre=compression_model) if os.path.exists(model_save_path): my_model = load_pretrained_model(my_model, torch.load(model_save_path)) elif os.path.exists(pretrained_model): my_model = load_pretrained_model(my_model, torch.load(pretrained_model)) my_model.to(pytorch_device) optimizer = optim.Adam(my_model.parameters()) loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\ center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss']) #prepare dataset train_pt_dataset = SemKITTI( data_path + '/sequences/', imageset='train', return_ref=True, instance_pkl_path=args['dataset']['instance_pkl_path']) val_pt_dataset = SemKITTI( data_path + '/sequences/', imageset='val', return_ref=True, instance_pkl_path=args['dataset']['instance_pkl_path']) if args['model']['polar']: train_dataset = spherical_dataset(train_pt_dataset, args['dataset'], grid_size=grid_size, ignore_label=0, use_aug=True) val_dataset = spherical_dataset(val_pt_dataset, args['dataset'], grid_size=grid_size, ignore_label=0) else: train_dataset = voxel_dataset(train_pt_dataset, args['dataset'], grid_size=grid_size, ignore_label=0, use_aug=True) val_dataset = voxel_dataset(val_pt_dataset, args['dataset'], grid_size=grid_size, ignore_label=0) train_dataset_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=train_batch_size, collate_fn=collate_fn_BEV, shuffle=True, num_workers=4) val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=val_batch_size, collate_fn=collate_fn_BEV, shuffle=False, num_workers=4) # training epoch = 0 best_val_PQ = 0 start_training = False my_model.train() global_iter = 0 exce_counter = 0 evaluator = PanopticEval(len(unique_label) + 1, None, [0], min_points=50) while epoch < args['model']['max_epoch']: pbar = tqdm(total=len(train_dataset_loader)) for i_iter, (train_vox_fea, train_label_tensor, train_gt_center, train_gt_offset, train_grid, _, _, train_pt_fea) in enumerate(train_dataset_loader): # validation if global_iter % check_iter == 0: my_model.eval() evaluator.reset() with torch.no_grad(): for i_iter_val, ( val_vox_fea, val_vox_label, val_gt_center, val_gt_offset, val_grid, val_pt_labels, val_pt_ints, val_pt_fea) in enumerate(val_dataset_loader): val_vox_fea_ten = val_vox_fea.to(pytorch_device) val_vox_label = SemKITTI2train(val_vox_label) val_pt_fea_ten = [ torch.from_numpy(i).type( torch.FloatTensor).to(pytorch_device) for i in val_pt_fea ] val_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in val_grid ] val_label_tensor = val_vox_label.type( torch.LongTensor).to(pytorch_device) val_gt_center_tensor = val_gt_center.to(pytorch_device) val_gt_offset_tensor = val_gt_offset.to(pytorch_device) if visibility: predict_labels, center, offset = my_model( val_pt_fea_ten, val_grid_ten, val_vox_fea_ten) else: predict_labels, center, offset = my_model( val_pt_fea_ten, val_grid_ten) for count, i_val_grid in enumerate(val_grid): # get foreground_mask for_mask = torch.zeros( 1, grid_size[0], grid_size[1], grid_size[2], dtype=torch.bool).to(pytorch_device) for_mask[0, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]] = True # post processing panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\ val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\ top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask) panoptic_labels = panoptic_labels.cpu().detach( ).numpy().astype(np.int32) panoptic = panoptic_labels[0, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]] evaluator.addBatch( panoptic & 0xFFFF, panoptic, np.squeeze(val_pt_labels[count]), np.squeeze(val_pt_ints[count])) del val_vox_label, val_pt_fea_ten, val_label_tensor, val_grid_ten, val_gt_center, val_gt_center_tensor, val_gt_offset, val_gt_offset_tensor, predict_labels, center, offset, panoptic_labels, center_points my_model.train() class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ( ) miou, ious = evaluator.getSemIoU() print('Validation per class PQ, SQ, RQ and IoU: ') for class_name, class_pq, class_sq, class_rq, class_iou in zip( unique_label_str, class_all_PQ[1:], class_all_SQ[1:], class_all_RQ[1:], ious[1:]): print('%15s : %6.2f%% %6.2f%% %6.2f%% %6.2f%%' % (class_name, class_pq * 100, class_sq * 100, class_rq * 100, class_iou * 100)) # save model if performance is improved if best_val_PQ < class_PQ: best_val_PQ = class_PQ torch.save(my_model.state_dict(), model_save_path) print('Current val PQ is %.3f while the best val PQ is %.3f' % (class_PQ * 100, best_val_PQ * 100)) print('Current val miou is %.3f' % (miou * 100)) if start_training: sem_l, hm_l, os_l = np.mean( loss_fn.lost_dict['semantic_loss']), np.mean( loss_fn.lost_dict['heatmap_loss']), np.mean( loss_fn.lost_dict['offset_loss']) print( 'epoch %d iter %5d, loss: %.3f, semantic loss: %.3f, heatmap loss: %.3f, offset loss: %.3f\n' % (epoch, i_iter, sem_l + hm_l + os_l, sem_l, hm_l, os_l)) print('%d exceptions encountered during last training\n' % exce_counter) exce_counter = 0 loss_fn.reset_loss_dict() # training try: train_vox_fea_ten = train_vox_fea.to(pytorch_device) train_label_tensor = SemKITTI2train(train_label_tensor) train_pt_fea_ten = [ torch.from_numpy(i).type( torch.FloatTensor).to(pytorch_device) for i in train_pt_fea ] train_grid_ten = [ torch.from_numpy(i[:, :2]).to(pytorch_device) for i in train_grid ] train_label_tensor = train_label_tensor.type( torch.LongTensor).to(pytorch_device) train_gt_center_tensor = train_gt_center.to(pytorch_device) train_gt_offset_tensor = train_gt_offset.to(pytorch_device) if args['model']['enable_SAP'] and epoch >= args['model'][ 'SAP']['start_epoch']: for fea in train_pt_fea_ten: fea.requires_grad_() # forward if visibility: sem_prediction, center, offset = my_model( train_pt_fea_ten, train_grid_ten, train_vox_fea_ten) else: sem_prediction, center, offset = my_model( train_pt_fea_ten, train_grid_ten) # loss loss = loss_fn(sem_prediction, center, offset, train_label_tensor, train_gt_center_tensor, train_gt_offset_tensor) # self adversarial pruning if args['model']['enable_SAP'] and epoch >= args['model'][ 'SAP']['start_epoch']: loss.backward() for i, fea in enumerate(train_pt_fea_ten): fea_grad = torch.norm(fea.grad, dim=1) top_k_grad, _ = torch.topk( fea_grad, int(args['model']['SAP']['rate'] * fea_grad.shape[0])) # delete high influential points train_pt_fea_ten[i] = train_pt_fea_ten[i][ fea_grad < top_k_grad[-1]] train_grid_ten[i] = train_grid_ten[i][ fea_grad < top_k_grad[-1]] optimizer.zero_grad() # second pass # forward if visibility: sem_prediction, center, offset = my_model( train_pt_fea_ten, train_grid_ten, train_vox_fea_ten) else: sem_prediction, center, offset = my_model( train_pt_fea_ten, train_grid_ten) # loss loss = loss_fn(sem_prediction, center, offset, train_label_tensor, train_gt_center_tensor, train_gt_offset_tensor) # backward + optimize loss.backward() optimizer.step() except Exception as error: if exce_counter == 0: print(error) exce_counter += 1 # zero the parameter gradients optimizer.zero_grad() pbar.update(1) start_training = True global_iter += 1 pbar.close() epoch += 1