Пример #1
0
def train_fold(train_config, distrib_config, pipeline_name, log_dir, fold_id,
               train_dataloader, valid_dataloader, evaluator):

    if distrib_config['LOCAL_RANK'] == 0:
        fold_logger = init_logger(log_dir, 'train_fold_{}.log'.format(fold_id))
        fold_tb_logger = init_tb_logger(log_dir,
                                        'train_fold_{}'.format(fold_id))

    best_checkpoint_folder = Path(log_dir,
                                  train_config['CHECKPOINTS']['BEST_FOLDER'])
    best_checkpoint_folder.mkdir(exist_ok=True, parents=True)

    checkpoints_history_folder = Path(
        log_dir, train_config['CHECKPOINTS']['FULL_FOLDER'],
        'fold{}'.format(fold_id))
    checkpoints_history_folder.mkdir(exist_ok=True, parents=True)
    checkpoints_topk = train_config['CHECKPOINTS']['TOPK']

    calculation_name = '{}_fold{}'.format(pipeline_name, fold_id)

    device = train_config['DEVICE']

    module = importlib.import_module(train_config['MODEL']['PY'])
    model_function = getattr(module, train_config['MODEL']['CLASS'])
    model = model_function(**train_config['MODEL']['ARGS'])

    if len(train_config['DEVICE_LIST']) > 1:
        model.cuda()
        model = convert_syncbn_model(model)
        model = DistributedDataParallel(model, delay_allreduce=True)

    pretrained_model_path = best_checkpoint_folder / f'{calculation_name}.pth'
    if pretrained_model_path.is_file():
        state_dict = torch.load(pretrained_model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict)

        if distrib_config['LOCAL_RANK'] == 0:
            fold_logger.info(
                'load model from {}'.format(pretrained_model_path))

    loss_args = train_config['CRITERION']
    loss_fn = SegmentationLosses(
        weight=loss_args['weight'],
        size_average=loss_args['size_average'],
        batch_average=loss_args['batch_average'],
        ignore_index=loss_args['ignore_index'],
        cuda=loss_args['cuda']).build_loss(mode=loss_args['mode'])

    if train_config['OPTIMIZER']['CLASS'] == 'RAdam':
        optimizer_class = getattr(radam, train_config['OPTIMIZER']['CLASS'])
    else:
        optimizer_class = getattr(torch.optim,
                                  train_config['OPTIMIZER']['CLASS'])

    train_params = [{
        'params': model.get_1x_lr_params(),
        'lr': train_config['OPTIMIZER']['ARGS']['lr']
    }, {
        'params': model.get_10x_lr_params(),
        'lr': train_config['OPTIMIZER']['ARGS']['lr'] * 10
    }]
    optimizer = optimizer_class(train_params,
                                **train_config['OPTIMIZER']['ARGS'])

    scheduler_class = getattr(torch.optim.lr_scheduler,
                              train_config['SCHEDULER']['CLASS'])
    scheduler = scheduler_class(optimizer, **train_config['SCHEDULER']['ARGS'])

    n_epoches = train_config['EPOCHS']
    accumulation_step = train_config['ACCUMULATION_STEP']
    early_stopping = train_config['EARLY_STOPPING']

    if distrib_config['LOCAL_RANK'] != 0:
        fold_logger = None
        fold_tb_logger = None

    best_epoch, best_score = Learning(
        distrib_config, optimizer, loss_fn, evaluator, device, n_epoches,
        scheduler, accumulation_step, early_stopping, fold_logger,
        fold_tb_logger, best_checkpoint_folder, checkpoints_history_folder,
        checkpoints_topk, calculation_name).run_train(model, train_dataloader,
                                                      valid_dataloader)

    fold_logger.info(f'Best Epoch : {best_epoch}, Best Score : {best_score}')