def train_net_new(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Set up data loader pnum = 2048 crop_point_num = 512 workers = 1 batchSize = 16 class_name = "Pistol" train_dataset_loader = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=False, class_choice=class_name, npoints=pnum, split='train') train_data_loader = torch.utils.data.DataLoader(train_dataset_loader, batch_size=batchSize, shuffle=True, num_workers=int(workers)) test_dataset_loader = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=False, class_choice=class_name, npoints=pnum, split='test') val_data_loader = torch.utils.data.DataLoader(test_dataset_loader, batch_size=batchSize, shuffle=True, num_workers=int(workers)) # Set up folders for logs and checkpoints output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' cfg.DIR.LOGS = output_dir % 'logs' if not os.path.exists(cfg.DIR.CHECKPOINTS): os.makedirs(cfg.DIR.CHECKPOINTS) # Create tensorboard writers train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) # Create the networks grnet = GRNet(cfg, seg_class_no) grnet.apply(utils.helpers.init_weights) logging.debug('Parameters in GRNet: %d.' % utils.helpers.count_parameters(grnet)) # Move the network to GPU if possible grnet = grnet.to(device) # Create the optimizers grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, grnet.parameters()), lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.TRAIN.WEIGHT_DECAY, betas=cfg.TRAIN.BETAS) grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( grnet_optimizer, milestones=cfg.TRAIN.LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) # Set up loss functions chamfer_dist = ChamferDistance() gridding_loss = GriddingLoss( # lgtm [py/unused-local-variable] scales=cfg.NETWORK.GRIDDING_LOSS_SCALES, alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS) seg_criterion = torch.nn.CrossEntropyLoss().cuda() # Load pretrained model if exists init_epoch = 0 best_metrics = None if 'WEIGHTS' in cfg.CONST: logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) grnet.load_state_dict(checkpoint['grnet']) logging.info( 'Recover complete. Current epoch = #%d; best metrics = %s.' % (init_epoch, best_metrics)) train_seg_on_sparse = False train_seg_on_dense = False miou = 0 # Training/Testing the network for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): epoch_start_time = time() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(['SparseLoss', 'DenseLoss']) grnet.train() if epoch_idx == 5: train_seg_on_sparse = True if epoch_idx == 7: train_seg_on_dense = True batch_end_time = time() n_batches = len(train_data_loader) for batch_idx, ( data, seg, model_ids, ) in enumerate(train_data_loader): data_time.update(time() - batch_end_time) input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3) input_cropped1 = input_cropped1.data.copy_(data) if batch_idx == 10: pass #break data = data.to(device) seg = seg.to(device) input_cropped1 = input_cropped1.to(device) # remove points to make input incomplete choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(data.size()[0]): index = random.sample(choice, 1) p_center = index[0].to(device) distances = torch.sum((data[m] - p_center)**2, dim=1) order = torch.argsort(distances) zero_point = torch.FloatTensor([0, 0, 0]).to(device) input_cropped1.data[m, order[:crop_point_num]] = zero_point if save_crop_mode: np.save(class_name + "_orig", data[0].detach().cpu().numpy()) np.save(class_name + "_cropped", input_cropped1[0].detach().cpu().numpy()) sys.exit() sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet( input_cropped1) data_seg = get_data_seg(data, full_seg) seg_loss = seg_criterion(torch.transpose(data_seg, 1, 2), seg) if train_seg_on_sparse and train_seg: gt_seg = get_seg_gts(seg, data, sparse_ptcloud) seg_loss += seg_criterion(torch.transpose(sparse_seg, 1, 2), gt_seg) seg_loss /= 2 if train_seg_on_dense and train_seg: gt_seg = get_seg_gts(seg, data, dense_ptcloud) dense_seg_loss = seg_criterion( torch.transpose(dense_seg, 1, 2), gt_seg) print(dense_seg_loss.item()) if draw_mode: plot_ptcloud(data[0], seg[0], "orig") plot_ptcloud(input_cropped1[0], seg[0], "cropped") plot_ptcloud(sparse_ptcloud[0], torch.argmax(sparse_seg[0], dim=1), "sparse_pred") if not train_seg_on_sparse: gt_seg = get_seg_gts(seg, data, sparse_ptcloud) #plot_ptcloud(sparse_ptcloud[0], gt_seg[0], "sparse_gt") #if not train_seg_on_dense: #gt_seg = get_seg_gts(seg, data, sparse_ptcloud) print(dense_seg.size()) plot_ptcloud(dense_ptcloud[0], torch.argmax(dense_seg[0], dim=1), "dense_pred") sys.exit() print(seg_loss.item()) lamb = 0.8 sparse_loss = chamfer_dist(sparse_ptcloud, data).to(device) dense_loss = chamfer_dist(dense_ptcloud, data).to(device) grid_loss = gridding_loss(sparse_ptcloud, data).to(device) if train_seg: _loss = lamb * (sparse_loss + dense_loss + grid_loss) + (1 - lamb) * seg_loss else: _loss = (sparse_loss + dense_loss + grid_loss) if train_seg_on_dense and train_seg: _loss += (1 - lamb) * dense_seg_loss _loss.to(device) losses.update( [sparse_loss.item() * 1000, dense_loss.item() * 1000]) grnet.zero_grad() _loss.backward() grnet_optimizer.step() n_itr = (epoch_idx - 1) * n_batches + batch_idx train_writer.add_scalar('Loss/Batch/Sparse', sparse_loss.item() * 1000, n_itr) train_writer.add_scalar('Loss/Batch/Dense', dense_loss.item() * 1000, n_itr) batch_time.update(time() - batch_end_time) batch_end_time = time() logging.info( '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches, batch_time.val(), data_time.val(), ['%.4f' % l for l in losses.val()])) # Validate the current model if train_seg: miou_new = test_net_new(cfg, epoch_idx, val_data_loader, val_writer, grnet) else: miou_new = 0 grnet_lr_scheduler.step() epoch_end_time = time() train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx) train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx) logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()])) if not train_seg or miou_new > miou: file_name = class_name + 'noseg-ckpt-epoch.pth' output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) torch.save({ 'epoch_index': epoch_idx, 'grnet': grnet.state_dict() }, output_path) # yapf: disable logging.info('Saved checkpoint to %s ...' % output_path) miou = miou_new train_writer.close() val_writer.close()
def test_net_new(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, grnet=None): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") pnum = 2048 crop_point_num = 512 workers = 1 batchSize = 16 if test_data_loader == None: test_dataset_loader = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=False, class_choice=save_name, npoints=pnum, split='test') test_data_loader = torch.utils.data.DataLoader( test_dataset_loader, batch_size=batchSize, shuffle=True, num_workers=int(workers)) # Setup networks and initialize networks if grnet is None: grnet = GRNet(cfg, 4) if torch.cuda.is_available(): grnet = grnet.to(device) logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) grnet.load_state_dict(checkpoint['grnet']) # Switch models to evaluation mode grnet.eval() # Set up loss functions chamfer_dist = ChamferDistance() gridding_loss = GriddingLoss( scales=cfg.NETWORK.GRIDDING_LOSS_SCALES, alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS) # lgtm [py/unused-import] seg_criterion = torch.nn.CrossEntropyLoss().cuda() total_sparse_cd = 0 total_dense_cd = 0 total_sparse_ce = 0 total_dense_ce = 0 total_sparse_miou = 0 total_dense_miou = 0 total_sparse_acc = 0 total_dense_acc = 0 # Testing loop for batch_idx, ( data, seg, model_ids, ) in enumerate(test_data_loader): model_id = model_ids[0] with torch.no_grad(): input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3) input_cropped1 = input_cropped1.data.copy_(data) if batch_idx == 200: pass # break data = data.to(device) seg = seg.to(device) input_cropped1 = input_cropped1.to(device) # remove points to make input incomplete choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(data.size()[0]): index = random.sample(choice, 1) p_center = index[0].to(device) distances = torch.sum((data[m] - p_center)**2, dim=1) order = torch.argsort(distances) zero_point = torch.FloatTensor([0, 0, 0]).to(device) input_cropped1.data[m, order[:crop_point_num]] = zero_point sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet( input_cropped1) if save_mode: np.save("./saved_results/original_" + save_name, data.detach().cpu().numpy()) np.save("./saved_results/original_seg_" + save_name, seg.detach().cpu().numpy()) np.save("./saved_results/cropped_" + save_name, input_cropped1.detach().cpu().numpy()) np.save("./saved_results/sparse_" + save_name, sparse_ptcloud.detach().cpu().numpy()) np.save("./saved_results/sparse_seg_" + save_name, sparse_seg.detach().cpu().numpy()) np.save("./saved_results/dense_" + save_name, dense_ptcloud.detach().cpu().numpy()) np.save("./saved_results/dense_seg_" + save_name, dense_seg.detach().cpu().numpy()) sys.exit() total_sparse_cd += chamfer_dist(sparse_ptcloud, data).to(device) total_dense_cd += chamfer_dist(dense_ptcloud, data).to(device) sparse_seg_gt = get_seg_gts(seg, data, sparse_ptcloud) sparse_miou, sparse_acc = miou(torch.argmax(sparse_seg, dim=2), sparse_seg_gt) total_sparse_miou += sparse_miou total_sparse_acc += sparse_acc print(batch_idx) total_sparse_ce += seg_criterion(torch.transpose(sparse_seg, 1, 2), sparse_seg_gt) dense_seg_gt = get_seg_gts(seg, data, dense_ptcloud) dense_miou, dense_acc = miou(torch.argmax(dense_seg, dim=2), dense_seg_gt) total_dense_miou += dense_miou print(dense_miou) total_dense_acc += dense_acc total_dense_ce += seg_criterion(torch.transpose(dense_seg, 1, 2), dense_seg_gt) length = len(test_data_loader) print("sparse cd: " + str(total_sparse_cd * 1000 / length)) print("dense cd: " + str(total_dense_cd * 1000 / length)) print("sparse acc: " + str(total_sparse_acc / length)) print("dense acc: " + str(total_dense_acc / length)) print("sparse miou: " + str(total_sparse_miou / length)) print("dense miou: " + str(total_dense_miou / length)) print("sparse ce: " + str(total_sparse_ce / length)) print("dense ce: " + str(total_dense_ce / length)) return total_dense_miou / length