def main(args): pytorch_device = torch.device('cuda:0') config_path = args.config_path configs = load_config_data(config_path) dataset_config = configs['dataset_params'] train_dataloader_config = configs['train_data_loader'] val_dataloader_config = configs['val_data_loader'] val_batch_size = val_dataloader_config['batch_size'] train_batch_size = train_dataloader_config['batch_size'] model_config = configs['model_params'] train_hypers = configs['train_params'] grid_size = model_config['output_shape'] num_class = model_config['num_class'] ignore_label = dataset_config['ignore_label'] model_load_path = train_hypers['model_load_path'] model_save_path = train_hypers['model_save_path'] SemKITTI_label_name = get_nuScenes_label_name( dataset_config["label_mapping"]) unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] train_dataset_loader, val_dataset_loader = data_builder.build( dataset_config, train_dataloader_config, val_dataloader_config, grid_size=grid_size) # print('train_dataset_loader:',train_dataset_loader) # _, train_vox_label, train_grid, _, train_pt_fea = train_dataset_loader[1] # print('train_grid:', train_grid, end='\n') # print('train_pt_fea', train_pt_fea, end='\n') # print('train_vox_label', train_vox_label, end='\n') for i_iter, (_, train_vox_label, train_grid, _, train_pt_fea) in enumerate(train_dataset_loader): print('train_grid:', train_grid, end='\n') print('train_pt_fea', train_pt_fea, end='\n') print('train_vox_label', train_vox_label, end='\n') my_model = model_builder.build_pt(model_config) if os.path.exists(model_load_path): my_model = load_checkpoint_1b1(model_load_path, my_model) my_model.to(pytorch_device) optimizer = optim.Adam(my_model.parameters(), lr=train_hypers["learning_rate"]) loss_func, lovasz_softmax = loss_builder.build(wce=True, lovasz=True, num_class=num_class, ignore_label=ignore_label)
def main(args): pytorch_device = torch.device('cuda:0') config_path = args.config_path configs = load_config_data(config_path) dataset_config = configs['dataset_params'] train_dataloader_config = configs['train_data_loader'] val_dataloader_config = configs['val_data_loader'] val_batch_size = val_dataloader_config['batch_size'] train_batch_size = train_dataloader_config['batch_size'] model_config = configs['model_params'] train_hypers = configs['train_params'] grid_size = model_config['output_shape'] num_class = model_config['num_class'] ignore_label = dataset_config['ignore_label'] model_load_path = train_hypers['model_load_path'] model_save_path = train_hypers['model_save_path'] SemKITTI_label_name = get_SemKITTI_label_name(dataset_config["label_mapping"]) unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] my_model = model_builder.build(model_config) if os.path.exists(model_load_path): my_model = load_checkpoint(model_load_path, my_model) my_model.to(pytorch_device) optimizer = optim.Adam(my_model.parameters(), lr=train_hypers["learning_rate"]) loss_func, lovasz_softmax = loss_builder.build(wce=True, lovasz=True, num_class=num_class, ignore_label=ignore_label) train_dataset_loader, val_dataset_loader = data_builder.build(dataset_config, train_dataloader_config, val_dataloader_config, grid_size=grid_size) # training epoch = 0 best_val_miou = 0 my_model.train() global_iter = 0 check_iter = train_hypers['eval_every_n_steps'] while epoch < train_hypers['max_num_epochs']: loss_list = [] pbar = tqdm(total=len(train_dataset_loader)) time.sleep(10) # lr_scheduler.step(epoch) for i_iter, (_, train_vox_label, train_grid, _, train_pt_fea) in enumerate(train_dataset_loader): if global_iter % check_iter == 0 and epoch >= 1: my_model.eval() hist_list = [] val_loss_list = [] with torch.no_grad(): for i_iter_val, (_, val_vox_label, val_grid, val_pt_labs, val_pt_fea) in enumerate( val_dataset_loader): val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea] val_grid_ten = [torch.from_numpy(i).to(pytorch_device) for i in val_grid] val_label_tensor = val_vox_label.type(torch.LongTensor).to(pytorch_device) predict_labels = my_model(val_pt_fea_ten, val_grid_ten, val_batch_size) # aux_loss = loss_fun(aux_outputs, point_label_tensor) loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels).detach(), val_label_tensor, ignore=0) + loss_func(predict_labels.detach(), val_label_tensor) predict_labels = torch.argmax(predict_labels, dim=1) predict_labels = predict_labels.cpu().detach().numpy() for count, i_val_grid in enumerate(val_grid): hist_list.append(fast_hist_crop(predict_labels[ count, val_grid[count][:, 0], val_grid[count][:, 1], val_grid[count][:, 2]], val_pt_labs[count], unique_label)) val_loss_list.append(loss.detach().cpu().numpy()) # Sets the module in training mode. my_model.train() iou = per_class_iu(sum(hist_list)) print('Validation per class iou: ') for class_name, class_iou in zip(unique_label_str, iou): print('%s : %.2f%%' % (class_name, class_iou * 100)) val_miou = np.nanmean(iou) * 100 del val_vox_label, val_grid, val_pt_fea, val_grid_ten # save model if performance is improved if best_val_miou < val_miou: best_val_miou = val_miou torch.save(my_model.state_dict(), model_save_path) print('Current val miou is %.3f while the best val miou is %.3f' % (val_miou, best_val_miou)) print('Current val loss is %.3f' % (np.mean(val_loss_list))) wandb.log({"val_miou":val_miou, "val_loss_list":np.mean(val_loss_list)}) train_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in train_pt_fea] # train_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in train_grid] train_vox_ten = [torch.from_numpy(i).to(pytorch_device) for i in train_grid] point_label_tensor = train_vox_label.type(torch.LongTensor).to(pytorch_device) # forward + backward + optimize outputs = my_model(train_pt_fea_ten, train_vox_ten, train_batch_size) loss = lovasz_softmax(torch.nn.functional.softmax(outputs), point_label_tensor, ignore=0) + loss_func( outputs, point_label_tensor) loss.backward() # All optimizers implement a .step() method, that updates the parameters. optimizer.step() loss_list.append(loss.item()) if global_iter % 1000 == 0: if len(loss_list) > 0: print('epoch %d iter %5d, loss: %.3f\n' % (epoch, i_iter, np.mean(loss_list))) wandb.log({"train_loss":np.mean(loss_list)}) else: print('loss error') optimizer.zero_grad() pbar.update(1) global_iter += 1 if global_iter % check_iter == 0: if len(loss_list) > 0: print('epoch %d iter %5d, loss: %.3f\n' % (epoch, i_iter, np.mean(loss_list))) wandb.log({"train_loss":np.mean(loss_list)}) else: print('loss error') pbar.close() epoch += 1