Exemplo n.º 1
0
def adjust_learning_rate(config, optimizer, epoch, decay=0.5, max_decays=4):
    """Sets the learning rate to the initial LR decayed by 0.5 every k epochs"""
    exponent = min(epoch // (config.model.scheduler.lr_epoch_divide_frequency / config.datasets.train.repeat), max_decays)
    decay_factor = (decay**exponent)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['original_lr'] * decay_factor
        printcolor('Changing {} network learning rate to {:8.6f}'.format(param_group['name'], param_group['lr']),
                   'red')
Exemplo n.º 2
0
def main(file, training_mode, non_spatial_aug, wandb_name, interval, partition, pretrained_model=None):
    """
    KP2D training script.

    Parameters
    ----------
    file : str
        Filepath, can be either a
        **.yaml** for a yacs configuration file or a
        **.ckpt** for a pre-trained checkpoint file.
    """
    # Parse config
    config = parse_train_file(file)
    print(config)
    print(config.arch)
    config.wandb.name = wandb_name

    # Initialize horovod
    hvd_init()
    n_threads = int(os.environ.get("OMP_NUM_THREADS", 1))
    torch.set_num_threads(n_threads)    
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    if world_size() > 1:
        printcolor('-'*18 + 'DISTRIBUTED DATA PARALLEL ' + '-'*18, 'cyan')
        device_id = local_rank()
        torch.cuda.set_device(device_id)
    else:
        printcolor('-'*25 + 'SINGLE GPU ' + '-'*25, 'cyan')
    
    if config.arch.seed is not None:
        _set_seeds(config.arch.seed)

    if rank() == 0:
        printcolor('-'*25 + ' MODEL PARAMS ' + '-'*25)
        printcolor(config.model.params, 'red')

    # Setup model and datasets/dataloaders
    model = KeypointNetwithIOLoss(pretrained_model=pretrained_model, training_mode=training_mode, keypoint_net_learning_rate=config.model.optimizer.learning_rate, **config.model.params)
    train_dataset, train_loader = setup_datasets_and_dataloaders(config.datasets, training_mode=training_mode, non_spatial_aug=non_spatial_aug, interval=interval, partition=partition)
    printcolor('({}) length: {}'.format("Train", len(train_dataset)))

    model = model.cuda()
    optimizer = optim.Adam(model.optim_params)
    compression = hvd.Compression.none  # or hvd.Compression.fp16
    optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=compression)

    # Synchronize model weights from all ranks
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    # checkpoint model
    log_path = os.path.join(config.model.checkpoint_path, 'logs')
    os.makedirs(log_path, exist_ok=True)
    
    if rank() == 0:
        if not config.wandb.dry_run:
            summary = SummaryWriter(log_path,
                                    config,
                                    project=config.wandb.project,
                                    entity=config.wandb.entity,
                                    job_type='training',
                                    name=config.wandb.name,
                                    mode=os.getenv('WANDB_MODE', 'run'))
            config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, summary.run_name)
        else:
            summary = None
            date_time = datetime.now().strftime("%m_%d_%Y__%H_%M_%S")
            date_time = model_submodule(model).__class__.__name__ + '_' + date_time
            config.model.checkpoint_path = os.path.join(config.model.checkpoint_path, date_time)
        
        print('Saving models at {}'.format(config.model.checkpoint_path))
        os.makedirs(config.model.checkpoint_path, exist_ok=True)    
    else:
        summary = None

    # Initial evaluation
    # evaluation(config, 0, model, summary)
    # Train
    for epoch in range(config.arch.epochs):
        # train for one epoch (only log if eval to have aligned steps...)
        printcolor("\n--------------------------------------------------------------")
        train(config, train_loader, model, optimizer, epoch, summary)

        # Model checkpointing, eval, and logging
        evaluation(config, epoch + 1, model, summary)
    printcolor('Training complete, models saved in {}'.format(config.model.checkpoint_path), "green")
Exemplo n.º 3
0
def evaluation(config, completed_epoch, model, summary):
    # Set to eval mode
    model.eval()
    model.training = False

    use_color = config.model.params.use_color

    if rank() == 0:
        # eval_shape = config.datasets.augmentation.image_shape[::-1]
        eval_shape = (320, 240)
        eval_params = [{'res': eval_shape, 'top_k': 300}]
        for params in eval_params:
            hp_dataset = PatchesDataset(root_dir=config.datasets.val.path, use_color=use_color, output_shape=params['res'], type='a')

            data_loader = DataLoader(hp_dataset,
                                    batch_size=1,
                                    pin_memory=False,
                                    shuffle=False,
                                    num_workers=8,
                                    worker_init_fn=None,
                                    sampler=None)
            print('Loaded {} image pairs '.format(len(data_loader)))

            printcolor('HPatches: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k']))
            rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(data_loader,
                                                                model_submodule(model).keypoint_net,
                                                                output_shape=params['res'],
                                                                top_k=params['top_k'],
                                                                use_color=use_color)
            if summary:
                summary.add_scalar('hpatches_repeatability_'+str(params['res']), rep)
                summary.add_scalar('hpatches_localization_' + str(params['res']), loc)
                summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(1), c1)
                summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(3), c3)
                summary.add_scalar('hpatches_correctness_'+str(params['res'])+'_'+str(5), c5)
                summary.add_scalar('hpatches_mscore' + str(params['res']), mscore)

            print('Hpatches Repeatability {0:.3f}'.format(rep))
            print('Hpatches Localization Error {0:.3f}'.format(loc))
            print('Hpatches Correctness d1 {:.3f}'.format(c1))
            print('Hpatches Correctness d3 {:.3f}'.format(c3))
            print('Hpatches Correctness d5 {:.3f}'.format(c5))
            print('Hpatches MScore {:.3f}'.format(mscore))

        params = {'res': (1024, 768), 'top_k': 1000}
        # hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='consecutive', data_transform=to_tensor_sample, partition='val+test')
        hp_dataset = HypersimLoader(config.datasets.train.path, training_mode='con', center_crop=False, data_transform=to_tensor_sample, interval=1, partition='val+test')
        data_loader = DataLoader(hp_dataset,
                                batch_size=1,
                                pin_memory=False,
                                shuffle=False,
                                num_workers=8,
                                worker_init_fn=None,
                                sampler=None)
        print('Loaded {} image pairs '.format(len(data_loader)))

        printcolor('Hypersim: Evaluating for {} -- top_k {}'.format(params['res'], params['top_k']))
        rep, loc, mscore = evaluate_keypoint_net_hypersim(data_loader,
                                                        model_submodule(model).keypoint_net,
                                                        output_shape=params['res'],
                                                        top_k=params['top_k'],
                                                        use_color=use_color)
        if summary:
            summary.add_scalar('hypersim_repeatability_'+str(params['res']), rep)
            summary.add_scalar('hypersim_localization_' + str(params['res']), loc)
            summary.add_scalar('hypersim_mscore' + str(params['res']), mscore)

        print('Hypersim Repeatability {0:.3f}'.format(rep))
        print('Hypersim Localization Error {0:.3f}'.format(loc))
        print('Hypersim MScore {:.3f}'.format(mscore))

    # Save checkpoint
    if config.model.save_checkpoint and rank() == 0:
        current_model_path = os.path.join(config.model.checkpoint_path, 'model.ckpt')
        printcolor('\nSaving model (epoch:{}) at {}'.format(completed_epoch, current_model_path), 'green')
        torch.save(
        {
            'state_dict': model_submodule(model_submodule(model).keypoint_net).state_dict(),
            'config': config
        }, current_model_path)
Exemplo n.º 4
0
def evaluation(config, completed_epoch, model, summary):
    # Set to eval mode
    model.eval()
    model.training = False

    use_color = config.model.params.use_color

    if rank() == 0:
        eval_params = [{'res': (320, 240), 'top_k': 300}]
        for params in eval_params:
            hp_dataset = PatchesDataset(root_dir=config.datasets.val.path,
                                        use_color=use_color,
                                        output_shape=params['res'],
                                        type='a')

            data_loader = DataLoader(hp_dataset,
                                     batch_size=1,
                                     pin_memory=False,
                                     shuffle=False,
                                     num_workers=8,
                                     worker_init_fn=None,
                                     sampler=None)
            print('Loaded {} image pairs '.format(len(data_loader)))

            printcolor('Evaluating for {} -- top_k {}'.format(
                params['res'], params['top_k']))
            rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net(
                data_loader,
                model_submodule(model).keypoint_net,
                output_shape=params['res'],
                top_k=params['top_k'],
                use_color=use_color)
            if summary:
                summary.add_scalar('repeatability_' + str(params['res']), rep)
                summary.add_scalar('localization_' + str(params['res']), loc)
                summary.add_scalar(
                    'correctness_' + str(params['res']) + '_' + str(1), c1)
                summary.add_scalar(
                    'correctness_' + str(params['res']) + '_' + str(3), c3)
                summary.add_scalar(
                    'correctness_' + str(params['res']) + '_' + str(5), c5)
                summary.add_scalar('mscore' + str(params['res']), mscore)

            print('Repeatability {0:.3f}'.format(rep))
            print('Localization Error {0:.3f}'.format(loc))
            print('Correctness d1 {:.3f}'.format(c1))
            print('Correctness d3 {:.3f}'.format(c3))
            print('Correctness d5 {:.3f}'.format(c5))
            print('MScore {:.3f}'.format(mscore))
        if summary:
            summary.commit_log()

    # Save checkpoint
    if config.model.save_checkpoint and rank() == 0:
        current_model_path = os.path.join(config.model.checkpoint_path,
                                          'model.ckpt')
        printcolor(
            '\nSaving model (epoch:{}) at {}'.format(completed_epoch,
                                                     current_model_path),
            'green')
        torch.save(
            {
                'state_dict':
                model_submodule(
                    model_submodule(model).keypoint_net).state_dict(),
                'config':
                config
            }, current_model_path)