def setup_datasets_and_dataloaders(config): """Prepare datasets for training, validation and test.""" def _worker_init_fn(worker_id): """Worker init fn to fix the seed of the workers""" _set_seeds(42 + worker_id) data_transforms = image_transforms(shape=config.augmentation.image_shape, jittering=config.augmentation.jittering) train_dataset = COCOLoader(config.train.path, data_transform=data_transforms['train']) # Concatenate dataset to produce a larger one if config.train.repeat > 1: train_dataset = ConcatDataset( [train_dataset for _ in range(config.train.repeat)]) # Create loaders if world_size() > 1: sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=world_size(), rank=rank()) else: sampler = None train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, pin_memory=True, shuffle=not (world_size() > 1), num_workers=config.train.num_workers, worker_init_fn=_worker_init_fn, sampler=sampler) return train_dataset, train_loader
def evaluation(config, completed_epoch, model, summary): # Set to eval mode model.eval() model.training = False use_color = config.model.params.use_color if rank() == 0: # eval_shape = config.datasets.augmentation.image_shape[::-1] eval_shape = (320, 240) eval_params = [{'res': eval_shape, 'top_k': 300}] for params in eval_params: hp_dataset = PatchesDataset(root_dir=config.datasets.val.path, use_color=use_color, output_shape=params['res'], type='a') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('HPatches: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k'])) rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('hpatches_repeatability_'+str(params['res']), rep) summary.add_scalar('hpatches_localization_' + str(params['res']), loc) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(1), c1) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(3), c3) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(5), c5) summary.add_scalar('hpatches_mscore' + str(params['res']), mscore) print('Hpatches Repeatability {0:.3f}'.format(rep)) print('Hpatches Localization Error {0:.3f}'.format(loc)) print('Hpatches Correctness d1 {:.3f}'.format(c1)) print('Hpatches Correctness d3 {:.3f}'.format(c3)) print('Hpatches Correctness d5 {:.3f}'.format(c5)) print('Hpatches MScore {:.3f}'.format(mscore)) params = {'res': (1024, 768), 'top_k': 1000} # hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='consecutive', data_transform=to_tensor_sample, partition='val+test') hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='con', center_crop=False, data_transform=to_tensor_sample, interval=1, partition='val+test') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('Hypersim: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k'])) rep, loc, mscore = evaluate_keypoint_net_hypersim(data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('hypersim_repeatability_'+str(params['res']), rep) summary.add_scalar('hypersim_localization_' + str(params['res']), loc) summary.add_scalar('hypersim_mscore' + str(params['res']), mscore) print('Hypersim Repeatability {0:.3f}'.format(rep)) print('Hypersim Localization Error {0:.3f}'.format(loc)) print('Hypersim MScore {:.3f}'.format(mscore)) # Save checkpoint if config.model.save_checkpoint and rank() == 0: current_model_path = os.path.join(config.model.checkpoint_path, 'model.ckpt') printcolor('\nSaving model (epoch:{}) at {}'.format(completed_epoch, current_model_path), 'green') torch.save( { 'state_dict': model_submodule(model_submodule(model).keypoint_net).state_dict(), 'config': config }, current_model_path)
def main(file, training_mode, non_spatial_aug, wandb_name, interval, partition, pretrained_model=None): """ KP2D training script. Parameters ---------- file : str Filepath, can be either a **.yaml** for a yacs configuration file or a **.ckpt** for a pre-trained checkpoint file. """ # Parse config config = parse_train_file(file) print(config) print(config.arch) config.wandb.name = wandb_name # Initialize horovod hvd_init() n_threads = int(os.environ.get("OMP_NUM_THREADS", 1)) torch.set_num_threads(n_threads) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True if world_size() > 1: printcolor('-'*18 + 'DISTRIBUTED DATA PARALLEL ' + '-'*18, 'cyan') device_id = local_rank() torch.cuda.set_device(device_id) else: printcolor('-'*25 + 'SINGLE GPU ' + '-'*25, 'cyan') if config.arch.seed is not None: _set_seeds(config.arch.seed) if rank() == 0: printcolor('-'*25 + ' MODEL PARAMS ' + '-'*25) printcolor(config.model.params, 'red') # Setup model and datasets/dataloaders model = KeypointNetwithIOLoss(pretrained_model=pretrained_model, training_mode=training_mode, keypoint_net_learning_rate=config.model.optimizer.learning_rate, **config.model.params) train_dataset, train_loader = setup_datasets_and_dataloaders(config.datasets, training_mode=training_mode, non_spatial_aug=non_spatial_aug, interval=interval, partition=partition) printcolor('({}) length: {}'.format("Train", len(train_dataset))) model = model.cuda() optimizer = optim.Adam(model.optim_params) compression = hvd.Compression.none # or hvd.Compression.fp16 optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=compression) # Synchronize model weights from all ranks hvd.broadcast_parameters(model.state_dict(), root_rank=0) # checkpoint model log_path = os.path.join(config.model.checkpoint_path, 'logs') os.makedirs(log_path, exist_ok=True) if rank() == 0: if not config.wandb.dry_run: summary = SummaryWriter(log_path, config, project=config.wandb.project, entity=config.wandb.entity, job_type='training', name=config.wandb.name, mode=os.getenv('WANDB_MODE', 'run')) config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, summary.run_name) else: summary = None date_time = datetime.now().strftime("%m_%d_%Y__%H_%M_%S") date_time = model_submodule(model).__class__.__name__ + '_' + date_time config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, date_time) print('Saving models at {}'.format(config.model.checkpoint_path)) os.makedirs(config.model.checkpoint_path, exist_ok=True) else: summary = None # Initial evaluation # evaluation(config, 0, model, summary) # Train for epoch in range(config.arch.epochs): # train for one epoch (only log if eval to have aligned steps...) printcolor("\n--------------------------------------------------------------") train(config, train_loader, model, optimizer, epoch, summary) # Model checkpointing, eval, and logging evaluation(config, epoch + 1, model, summary) printcolor('Training complete, models saved in {}'.format(config.model.checkpoint_path), "green")
def evaluation(config, completed_epoch, model, summary): # Set to eval mode model.eval() model.training = False use_color = config.model.params.use_color if rank() == 0: eval_params = [{'res': (320, 240), 'top_k': 300}] for params in eval_params: hp_dataset = PatchesDataset(root_dir=config.datasets.val.path, use_color=use_color, output_shape=params['res'], type='a') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('Evaluating for {} -- top_k {}'.format( params['res'], params['top_k'])) rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('repeatability_' + str(params['res']), rep) summary.add_scalar('localization_' + str(params['res']), loc) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(1), c1) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(3), c3) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(5), c5) summary.add_scalar('mscore' + str(params['res']), mscore) print('Repeatability {0:.3f}'.format(rep)) print('Localization Error {0:.3f}'.format(loc)) print('Correctness d1 {:.3f}'.format(c1)) print('Correctness d3 {:.3f}'.format(c3)) print('Correctness d5 {:.3f}'.format(c5)) print('MScore {:.3f}'.format(mscore)) if summary: summary.commit_log() # Save checkpoint if config.model.save_checkpoint and rank() == 0: current_model_path = os.path.join(config.model.checkpoint_path, 'model.ckpt') printcolor( '\nSaving model (epoch:{}) at {}'.format(completed_epoch, current_model_path), 'green') torch.save( { 'state_dict': model_submodule( model_submodule(model).keypoint_net).state_dict(), 'config': config }, current_model_path)