def adjust_learning_rate(config, optimizer, epoch, decay=0.5, max_decays=4): """Sets the learning rate to the initial LR decayed by 0.5 every k epochs""" exponent = min(epoch // (config.model.scheduler.lr_epoch_divide_frequency / config.datasets.train.repeat), max_decays) decay_factor = (decay**exponent) for param_group in optimizer.param_groups: param_group['lr'] = param_group['original_lr'] * decay_factor printcolor('Changing {} network learning rate to {:8.6f}'.format(param_group['name'], param_group['lr']), 'red')
def main(file, training_mode, non_spatial_aug, wandb_name, interval, partition, pretrained_model=None): """ KP2D training script. Parameters ---------- file : str Filepath, can be either a **.yaml** for a yacs configuration file or a **.ckpt** for a pre-trained checkpoint file. """ # Parse config config = parse_train_file(file) print(config) print(config.arch) config.wandb.name = wandb_name # Initialize horovod hvd_init() n_threads = int(os.environ.get("OMP_NUM_THREADS", 1)) torch.set_num_threads(n_threads) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True if world_size() > 1: printcolor('-'*18 + 'DISTRIBUTED DATA PARALLEL ' + '-'*18, 'cyan') device_id = local_rank() torch.cuda.set_device(device_id) else: printcolor('-'*25 + 'SINGLE GPU ' + '-'*25, 'cyan') if config.arch.seed is not None: _set_seeds(config.arch.seed) if rank() == 0: printcolor('-'*25 + ' MODEL PARAMS ' + '-'*25) printcolor(config.model.params, 'red') # Setup model and datasets/dataloaders model = KeypointNetwithIOLoss(pretrained_model=pretrained_model, training_mode=training_mode, keypoint_net_learning_rate=config.model.optimizer.learning_rate, **config.model.params) train_dataset, train_loader = setup_datasets_and_dataloaders(config.datasets, training_mode=training_mode, non_spatial_aug=non_spatial_aug, interval=interval, partition=partition) printcolor('({}) length: {}'.format("Train", len(train_dataset))) model = model.cuda() optimizer = optim.Adam(model.optim_params) compression = hvd.Compression.none # or hvd.Compression.fp16 optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=compression) # Synchronize model weights from all ranks hvd.broadcast_parameters(model.state_dict(), root_rank=0) # checkpoint model log_path = os.path.join(config.model.checkpoint_path, 'logs') os.makedirs(log_path, exist_ok=True) if rank() == 0: if not config.wandb.dry_run: summary = SummaryWriter(log_path, config, project=config.wandb.project, entity=config.wandb.entity, job_type='training', name=config.wandb.name, mode=os.getenv('WANDB_MODE', 'run')) config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, summary.run_name) else: summary = None date_time = datetime.now().strftime("%m_%d_%Y__%H_%M_%S") date_time = model_submodule(model).__class__.__name__ + '_' + date_time config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, date_time) print('Saving models at {}'.format(config.model.checkpoint_path)) os.makedirs(config.model.checkpoint_path, exist_ok=True) else: summary = None # Initial evaluation # evaluation(config, 0, model, summary) # Train for epoch in range(config.arch.epochs): # train for one epoch (only log if eval to have aligned steps...) printcolor("\n--------------------------------------------------------------") train(config, train_loader, model, optimizer, epoch, summary) # Model checkpointing, eval, and logging evaluation(config, epoch + 1, model, summary) printcolor('Training complete, models saved in {}'.format(config.model.checkpoint_path), "green")
def evaluation(config, completed_epoch, model, summary): # Set to eval mode model.eval() model.training = False use_color = config.model.params.use_color if rank() == 0: # eval_shape = config.datasets.augmentation.image_shape[::-1] eval_shape = (320, 240) eval_params = [{'res': eval_shape, 'top_k': 300}] for params in eval_params: hp_dataset = PatchesDataset(root_dir=config.datasets.val.path, use_color=use_color, output_shape=params['res'], type='a') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('HPatches: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k'])) rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('hpatches_repeatability_'+str(params['res']), rep) summary.add_scalar('hpatches_localization_' + str(params['res']), loc) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(1), c1) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(3), c3) summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(5), c5) summary.add_scalar('hpatches_mscore' + str(params['res']), mscore) print('Hpatches Repeatability {0:.3f}'.format(rep)) print('Hpatches Localization Error {0:.3f}'.format(loc)) print('Hpatches Correctness d1 {:.3f}'.format(c1)) print('Hpatches Correctness d3 {:.3f}'.format(c3)) print('Hpatches Correctness d5 {:.3f}'.format(c5)) print('Hpatches MScore {:.3f}'.format(mscore)) params = {'res': (1024, 768), 'top_k': 1000} # hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='consecutive', data_transform=to_tensor_sample, partition='val+test') hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='con', center_crop=False, data_transform=to_tensor_sample, interval=1, partition='val+test') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('Hypersim: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k'])) rep, loc, mscore = evaluate_keypoint_net_hypersim(data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('hypersim_repeatability_'+str(params['res']), rep) summary.add_scalar('hypersim_localization_' + str(params['res']), loc) summary.add_scalar('hypersim_mscore' + str(params['res']), mscore) print('Hypersim Repeatability {0:.3f}'.format(rep)) print('Hypersim Localization Error {0:.3f}'.format(loc)) print('Hypersim MScore {:.3f}'.format(mscore)) # Save checkpoint if config.model.save_checkpoint and rank() == 0: current_model_path = os.path.join(config.model.checkpoint_path, 'model.ckpt') printcolor('\nSaving model (epoch:{}) at {}'.format(completed_epoch, current_model_path), 'green') torch.save( { 'state_dict': model_submodule(model_submodule(model).keypoint_net).state_dict(), 'config': config }, current_model_path)
def evaluation(config, completed_epoch, model, summary): # Set to eval mode model.eval() model.training = False use_color = config.model.params.use_color if rank() == 0: eval_params = [{'res': (320, 240), 'top_k': 300}] for params in eval_params: hp_dataset = PatchesDataset(root_dir=config.datasets.val.path, use_color=use_color, output_shape=params['res'], type='a') data_loader = DataLoader(hp_dataset, batch_size=1, pin_memory=False, shuffle=False, num_workers=8, worker_init_fn=None, sampler=None) print('Loaded {} image pairs '.format(len(data_loader))) printcolor('Evaluating for {} -- top_k {}'.format( params['res'], params['top_k'])) rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( data_loader, model_submodule(model).keypoint_net, output_shape=params['res'], top_k=params['top_k'], use_color=use_color) if summary: summary.add_scalar('repeatability_' + str(params['res']), rep) summary.add_scalar('localization_' + str(params['res']), loc) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(1), c1) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(3), c3) summary.add_scalar( 'correctness_' + str(params['res']) + '_' + str(5), c5) summary.add_scalar('mscore' + str(params['res']), mscore) print('Repeatability {0:.3f}'.format(rep)) print('Localization Error {0:.3f}'.format(loc)) print('Correctness d1 {:.3f}'.format(c1)) print('Correctness d3 {:.3f}'.format(c3)) print('Correctness d5 {:.3f}'.format(c5)) print('MScore {:.3f}'.format(mscore)) if summary: summary.commit_log() # Save checkpoint if config.model.save_checkpoint and rank() == 0: current_model_path = os.path.join(config.model.checkpoint_path, 'model.ckpt') printcolor( '\nSaving model (epoch:{}) at {}'.format(completed_epoch, current_model_path), 'green') torch.save( { 'state_dict': model_submodule( model_submodule(model).keypoint_net).state_dict(), 'config': config }, current_model_path)