コード例 #1
0
    def __init__(self, bert_question_encoder, bert_paragraph_encoder,
                 tokenizer, max_question_len, max_paragraph_len, debug,
                 model_hyper_params, previous_hidden_size):
        super(CNNRetriever,
              self).__init__(bert_question_encoder, bert_paragraph_encoder,
                             tokenizer, max_question_len, max_paragraph_len,
                             debug)
        self.returns_embeddings = False

        check_and_log_hp(['retriever_layer_sizes'], model_hyper_params)
        # CNN commented

        # [batch 3 542 768]
        # [batch 3 new_dim 768]
        # pooling or flatten along 3rd dim
        # [batch 3 768]
        #out [batch 3]
        # softmax later
        input_channels = 768
        out_channel = 256
        kernel_size = 3
        stride = 1
        #(batch 3 542 768]
        #view change to bath * 3 542 768
        # should be along the sequence

        #(batch token embedding)
        self.conv1 = nn.Conv1d(input_channels, out_channel, kernel_size,
                               stride)
        self.pool = nn.MaxPool1d(kernel_size)
        self.conv2 = nn.Conv1d(256, 128, kernel_size)

        self.fc1 = nn.Linear(128 * 59, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)
コード例 #2
0
    def __init__(self, hyper_params, encoder_hidden_size):
        super(GeneralEncoder, self).__init__()

        model_hparams = hyper_params['model']
        check_and_log_hp([
            'layers_pre_pooling', 'layers_post_pooling', 'dropout',
            'normalize_model_result', 'pooling_type'
        ], model_hparams)

        self.pooling_type = model_hparams['pooling_type']
        self.normalize_model_result = model_hparams['normalize_model_result']

        pre_pooling_seq = _get_layers(encoder_hidden_size,
                                      model_hparams['dropout'],
                                      model_hparams['layers_pre_pooling'],
                                      True)
        self.pre_pooling_net = nn.Sequential(*pre_pooling_seq)

        pre_pooling_last_hidden_size = model_hparams['layers_pre_pooling'][-1] if \
            model_hparams['layers_pre_pooling'] else encoder_hidden_size
        post_pooling_seq = _get_layers(pre_pooling_last_hidden_size,
                                       model_hparams['dropout'],
                                       model_hparams['layers_post_pooling'],
                                       False)
        self.post_pooling_last_hidden_size = model_hparams['layers_post_pooling'][-1] if \
            model_hparams['layers_post_pooling'] else pre_pooling_last_hidden_size
        self.post_pooling_net = nn.Sequential(*post_pooling_seq)
コード例 #3
0
def load_model(hyper_params, tokenizer, debug):
    check_and_log_hp(['name'], hyper_params['model'])
    if hyper_params['model']['name'] == 'bert_encoder':
        bert_question_encoder = BertEncoder(hyper_params)
        bert_paragraph_encoder = BertEncoder(hyper_params)
        model = EmbeddingRetriever(bert_question_encoder,
                                   bert_paragraph_encoder, tokenizer,
                                   hyper_params['max_question_len'],
                                   hyper_params['max_paragraph_len'], debug)
    elif hyper_params['model']['name'] == 'bert_ffw':
        bert_question_encoder = BertEncoder(hyper_params)
        bert_paragraph_encoder = BertEncoder(hyper_params)
        if bert_question_encoder.post_pooling_last_hidden_size != \
                bert_paragraph_encoder.post_pooling_last_hidden_size:
            raise ValueError(
                "question/paragraph encoder should have the same output hidden size"
            )
        previous_hidden_size = bert_question_encoder.post_pooling_last_hidden_size
        model = FeedForwardRetriever(bert_question_encoder,
                                     bert_paragraph_encoder,
                                     tokenizer,
                                     hyper_params['max_question_len'],
                                     hyper_params['max_paragraph_len'],
                                     debug,
                                     hyper_params['model'],
                                     previous_hidden_size=previous_hidden_size)
    elif hyper_params['model']['name'] == 'bert_cnn':
        bert_question_encoder = BertEncoder(hyper_params)
        bert_paragraph_encoder = BertEncoder(hyper_params)
        if bert_question_encoder.post_pooling_last_hidden_size != \
                bert_paragraph_encoder.post_pooling_last_hidden_size:
            raise ValueError(
                "question/paragraph encoder should have the same output hidden size"
            )
        previous_hidden_size = bert_question_encoder.post_pooling_last_hidden_size
        model = CNNRetriever(bert_question_encoder,
                             bert_paragraph_encoder,
                             tokenizer,
                             hyper_params['max_question_len'],
                             hyper_params['max_paragraph_len'],
                             debug,
                             hyper_params['model'],
                             previous_hidden_size=previous_hidden_size)
    elif hyper_params['model']['name'] == 'cnn':
        cnn_question_encoder = CNNEncoder(hyper_params, tokenizer.vocab_size)
        cnn_paragraph_encoder = CNNEncoder(hyper_params, tokenizer.vocab_size)
        model = EmbeddingRetriever(cnn_question_encoder, cnn_paragraph_encoder,
                                   tokenizer, hyper_params['max_question_len'],
                                   hyper_params['max_paragraph_len'], debug)
    else:
        raise ValueError('model name {} not supported'.format(
            hyper_params['model']['name']))
    return model
コード例 #4
0
ファイル: retriever.py プロジェクト: jghosn/bert_reranker
    def __init__(self, bert_question_encoder, bert_paragraph_encoder, tokenizer, max_question_len,
                 max_paragraph_len, debug, model_hyper_params, previous_hidden_size):
        super(FeedForwardRetriever, self).__init__(
            bert_question_encoder, bert_paragraph_encoder, tokenizer, max_question_len,
            max_paragraph_len, debug)
        self.returns_embeddings = False

        check_and_log_hp(['retriever_layer_sizes'], model_hyper_params)
        ffw_layers = get_ffw_layers(
            previous_hidden_size * 2, model_hyper_params['dropout'],
            model_hyper_params['retriever_layer_sizes'] + [1], False)
        self.ffw_net = nn.Sequential(*ffw_layers)
コード例 #5
0
ファイル: bert_encoder.py プロジェクト: jghosn/bert_reranker
    def __init__(self, hyper_params, bert_model, name=''):
        model_hparams = hyper_params['model']
        check_and_log_hp(
            ['bert_base', 'dropout_bert', 'freeze_bert', 'cache_size'],
            model_hparams)
        super(CachedBertEncoder, self).__init__(hyper_params, bert_model, name=name)

        if not model_hparams['freeze_bert'] or not model_hparams['dropout_bert'] == 0.0:
            raise ValueError('to cache results, set freeze_bert=True and dropout_bert=0.0')
        self.cache = {}
        self.cache_hit = 0
        self.cache_miss = 0
        self.max_cache_size = model_hparams['cache_size']
コード例 #6
0
    def __init__(self, hyper_params, voc_size):
        raise ValueError(
            'need to understand how to handle padding in the cnn layers')
        model_hparams = hyper_params['model']
        check_and_log_hp(['cnn_layer_sizes', 'emb_size'], model_hparams)
        super(CNNEncoder,
              self).__init__(hyper_params,
                             hyper_params['model']['cnn_layer_sizes'][-1])

        emb_size = hyper_params['model']['emb_size']
        self.embedding = nn.Embedding(voc_size, emb_size)
        self.cnn = nn.Conv1d(emb_size,
                             hyper_params['model']['cnn_layer_sizes'][-1],
                             10,
                             stride=2)
コード例 #7
0
    def __init__(self, hyper_params, name=''):
        model_hparams = hyper_params['model']
        check_and_log_hp(['bert_base', 'dropout_bert', 'freeze_bert'],
                         model_hparams)
        bert = AutoModel.from_pretrained(model_hparams['bert_base'])
        super(BertEncoder, self).__init__(hyper_params,
                                          bert.config.hidden_size)
        self.bert = bert
        self.name = name

        bert_dropout = model_hparams['dropout_bert']
        if bert_dropout is not None:
            logger.info('setting bert dropout to {}'.format(bert_dropout))
            self.bert.config.attention_probs_dropout_prob = bert_dropout
            self.bert.config.hidden_dropout_prob = bert_dropout
        else:
            logger.info('using the original bert model dropout')

        self.freeze_bert = model_hparams['freeze_bert']
コード例 #8
0
def load_model(hyper_params, tokenizer, debug):
    check_and_log_hp(['name', 'single_encoder'], hyper_params['model'])
    if hyper_params['model']['name'] == 'bert_encoder':
        if hyper_params['model'].get('cache_size', 0) > 0:
            encoder_class = CachedBertEncoder
        else:
            encoder_class = BertEncoder

        bert_paragraph_encoder, bert_question_encoder = _create_encoders(
            encoder_class, hyper_params)

        model = EmbeddingRetriever(bert_question_encoder,
                                   bert_paragraph_encoder, tokenizer,
                                   hyper_params['max_question_len'],
                                   hyper_params['max_paragraph_len'], debug)
    elif hyper_params['model']['name'] == 'bert_ffw':

        if hyper_params['model'].get('cache_size', 0) > 0:
            encoder_class = CachedBertEncoder
        else:
            encoder_class = BertEncoder

        bert_paragraph_encoder, bert_question_encoder = _create_encoders(
            encoder_class, hyper_params)

        if bert_question_encoder.post_pooling_last_hidden_size != \
                bert_paragraph_encoder.post_pooling_last_hidden_size:
            raise ValueError(
                "question/paragraph encoder should have the same output hidden size"
            )
        previous_hidden_size = bert_question_encoder.post_pooling_last_hidden_size
        model = FeedForwardRetriever(bert_question_encoder,
                                     bert_paragraph_encoder,
                                     tokenizer,
                                     hyper_params['max_question_len'],
                                     hyper_params['max_paragraph_len'],
                                     debug,
                                     hyper_params['model'],
                                     previous_hidden_size=previous_hidden_size)
    else:
        raise ValueError('model name {} not supported'.format(
            hyper_params['model']['name']))
    return model
コード例 #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        help='config file with generic hyper-parameters,  such as optimizer, '
                             'batch_size, ... -  in yaml format', required=True)
    parser.add_argument('--gpu', help='list of gpu ids to use. default is cpu. example: --gpu 0 1',
                        type=int, nargs='+', default=0)
    parser.add_argument('--validation-interval', help='how often to run validation in one epoch - '
                                                      'e.g., 0.5 means halfway - default 0.5',
                        type=float, default=0.5)
    parser.add_argument('--output', help='where to store models', required=True)
    parser.add_argument('--no-model-restoring', help='will not restore any previous model weights ('
                                                     'even if present)', action='store_true')
    parser.add_argument('--train', help='will not train - will just evaluate on dev',
                        action='store_true')
    parser.add_argument('--validate', help='will not train - will just evaluate on dev',
                        action='store_true')
    parser.add_argument('--predict', help='will predict on the json file you provide as an arg')
    parser.add_argument('--redirect-log', help='will intercept any stdout/err and log it',
                        action='store_true')
    parser.add_argument('--debug', help='will log more info', action='store_true')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    if args.redirect_log:
        sys.stdout = LoggerWriter(logger.info)
        sys.stderr = LoggerWriter(logger.warning)

    with open(args.config, 'r') as stream:
        hyper_params = load(stream, Loader=yaml.FullLoader)

    check_and_log_hp(
        ['train_file', 'dev_file', 'cache_folder', 'batch_size', 'model_name',
         'max_question_len', 'max_paragraph_len', 'patience', 'gradient_clipping',
         'loss_type', 'optimizer_type', 'freeze_bert', 'pooling_type', 'precision',
         'top_layer_sizes', 'dropout'],
        hyper_params)

    os.makedirs(hyper_params['cache_folder'], exist_ok=True)

    model_name = hyper_params['model_name']
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    train_dataloader, dev_dataloader = generate_dataloaders(
        hyper_params['train_file'], hyper_params['dev_file'], hyper_params['cache_folder'],
        hyper_params['max_question_len'], hyper_params['max_paragraph_len'],
        tokenizer, hyper_params['batch_size'])

    bert_question = AutoModel.from_pretrained(model_name)
    bert_paragraph = AutoModel.from_pretrained(model_name)

    bert_question_encoder = BertEncoder(bert_question, hyper_params['max_question_len'],
                                        hyper_params['freeze_bert'], hyper_params['pooling_type'],
                                        hyper_params['top_layer_sizes'], hyper_params['dropout'])
    bert_paragraph_encoder = BertEncoder(bert_paragraph, hyper_params['max_paragraph_len'],
                                         hyper_params['freeze_bert'], hyper_params['pooling_type'],
                                         hyper_params['top_layer_sizes'], hyper_params['dropout'])

    ret = Retriever(bert_question_encoder, bert_paragraph_encoder, tokenizer,
                    hyper_params['max_question_len'], hyper_params['max_paragraph_len'], args.debug)

    os.makedirs(args.output, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(args.output, '{epoch}-{val_loss:.2f}-{val_acc:.2f}'),
        save_top_k=1,
        verbose=True,
        monitor='val_acc',
        mode='max'
    )

    early_stopping = EarlyStopping('val_acc', mode='max', patience=hyper_params['patience'])

    if hyper_params['precision'] not in {16, 32}:
        raise ValueError('precision should be either 16 or 32')

    if not args.no_model_restoring:
        ckpt_to_resume = try_to_restore_model_weights(args.output)

    else:
        ckpt_to_resume = None
        logger.info('will not try to restore previous models because --no-model-restoring')

    trainer = pl.Trainer(
        gpus=args.gpu,
        distributed_backend='dp',
        val_check_interval=args.validation_interval,
        min_epochs=1,
        gradient_clip_val=hyper_params['gradient_clipping'],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping,
        precision=hyper_params['precision'],
        resume_from_checkpoint=ckpt_to_resume)

    # note we are passing dev_dataloader for both dev and test
    ret_trainee = RetrieverTrainer(ret, train_dataloader, dev_dataloader, dev_dataloader,
                                   hyper_params['loss_type'], hyper_params['optimizer_type'])

    if args.train:
        trainer.fit(ret_trainee)
    elif args.validate:
        trainer.test(ret_trainee)
    elif args.predict:
        model_ckpt = torch.load(
            ckpt_to_resume, map_location=torch.device("cpu")
        )
        ret_trainee.load_state_dict(model_ckpt["state_dict"])
        evaluate_model(ret_trainee, qa_pairs_json_file=args.predict)
    else:
        logger.warning('please select one between --train / --validate / --test')
コード例 #10
0
def init_model(hyper_params, num_workers, output, validation_interval, gpu,
               no_model_restoring, debug):

    check_and_log_hp([
        'train_file', 'dev_files', 'test_file', 'batch_size', 'tokenizer_name',
        'model', 'max_question_len', 'max_paragraph_len', 'patience',
        'gradient_clipping', 'max_epochs', 'loss_type', 'optimizer',
        'precision', 'accumulate_grad_batches', 'seed'
    ], hyper_params)

    if hyper_params['seed'] is not None:
        # fix the seed
        torch.manual_seed(hyper_params['seed'])
        np.random.seed(hyper_params['seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer_name = hyper_params['tokenizer_name']
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    ret = load_model(hyper_params, tokenizer, debug)

    os.makedirs(output, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(filepath=os.path.join(
        output, '{epoch}-{val_acc_0:.2f}-{val_loss_0:.2f}'),
                                          save_top_k=1,
                                          verbose=True,
                                          monitor='val_acc_0',
                                          mode='max',
                                          period=0)
    early_stopping = EarlyStopping('val_acc_0',
                                   mode='max',
                                   patience=hyper_params['patience'])

    if (hyper_params['model'].get('name') == 'bert_encoder'
            and hyper_params['model'].get('cache_size', 0) > 0):
        cbs = [CacheManagerCallback(ret, output)]
    else:
        cbs = []

    if hyper_params['precision'] not in {16, 32}:
        raise ValueError('precision should be either 16 or 32')
    if not no_model_restoring:
        ckpt_to_resume = try_to_restore_model_weights(output)
    else:
        ckpt_to_resume = None
        logger.info(
            'will not try to restore previous models because --no-model-restoring'
        )
    tb_logger = loggers.TensorBoardLogger('experiment_logs')
    for hparam in list(hyper_params):
        tb_logger.experiment.add_text(hparam, str(hyper_params[hparam]))

    trainer = pl.Trainer(
        logger=tb_logger,
        gpus=gpu,
        distributed_backend='dp',
        val_check_interval=validation_interval,
        min_epochs=1,
        gradient_clip_val=hyper_params['gradient_clipping'],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping,
        callbacks=cbs,
        precision=hyper_params['precision'],
        resume_from_checkpoint=ckpt_to_resume,
        accumulate_grad_batches=hyper_params['accumulate_grad_batches'],
        max_epochs=hyper_params['max_epochs'])

    dev_dataloaders, test_dataloader, train_dataloader = get_data_loaders(
        hyper_params, num_workers, tokenizer)

    ret_trainee = RetrieverTrainer(ret, train_dataloader, dev_dataloaders,
                                   test_dataloader, hyper_params['loss_type'],
                                   hyper_params['optimizer'])
    return ckpt_to_resume, ret_trainee, trainer
コード例 #11
0
ファイル: main.py プロジェクト: kiminh/bert_reranker
def init_model(
    hyper_params,
    num_workers,
    output,
    validation_interval,
    gpu,
    no_model_restoring,
    debug,
    print_sentence_stats
):

    check_and_log_hp(
        [
            "train_file",
            "dev_files",
            "test_file",
            "batch_size",
            "tokenizer_name",
            "model",
            "max_question_len",
            "max_paragraph_len",
            "patience",
            "gradient_clipping",
            "max_epochs",
            "loss_type",
            "optimizer",
            "precision",
            "accumulate_grad_batches",
            "seed",
            "logging",
            "keep_ood"
        ],
        hyper_params,
    )

    if hyper_params["seed"] is not None:
        # fix the seed
        torch.manual_seed(hyper_params["seed"])
        np.random.seed(hyper_params["seed"])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer_name = hyper_params["tokenizer_name"]
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    ret = load_model(hyper_params, tokenizer, debug)

    os.makedirs(output, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(output, "{epoch}-{val_acc_0:.2f}-{val_loss_0:.2f}"),
        save_top_k=1,
        verbose=True,
        monitor="val_acc_0",
        mode="max",
        period=0,
    )
    early_stopping = EarlyStopping(
        "val_acc_0", mode="max", patience=hyper_params["patience"]
    )

    if (
        hyper_params["model"].get("name") == "bert_encoder"
        and hyper_params["model"].get("cache_size", 0) > 0
    ):
        cbs = [CacheManagerCallback(ret, output)]
    else:
        cbs = []

    if hyper_params["precision"] not in {16, 32}:
        raise ValueError("precision should be either 16 or 32")
    if not no_model_restoring:
        ckpt_to_resume = try_to_restore_model_weights(output)
    else:
        ckpt_to_resume = None
        logger.info(
            "will not try to restore previous models because --no-model-restoring"
        )
    if hyper_params["logging"]["logger"] == "tensorboard":
        pl_logger = loggers.TensorBoardLogger("experiment_logs")
        for hparam in list(hyper_params):
            pl_logger.experiment.add_text(hparam, str(hyper_params[hparam]))
    elif hyper_params["logging"]["logger"] == "wandb":
        orion_trial_id = os.environ.get('ORION_TRIAL_ID')
        name = orion_trial_id if orion_trial_id else hyper_params["logging"]["name"]
        pl_logger = WandbLogger(
            name=name,
            project=hyper_params["logging"]["project"],
            group=hyper_params["logging"]["group"],
        )
        pl_logger.log_hyperparams(hyper_params)
    else:
        raise ValueError(
            logger.info(
                "logger {} is not implemnted".format(hyper_params["logging"]["logger"])
            )
        )

    trainer = pl.Trainer(
        logger=pl_logger,
        gpus=gpu,
        distributed_backend="dp",
        val_check_interval=validation_interval,
        min_epochs=1,
        gradient_clip_val=hyper_params["gradient_clipping"],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping,
        callbacks=cbs,
        precision=hyper_params["precision"],
        resume_from_checkpoint=ckpt_to_resume,
        accumulate_grad_batches=hyper_params["accumulate_grad_batches"],
        max_epochs=hyper_params["max_epochs"],
    )

    dev_dataloaders, test_dataloader, train_dataloader = get_data_loaders(
        hyper_params, num_workers, tokenizer
    )

    if print_sentence_stats:
        evaluate_tokenizer_cutoff(
            hyper_params["train_file"],
            tokenizer,
            hyper_params["max_question_len"],
            hyper_params["max_paragraph_len"],
        )

    ret_trainee = RetrieverTrainer(
        ret,
        train_dataloader,
        dev_dataloaders,
        test_dataloader,
        hyper_params["loss_type"],
        hyper_params["optimizer"],
    )
    return ckpt_to_resume, ret_trainee, trainer
コード例 #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help='config file with generic hyper-parameters,  such as optimizer, '
        'batch_size, ... -  in yaml format',
        required=True)
    parser.add_argument(
        '--gpu',
        help='list of gpu ids to use. default is cpu. example: --gpu 0 1',
        type=int,
        nargs='+',
        default=0)
    parser.add_argument('--validation-interval',
                        help='how often to run validation in one epoch - '
                        'e.g., 0.5 means halfway - default 0.5',
                        type=float,
                        default=0.5)
    parser.add_argument('--output',
                        help='where to store models',
                        required=True)
    parser.add_argument('--no-model-restoring',
                        help='will not restore any previous model weights ('
                        'even if present)',
                        action='store_true')
    parser.add_argument('--train',
                        help='will not train - will just evaluate on dev',
                        action='store_true')
    parser.add_argument('--validate',
                        help='will not train - will just evaluate on dev',
                        action='store_true')
    parser.add_argument(
        '--predict',
        help='will predict on the json file you provide as an arg')
    parser.add_argument('--predict-to',
                        help='(optiona) write predictions here)')
    parser.add_argument('--redirect-log',
                        help='will intercept any stdout/err and log it',
                        action='store_true')
    parser.add_argument('--debug',
                        help='will log more info',
                        action='store_true')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    if args.redirect_log:
        sys.stdout = LoggerWriter(logger.info)
        sys.stderr = LoggerWriter(logger.warning)

    with open(args.config, 'r') as stream:
        hyper_params = load(stream, Loader=yaml.FullLoader)

    check_and_log_hp([
        'train_file', 'dev_files', 'test_file', 'cache_folder', 'batch_size',
        'tokenizer_name', 'model', 'max_question_len', 'max_paragraph_len',
        'patience', 'gradient_clipping', 'max_epochs', 'loss_type',
        'optimizer', 'precision', 'accumulate_grad_batches', 'seed'
    ], hyper_params)

    if hyper_params['seed'] is not None:
        # fix the seed
        torch.manual_seed(hyper_params['seed'])
        np.random.seed(hyper_params['seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(hyper_params['cache_folder'], exist_ok=True)

    tokenizer_name = hyper_params['tokenizer_name']
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    train_dataloader = generate_dataloader(hyper_params['train_file'],
                                           hyper_params['cache_folder'],
                                           hyper_params['max_question_len'],
                                           hyper_params['max_paragraph_len'],
                                           tokenizer,
                                           hyper_params['batch_size'])

    dev_dataloaders = []
    for dev_file in hyper_params['dev_files'].values():
        dev_dataloaders.append(
            generate_dataloader(dev_file, hyper_params['cache_folder'],
                                hyper_params['max_question_len'],
                                hyper_params['max_paragraph_len'], tokenizer,
                                hyper_params['batch_size']))

    test_dataloader = generate_dataloader(hyper_params['test_file'],
                                          hyper_params['cache_folder'],
                                          hyper_params['max_question_len'],
                                          hyper_params['max_paragraph_len'],
                                          tokenizer,
                                          hyper_params['batch_size'])

    ret = load_model(hyper_params, tokenizer, args.debug)

    os.makedirs(args.output, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(filepath=os.path.join(
        args.output, '{epoch}-{val_acc_0:.2f}-{val_loss_0:.2f}'),
                                          save_top_k=1,
                                          verbose=True,
                                          monitor='val_acc_0',
                                          mode='max')

    early_stopping = EarlyStopping('val_acc_0',
                                   mode='max',
                                   patience=hyper_params['patience'])

    if hyper_params['precision'] not in {16, 32}:
        raise ValueError('precision should be either 16 or 32')

    if not args.no_model_restoring:
        ckpt_to_resume = try_to_restore_model_weights(args.output)
    else:
        ckpt_to_resume = None
        logger.info(
            'will not try to restore previous models because --no-model-restoring'
        )

    tb_logger = loggers.TensorBoardLogger('experiment_logs')
    for hparam in list(hyper_params):
        tb_logger.experiment.add_text(hparam, str(hyper_params[hparam]))

    trainer = pl.Trainer(
        logger=tb_logger,
        gpus=args.gpu,
        distributed_backend='dp',
        val_check_interval=args.validation_interval,
        min_epochs=1,
        gradient_clip_val=hyper_params['gradient_clipping'],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping,
        precision=hyper_params['precision'],
        resume_from_checkpoint=ckpt_to_resume,
        accumulate_grad_batches=hyper_params['accumulate_grad_batches'],
        max_epochs=hyper_params['max_epochs'])

    ret_trainee = RetrieverTrainer(ret, train_dataloader, dev_dataloaders,
                                   test_dataloader, hyper_params['loss_type'],
                                   hyper_params['optimizer'])

    if args.train:
        trainer.fit(ret_trainee)
    elif args.validate:
        trainer.test(ret_trainee)
    elif args.predict:
        model_ckpt = torch.load(ckpt_to_resume,
                                map_location=torch.device("cpu"))
        ret_trainee.load_state_dict(model_ckpt["state_dict"])
        evaluate_model(ret_trainee,
                       qa_pairs_json_file=args.predict,
                       predict_to=args.predict_to)
    else:
        logger.warning(
            'please select one between --train / --validate / --test')