args.feature_extraction_cnn + '.pth.tar') else: checkpoint_name = os.path.join( args.trained_models_dir, args.trained_models_fn + '_' + args.geometric_model + '_grid_loss' + args.feature_extraction_cnn + '.pth.tar') best_test_loss = float("inf") print('Starting training...') for epoch in range(1, args.num_epochs + 1): train_loss = train(epoch, model, loss, optimizer, dataloader, pair_generation_tnf, log_interval=100) test_loss = test(model, loss, dataloader_test, pair_generation_tnf) # remember best loss is_best = test_loss < best_test_loss best_test_loss = min(test_loss, best_test_loss) save_checkpoint( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(), 'best_test_loss': best_test_loss, 'optimizer': optimizer.state_dict(),
print("\t\t......Train config......") print("\t\t CNN model: ", args.feature_extraction_cnn) print("\t\t Geometric model: ", args.geometric_model) print("\t\t Dataset: ", args.training_dataset) print() print("\t\t Learning rate: ", args.lr) print("\t\t Batch size: ", args.batch_size) print("\t\t Maximum epoch: ", args.num_epochs) print("# ===================================== #\n") for epoch in range(1, args.num_epochs + 1): train_loss = train(epoch=epoch, cost=cost, optimizer=optimizer, dataset=dataset, pair_generation_tnf=pair_generation_tnf, sess=sess, batch_size=args.batch_size, source_train=source_train, target_train=target_train, theta_GT=theta_GT) # Save checkpoint if args.use_mse_loss: checkpoint_name = join( args.trained_models_dir, args.trained_models_fn + '_' + args.geometric_model + '_mse_loss_' + args.feature_extraction_cnn + '_' + args.training_dataset + '_epoch_' + str(epoch) + '.ckpt') else: checkpoint_name = join( args.trained_models_dir, args.trained_models_fn + '_' + args.geometric_model +
def main(): args, arg_groups = ArgumentParser(mode='train').parse() print(args) use_cuda = torch.cuda.is_available() device = torch.device('cuda') if use_cuda else torch.device('cpu') # Seed torch.manual_seed(args.seed) if use_cuda: torch.cuda.manual_seed(args.seed) # Download dataset if needed and set paths if args.training_dataset == 'pascal': if args.dataset_image_path == '' and not os.path.exists( 'datasets/pascal-voc11/TrainVal'): download_pascal('datasets/pascal-voc11/') if args.dataset_image_path == '': args.dataset_image_path = 'datasets/pascal-voc11/' args.dataset_csv_path = 'training_data/pascal-random' # CNN model and loss print('Creating CNN model...') if args.geometric_model == 'affine': cnn_output_dim = 6 elif args.geometric_model == 'hom' and args.four_point_hom: cnn_output_dim = 8 elif args.geometric_model == 'hom' and not args.four_point_hom: cnn_output_dim = 9 elif args.geometric_model == 'tps': cnn_output_dim = 18 model = CNNGeometric(use_cuda=use_cuda, output_dim=cnn_output_dim, **arg_groups['model']) if args.geometric_model == 'hom' and not args.four_point_hom: init_theta = torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1], device=device) model.FeatureRegression.linear.bias.data += init_theta if args.geometric_model == 'hom' and args.four_point_hom: init_theta = torch.tensor([-1, -1, 1, 1, -1, 1, -1, 1], device=device) model.FeatureRegression.linear.bias.data += init_theta if args.use_mse_loss: print('Using MSE loss...') loss = nn.MSELoss() else: print('Using grid loss...') loss = TransformedGridLoss(use_cuda=use_cuda, geometric_model=args.geometric_model) # Initialize Dataset objects dataset = SynthDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='train.csv', dataset_image_path=args.dataset_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) dataset_val = SynthDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='val.csv', dataset_image_path=args.dataset_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) # Set Tnf pair generation func pair_generation_tnf = SynthPairTnf(geometric_model=args.geometric_model, use_cuda=use_cuda) # Initialize DataLoaders dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, num_workers=4) # Optimizer and eventual scheduler optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr) if args.lr_scheduler: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.lr_max_iter, eta_min=1e-6) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') else: scheduler = False # Train # Set up names for checkpoints if args.use_mse_loss: ckpt = args.trained_model_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn checkpoint_path = os.path.join(args.trained_model_dir, args.trained_model_fn, ckpt + '.pth.tar') else: ckpt = args.trained_model_fn + '_' + args.geometric_model + '_grid_loss' + args.feature_extraction_cnn checkpoint_path = os.path.join(args.trained_model_dir, args.trained_model_fn, ckpt + '.pth.tar') if not os.path.exists(args.trained_model_dir): os.mkdir(args.trained_model_dir) # Set up TensorBoard writer if not args.log_dir: tb_dir = os.path.join(args.trained_model_dir, args.trained_model_fn + '_tb_logs') else: tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs') logs_writer = SummaryWriter(tb_dir) # add graph, to do so we have to generate a dummy input to pass along with the graph dummy_input = { 'source_image': torch.rand([args.batch_size, 3, 240, 240], device=device), 'target_image': torch.rand([args.batch_size, 3, 240, 240], device=device), 'theta_GT': torch.rand([16, 2, 3], device=device) } logs_writer.add_graph(model, dummy_input) # Start of training print('Starting training...') best_val_loss = float("inf") for epoch in range(1, args.num_epochs + 1): # we don't need the average epoch loss so we assign it to _ _ = train(epoch, model, loss, optimizer, dataloader, pair_generation_tnf, log_interval=args.log_interval, scheduler=scheduler, tb_writer=logs_writer) val_loss = validate_model(model, loss, dataloader_val, pair_generation_tnf, epoch, logs_writer) # remember best loss is_best = val_loss < best_val_loss best_val_loss = min(val_loss, best_val_loss) save_checkpoint( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(), 'best_val_loss': best_val_loss, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint_path) logs_writer.close() print('Done!')
def main(args): # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_paper_NTG_resnet101.pth.tar" # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/coco2017_multi_gpu/checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar" #args.training_image_path = '/home/zale/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages' # args.training_image_path = '/media/disk2/zale/datasets/coco2017/train2017' checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_three_channel_paper_origin_NTG_resnet101.pth.tar" args.training_image_path = '/home/zlk/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages' random_seed = 10021 init_seeds(random_seed + random.randint(0, 10000)) mixed_precision = True utils.init_distributed_mode(args) print(args) #device,local_rank = torch_util.select_device(multi_process =True,apex=mixed_precision) device = torch.device(args.device) use_cuda = True # Data loading code print("Loading data") RandomTnsDataset = RandomTnsData(args.training_image_path, cache_images=False, paper_affine_generator=True, transform=NormalizeImageDict(["image"])) # train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, # pin_memory=True) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( RandomTnsDataset) # test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(RandomTnsDataset) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) # train_batch_sampler = torch.utils.data.BatchSampler( # train_sampler, args.batch_size, drop_last=True) data_loader = DataLoader(RandomTnsDataset, sampler=train_sampler, num_workers=4, shuffle=(train_sampler is None), pin_memory=False, batch_size=args.batch_size) # data_loader_test = torch.utils.data.DataLoader( # dataset_test, batch_size=1, # sampler=test_sampler, num_workers=args.workers, # collate_fn=utils.collate_fn) print("Creating model") model = CNNRegistration(use_cuda=use_cuda) model.to(device) # 优化器 和scheduler params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.Adam(params, lr=args.lr) # 学习率小于1e-6 ntg损失下降很慢,所以最小设置为1e-6 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.lr_max_iter, eta_min=1e-6) # if mixed_precision: # model,optimizer = amp.initialize(model,optimizer,opt_level='O1',verbosity=0) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module minium_loss, saved_epoch = load_checkpoint(model_without_ddp, optimizer, lr_scheduler, checkpoint_path, args.rank) vis_env = "multi_gpu_rgb_train_paper_30" loss = NTGLoss() pair_generator = RandomTnsPair(use_cuda=use_cuda) gridGen = AffineGridGen() vis = VisdomHelper(env_name=vis_env) print('Starting training...') start_time = time.time() draw_test_loss = False log_interval = 20 for epoch in range(saved_epoch, args.num_epochs): start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) train_loss = train(epoch, model, loss, optimizer, data_loader, pair_generator, gridGen, vis, use_cuda=use_cuda, log_interval=log_interval, lr_scheduler=lr_scheduler, rank=args.rank) if draw_test_loss: #test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda) #vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table') pass else: vis.drawLoss(epoch, train_loss) end_time = time.time() print("epoch:", str(end_time - start_time), '秒') is_best = train_loss < minium_loss minium_loss = min(train_loss, minium_loss) state_dict = model_without_ddp.state_dict() if is_main_process(): save_checkpoint( { 'epoch': epoch + 1, 'args': args, # 'state_dict': model.state_dict(), 'state_dict': state_dict, 'minium_loss': minium_loss, 'model_loss': train_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), }, is_best, checkpoint_path)
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False, random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True): init_seeds(random_seed+random.randint(0,10000)) device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision) # args.batch_size = args.batch_size * torch.cuda.device_count() args.batch_size = 16 args.lr_scheduler = True draw_test_loss = False print(args.batch_size) print("创建模型中") model = CNNRegistration(use_cuda=use_cuda) model = model.to(device) # 优化器 和scheduler optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr) if args.lr_scheduler: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.lr_max_iter, eta_min=1e-7) else: scheduler = False print("加载权重") minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0) # Mixed precision training https://github.com/NVIDIA/apex if mixed_precision: model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0) if multi_gpu: model = nn.DataParallel(model) loss = NTGLoss() pair_generator = RandomTnsPair(use_cuda=use_cuda) gridGen = AffineGridGen() vis = VisdomHelper(env_name=vis_env) print("创建dataloader") RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator, transform=NormalizeImageDict(["image"])) train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) if draw_test_loss: testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator, transform=NormalizeImageDict(["image"])) test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) print('Starting training...') for epoch in range(saved_epoch, args.num_epochs): start_time = time.time() train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis, use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler) if draw_test_loss: test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda) vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table') else: vis.drawLoss(epoch,train_loss) end_time = time.time() print("epoch:", str(end_time - start_time),'秒') is_best = train_loss < minium_loss minium_loss = min(train_loss, minium_loss) state_dict = model.module.state_dict() if multi_gpu else model.state_dict() save_checkpoint({ 'epoch': epoch + 1, 'args': args, #'state_dict': model.state_dict(), 'state_dict': state_dict, 'minium_loss': minium_loss, 'model_loss':train_loss, 'optimizer': optimizer.state_dict(), }, is_best, out_path)
def main(): args, arg_groups = ArgumentParser(mode='train').parse() print(args) use_cuda = torch.cuda.is_available() use_me = args.use_me device = torch.device('cuda') if use_cuda else torch.device('cpu') # Seed # torch.manual_seed(args.seed) # if use_cuda: # torch.cuda.manual_seed(args.seed) # CNN model and loss print('Creating CNN model...') if args.geometric_model == 'affine_simple': cnn_output_dim = 3 elif args.geometric_model == 'affine_simple_4': cnn_output_dim = 4 else: raise NotImplementedError('Specified geometric model is unsupported') model = CNNGeometric(use_cuda=use_cuda, output_dim=cnn_output_dim, **arg_groups['model']) if args.geometric_model == 'affine_simple': init_theta = torch.tensor([0.0, 1.0, 0.0], device=device) elif args.geometric_model == 'affine_simple_4': init_theta = torch.tensor([0.0, 1.0, 0.0, 0.0], device=device) try: model.FeatureRegression.linear.bias.data += init_theta except: model.FeatureRegression.resnet.fc.bias.data += init_theta args.load_images = False if args.loss == 'mse': print('Using MSE loss...') loss = nn.MSELoss() elif args.loss == 'weighted_mse': print('Using weighted MSE loss...') loss = WeightedMSELoss(use_cuda=use_cuda) elif args.loss == 'reconstruction': print('Using reconstruction loss...') loss = ReconstructionLoss( int(np.rint(args.input_width * (1 - args.crop_factor) / 16) * 16), int(np.rint(args.input_height * (1 - args.crop_factor) / 16) * 16), args.input_height, use_cuda=use_cuda) args.load_images = True elif args.loss == 'combined': print('Using combined loss...') loss = CombinedLoss(args, use_cuda=use_cuda) if args.use_reconstruction_loss: args.load_images = True elif args.loss == 'grid': print('Using grid loss...') loss = SequentialGridLoss(use_cuda=use_cuda) else: raise NotImplementedError('Specifyed loss %s is not supported' % args.loss) # Initialize Dataset objects if use_me: dataset = MEDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='train.csv', dataset_image_path=args.dataset_image_path, input_height=args.input_height, input_width=args.input_width, crop=args.crop_factor, use_conf=args.use_conf, use_random_patch=args.use_random_patch, normalize_inputs=args.normalize_inputs, random_sample=args.random_sample, load_images=args.load_images) dataset_val = MEDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='val.csv', dataset_image_path=args.dataset_image_path, input_height=args.input_height, input_width=args.input_width, crop=args.crop_factor, use_conf=args.use_conf, use_random_patch=args.use_random_patch, normalize_inputs=args.normalize_inputs, random_sample=args.random_sample, load_images=args.load_images) else: dataset = SynthDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='train.csv', dataset_image_path=args.dataset_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) dataset_val = SynthDataset(geometric_model=args.geometric_model, dataset_csv_path=args.dataset_csv_path, dataset_csv_file='val.csv', dataset_image_path=args.dataset_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) # Set Tnf pair generation func if use_me: pair_generation_tnf = BatchTensorToVars(use_cuda=use_cuda) elif args.geometric_model == 'affine_simple' or args.geometric_model == 'affine_simple_4': pair_generation_tnf = SynthPairTnf(geometric_model='affine', use_cuda=use_cuda) else: raise NotImplementedError('Specified geometric model is unsupported') # Initialize DataLoaders dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, num_workers=4) # Optimizer optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr) # Train # Set up names for checkpoints ckpt = args.trained_model_fn + '_' + args.geometric_model + '_' + args.loss + '_loss_' checkpoint_path = os.path.join(args.trained_model_dir, args.trained_model_fn, ckpt + '.pth.tar') if not os.path.exists(args.trained_model_dir): os.mkdir(args.trained_model_dir) # Set up TensorBoard writer if not args.log_dir: tb_dir = os.path.join(args.trained_model_dir, args.trained_model_fn + '_tb_logs') else: tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs') logs_writer = SummaryWriter(tb_dir) # add graph, to do so we have to generate a dummy input to pass along with the graph if use_me: dummy_input = { 'mv_L2R': torch.rand([args.batch_size, 2, 216, 384], device=device), 'mv_R2L': torch.rand([args.batch_size, 2, 216, 384], device=device), 'grid_L2R': torch.rand([args.batch_size, 2, 216, 384], device=device), 'grid_R2L': torch.rand([args.batch_size, 2, 216, 384], device=device), 'grid': torch.rand([args.batch_size, 2, 216, 384], device=device), 'conf_L': torch.rand([args.batch_size, 1, 216, 384], device=device), 'conf_R': torch.rand([args.batch_size, 1, 216, 384], device=device), 'theta_GT': torch.rand([args.batch_size, 4], device=device), } if args.load_images: dummy_input['img_R_orig'] = torch.rand( [args.batch_size, 1, 216, 384], device=device) dummy_input['img_R'] = torch.rand([args.batch_size, 1, 216, 384], device=device) else: dummy_input = { 'source_image': torch.rand([args.batch_size, 3, 240, 240], device=device), 'target_image': torch.rand([args.batch_size, 3, 240, 240], device=device), 'theta_GT': torch.rand([args.batch_size, 2, 3], device=device) } logs_writer.add_graph(model, dummy_input) # Start of training print('Starting training...') best_val_loss = float("inf") max_batch_iters = len(dataloader) print('Iterations for one epoch:', max_batch_iters) epoch_to_change_lr = int(args.lr_max_iter / max_batch_iters * 2 + 0.5) # Loading checkpoint model, optimizer, start_epoch, best_val_loss, last_epoch = load_checkpoint( checkpoint_path, model, optimizer, device) # Scheduler if args.lr_scheduler == 'cosine': is_cosine_scheduler = True scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.lr_max_iter, eta_min=1e-7, last_epoch=last_epoch) elif args.lr_scheduler == 'cosine_restarts': is_cosine_scheduler = True scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=args.lr_max_iter, T_mult=2, last_epoch=last_epoch) elif args.lr_scheduler == 'exp': is_cosine_scheduler = False if last_epoch > 0: last_epoch /= max_batch_iters scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=args.lr_decay, last_epoch=last_epoch) # elif args.lr_scheduler == 'step': # step_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1) # scheduler = False else: is_cosine_scheduler = False scheduler = False for epoch in range(1, start_epoch): if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0): scheduler.state_dict()['base_lrs'][0] *= args.lr_decay torch.autograd.set_detect_anomaly(True) for epoch in range(start_epoch, args.num_epochs + 1): print('Current epoch: ', epoch) # we don't need the average epoch loss so we assign it to _ _ = train(epoch, model, loss, optimizer, dataloader, pair_generation_tnf, log_interval=args.log_interval, scheduler=scheduler, is_cosine_scheduler=is_cosine_scheduler, tb_writer=logs_writer) # Step non-cosine scheduler if scheduler and not is_cosine_scheduler: scheduler.step() val_loss = validate_model(model, loss, dataloader_val, pair_generation_tnf, epoch, logs_writer) # Change lr_max in cosine annealing if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0): scheduler.state_dict()['base_lrs'][0] *= args.lr_decay if (epoch % epoch_to_change_lr == epoch_to_change_lr // 2) or epoch == 1: compute_metric('absdiff', model, args.geometric_model, None, None, dataset_val, dataloader_val, pair_generation_tnf, args.batch_size, args) # remember best loss is_best = val_loss < best_val_loss best_val_loss = min(val_loss, best_val_loss) save_checkpoint( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(), 'best_val_loss': best_val_loss, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint_path) logs_writer.close() print('Done!')
def main(): args = parse_flags() use_cuda = torch.cuda.is_available() # Seed torch.manual_seed(args.seed) if use_cuda: torch.cuda.manual_seed(args.seed) # Download dataset if needed and set paths if args.training_dataset == 'pascal': if args.training_image_path == '': download_pascal('datasets/pascal-voc11/') args.training_image_path = 'datasets/pascal-voc11/' if args.training_tnf_csv == '' and args.geometric_model == 'affine': args.training_tnf_csv = 'training_data/pascal-synth-aff' elif args.training_tnf_csv == '' and args.geometric_model == 'tps': args.training_tnf_csv = 'training_data/pascal-synth-tps' # CNN model and loss if not args.pretrained: if args.light_model: print('Creating light CNN model...') model = LightCNN(use_cuda=use_cuda, geometric_model=args.geometric_model) else: print('Creating CNN model...') model = CNNGeometric( use_cuda=use_cuda, geometric_model=args.geometric_model, feature_extraction_cnn=args.feature_extraction_cnn) else: model = load_torch_model(args, use_cuda) if args.loss == 'mse': print('Using MSE loss...') loss = MSELoss() elif args.loss == 'sum': print('Using the sum of MSE and grid loss...') loss = GridLossWithMSE(use_cuda=use_cuda, geometric_model=args.geometric_model) else: print('Using grid loss...') loss = TransformedGridLoss(use_cuda=use_cuda, geometric_model=args.geometric_model) # Initialize csv paths train_csv_path_list = glob( os.path.join(args.training_tnf_csv, '*train.csv')) if len(train_csv_path_list) > 1: print( "!!!!WARNING!!!! multiple train csv files found, using first in glob order" ) elif not len(train_csv_path_list): raise FileNotFoundError( "No training csv where found in the specified path!!!") train_csv_path = train_csv_path_list[0] val_csv_path_list = glob(os.path.join(args.training_tnf_csv, '*val.csv')) if len(val_csv_path_list) > 1: print( "!!!!WARNING!!!! multiple train csv files found, using first in glob order" ) elif not len(val_csv_path_list): raise FileNotFoundError( "No training csv where found in the specified path!!!") val_csv_path = val_csv_path_list[0] # Initialize Dataset objects if args.coupled_dataset: # Dataset for train and val if dataset is already coupled dataset = CoupledDataset(geometric_model=args.geometric_model, csv_file=train_csv_path, training_image_path=args.training_image_path, transform=NormalizeImageDict( ['image_a', 'image_b'])) dataset_val = CoupledDataset( geometric_model=args.geometric_model, csv_file=val_csv_path, training_image_path=args.training_image_path, transform=NormalizeImageDict(['image_a', 'image_b'])) # Set Tnf pair generation func pair_generation_tnf = CoupledPairTnf(use_cuda=use_cuda) else: # Standard Dataset for train and val dataset = SynthDataset(geometric_model=args.geometric_model, csv_file=train_csv_path, training_image_path=args.training_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) dataset_val = SynthDataset( geometric_model=args.geometric_model, csv_file=val_csv_path, training_image_path=args.training_image_path, transform=NormalizeImageDict(['image']), random_sample=args.random_sample) # Set Tnf pair generation func pair_generation_tnf = SynthPairTnf( geometric_model=args.geometric_model, use_cuda=use_cuda) # Initialize DataLoaders dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, num_workers=4) # Optimizer and eventual scheduler optimizer = Adam(model.FeatureRegression.parameters(), lr=args.lr) if args.lr_scheduler: if args.scheduler_type == 'cosine': print('Using cosine learning rate scheduler') scheduler = CosineAnnealingLR(optimizer, T_max=args.lr_max_iter, eta_min=args.lr_min) elif args.scheduler_type == 'decay': print('Using decay learning rate scheduler') scheduler = ReduceLROnPlateau(optimizer, 'min') else: print( 'Using truncated cosine with decay learning rate scheduler...') scheduler = TruncateCosineScheduler(optimizer, len(dataloader), args.num_epochs - 1) else: scheduler = False # Train # Set up names for checkpoints if args.loss == 'mse': ckpt = args.trained_models_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn checkpoint_path = os.path.join(args.trained_models_dir, args.trained_models_fn, ckpt + '.pth.tar') elif args.loss == 'sum': ckpt = args.trained_models_fn + '_' + args.geometric_model + '_sum_loss' + args.feature_extraction_cnn checkpoint_path = os.path.join(args.trained_models_dir, args.trained_models_fn, ckpt + '.pth.tar') else: ckpt = args.trained_models_fn + '_' + args.geometric_model + '_grid_loss' + args.feature_extraction_cnn checkpoint_path = os.path.join(args.trained_models_dir, args.trained_models_fn, ckpt + '.pth.tar') if not os.path.exists(args.trained_models_dir): os.mkdir(args.trained_models_dir) # Set up TensorBoard writer if not args.log_dir: tb_dir = os.path.join(args.trained_models_dir, args.trained_models_fn + '_tb_logs') else: tb_dir = os.path.join(args.log_dir, args.trained_models_fn + '_tb_logs') logs_writer = SummaryWriter(tb_dir) # add graph, to do so we have to generate a dummy input to pass along with the graph dummy_input = { 'source_image': torch.rand([args.batch_size, 3, 240, 240]), 'target_image': torch.rand([args.batch_size, 3, 240, 240]), 'theta_GT': torch.rand([16, 2, 3]) } logs_writer.add_graph(model, dummy_input) # START OF TRAINING # print('Starting training...') best_val_loss = float("inf") for epoch in range(1, args.num_epochs + 1): # we don't need the average epoch loss so we assign it to _ _ = train(epoch, model, loss, optimizer, dataloader, pair_generation_tnf, log_interval=args.log_interval, scheduler=scheduler, tb_writer=logs_writer) val_loss = validate_model(model, loss, dataloader_val, pair_generation_tnf, epoch, logs_writer, coupled=args.coupled_dataset) # remember best loss is_best = val_loss < best_val_loss best_val_loss = min(val_loss, best_val_loss) save_checkpoint( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(), 'best_val_loss': best_val_loss, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint_path) logs_writer.close() print('Done!')
checkpoint = torch.load(args.load_model, map_location=lambda storage, loc: storage) # Load model state dict model.load_state_dict(checkpoint['state_dict']) # Load optimizer state dict optimizer.load_state_dict(checkpoint['optimizer']) # Load epoch information start_epoch = checkpoint['epoch'] print("Reloading from--[%s]" % args.load_model) for epoch in range(start_epoch, args.num_epochs + 1): # Call train, test function train_loss = train(epoch, model, loss, optimizer, dataloader_train, use_cuda, log_interval=100) test_acc = test(model, dataloader_test, len(dataset_test), use_cuda) checkpoint_name = os.path.join( args.trained_models_dir, args.model_type + '_epoch_' + str(epoch) + '.pth.tar') # Save checkpoint save_checkpoint( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(),