def iterDataset(dataloader, pair_generator, ntg_model, vis, threshold=10, use_cuda=True): ''' 迭代数据集中的批次数据,进行处理 :param dataloader: :param pair_generator: :param ntg_model: :param use_cuda: :return: ''' grid_loss_hist = [] grid_loss_traditional_hist = [] loss_fn = NTGLoss() gridGen = AffineGridGen() grid_loss = GridLoss(use_cuda=use_cuda) grid_loss_list = [] grid_loss_ntg_list = [] grid_loss_comb_list = [] ntg_loss_total = 0 # batch {image.shape = } for batch_idx, batch in enumerate(dataloader): #print("batch_id",batch_idx,'/',len(dataloader)) # if batch_idx == 2: # break if batch_idx % 5 == 0: print('test batch: [{}/{} ({:.0f}%)]'.format( batch_idx, len(dataloader), 100. * batch_idx / len(dataloader))) pair_batch = pair_generator( batch) # image[batch_size,1,w,h] theta_GT[batch_size,2,3] theta_estimate_batch = ntg_model(pair_batch) # theta [batch_size,6] source_image_batch = pair_batch['source_image'] target_image_batch = pair_batch['target_image'] theta_GT_batch = pair_batch['theta_GT'] sampling_grid = gridGen(theta_estimate_batch.view(-1, 2, 3)) warped_image_batch = F.grid_sample(source_image_batch, sampling_grid) loss, g1xy, g2xy = loss_fn(target_image_batch, warped_image_batch) #print("one batch ntg:",loss.item()) ntg_loss_total += loss.item() # 显示CNN配准结果 # print("显示图片") visualize_cnn_result(source_image_batch, target_image_batch, theta_estimate_batch, vis) # # # time.sleep(10) # 显示一个epoch的对比结果 #visualize_compare_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda) # 显示多个epoch的折线图 #visualize_iter_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda) ## 计算网格点损失配准误差 # 将pytorch的变换参数转为opencv的变换参数 #theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) # P5使用传统NTG方法进行优化cnn的结果 #ntg_param = estimate_param_batch(source_image_batch,target_image_batch,None,itermax=600) #ntg_param_pytorch = param2theta(ntg_param,240,240,use_cuda=use_cuda) #cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv,itermax=800) #cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda) loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch, theta_GT_batch) #loss_ntg = grid_loss.compute_grid_loss(ntg_param_pytorch,theta_GT_batch) #loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch,theta_GT_batch) grid_loss_list.append(loss_cnn.detach().cpu()) #grid_loss_ntg_list.append(loss_ntg) #grid_loss_comb_list.append(loss_cnn_ntg) ## # 显示特定epoch的gridloss的直方图 # g_loss,g_trad_loss = visualize_spec_epoch_result(source_image_batch, target_image_batch, theta_GT_batch, theta_estimate_batch, # use_cuda=use_cuda) # grid_loss_hist.append(g_loss) # grid_loss_traditional_hist.append(g_trad_loss) # loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch,theta_GT_list) # # loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param,theta_GT_list) print("计算平均网格点损失") compute_average_grid_loss(grid_loss_list) print("计算平均NTG值", ntg_loss_total / len(dataloader)) print("计算正确率") compute_correct_rate(grid_loss_list, threshold=threshold)
def main(args): # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_paper_NTG_resnet101.pth.tar" # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/coco2017_multi_gpu/checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar" #args.training_image_path = '/home/zale/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages' # args.training_image_path = '/media/disk2/zale/datasets/coco2017/train2017' checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_three_channel_paper_origin_NTG_resnet101.pth.tar" args.training_image_path = '/home/zlk/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages' random_seed = 10021 init_seeds(random_seed + random.randint(0, 10000)) mixed_precision = True utils.init_distributed_mode(args) print(args) #device,local_rank = torch_util.select_device(multi_process =True,apex=mixed_precision) device = torch.device(args.device) use_cuda = True # Data loading code print("Loading data") RandomTnsDataset = RandomTnsData(args.training_image_path, cache_images=False, paper_affine_generator=True, transform=NormalizeImageDict(["image"])) # train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, # pin_memory=True) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( RandomTnsDataset) # test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(RandomTnsDataset) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) # train_batch_sampler = torch.utils.data.BatchSampler( # train_sampler, args.batch_size, drop_last=True) data_loader = DataLoader(RandomTnsDataset, sampler=train_sampler, num_workers=4, shuffle=(train_sampler is None), pin_memory=False, batch_size=args.batch_size) # data_loader_test = torch.utils.data.DataLoader( # dataset_test, batch_size=1, # sampler=test_sampler, num_workers=args.workers, # collate_fn=utils.collate_fn) print("Creating model") model = CNNRegistration(use_cuda=use_cuda) model.to(device) # 优化器 和scheduler params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.Adam(params, lr=args.lr) # 学习率小于1e-6 ntg损失下降很慢,所以最小设置为1e-6 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.lr_max_iter, eta_min=1e-6) # if mixed_precision: # model,optimizer = amp.initialize(model,optimizer,opt_level='O1',verbosity=0) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module minium_loss, saved_epoch = load_checkpoint(model_without_ddp, optimizer, lr_scheduler, checkpoint_path, args.rank) vis_env = "multi_gpu_rgb_train_paper_30" loss = NTGLoss() pair_generator = RandomTnsPair(use_cuda=use_cuda) gridGen = AffineGridGen() vis = VisdomHelper(env_name=vis_env) print('Starting training...') start_time = time.time() draw_test_loss = False log_interval = 20 for epoch in range(saved_epoch, args.num_epochs): start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) train_loss = train(epoch, model, loss, optimizer, data_loader, pair_generator, gridGen, vis, use_cuda=use_cuda, log_interval=log_interval, lr_scheduler=lr_scheduler, rank=args.rank) if draw_test_loss: #test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda) #vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table') pass else: vis.drawLoss(epoch, train_loss) end_time = time.time() print("epoch:", str(end_time - start_time), '秒') is_best = train_loss < minium_loss minium_loss = min(train_loss, minium_loss) state_dict = model_without_ddp.state_dict() if is_main_process(): save_checkpoint( { 'epoch': epoch + 1, 'args': args, # 'state_dict': model.state_dict(), 'state_dict': state_dict, 'minium_loss': minium_loss, 'model_loss': train_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), }, is_best, checkpoint_path)
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False, random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True): init_seeds(random_seed+random.randint(0,10000)) device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision) # args.batch_size = args.batch_size * torch.cuda.device_count() args.batch_size = 16 args.lr_scheduler = True draw_test_loss = False print(args.batch_size) print("创建模型中") model = CNNRegistration(use_cuda=use_cuda) model = model.to(device) # 优化器 和scheduler optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr) if args.lr_scheduler: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.lr_max_iter, eta_min=1e-7) else: scheduler = False print("加载权重") minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0) # Mixed precision training https://github.com/NVIDIA/apex if mixed_precision: model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0) if multi_gpu: model = nn.DataParallel(model) loss = NTGLoss() pair_generator = RandomTnsPair(use_cuda=use_cuda) gridGen = AffineGridGen() vis = VisdomHelper(env_name=vis_env) print("创建dataloader") RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator, transform=NormalizeImageDict(["image"])) train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) if draw_test_loss: testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator, transform=NormalizeImageDict(["image"])) test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) print('Starting training...') for epoch in range(saved_epoch, args.num_epochs): start_time = time.time() train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis, use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler) if draw_test_loss: test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda) vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table') else: vis.drawLoss(epoch,train_loss) end_time = time.time() print("epoch:", str(end_time - start_time),'秒') is_best = train_loss < minium_loss minium_loss = min(train_loss, minium_loss) state_dict = model.module.state_dict() if multi_gpu else model.state_dict() save_checkpoint({ 'epoch': epoch + 1, 'args': args, #'state_dict': model.state_dict(), 'state_dict': state_dict, 'minium_loss': minium_loss, 'model_loss':train_loss, 'optimizer': optimizer.state_dict(), }, is_best, out_path)