def main():
    parser = argparse.ArgumentParser(description='Train a model on TriviaQA web')
    parser.add_argument('mode', choices=["paragraph-level", "confidence", "merge",
                                         "shared-norm", "sigmoid", "shared-norm-600"])
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument('-n', '--n_processes', type=int, default=2,
                        help="Number of processes (i.e., select which paragraphs to train on) "
                             "the data with")
    args = parser.parse_args()
    mode = args.mode

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")

    model = get_model(100, 140, mode, WithIndicators())

    stop = NltkPlusStopWords(True)

    if mode == "paragraph-level":
        extract = ExtractSingleParagraph(MergeParagraphs(400), TopTfIdf(stop, 1), model.preprocessor, intern=True)
    elif mode == "shared-norm-600":
        extract = ExtractMultiParagraphs(MergeParagraphs(600), TopTfIdf(stop, 4), model.preprocessor, intern=True)
    else:
        extract = ExtractMultiParagraphs(MergeParagraphs(400), TopTfIdf(stop, 4), model.preprocessor, intern=True)
    
    if mode == "paragraph-level":
        n_epochs = 16
        train = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True))
        test = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenKey(), False))
        n_dev, n_train = 21000, 12000
        eval = [LossEvaluator(), SpanEvaluator([4, 8], "triviaqa")]
    else:
        eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge")]
        # we sample two paragraphs per a (question, doc) pair, so evaluate on fewer questions
        n_dev, n_train = 15000, 8000

        if mode == "confidence" or mode == "sigmoid":
            if mode == "sigmoid":
                # Trains very slowly, do this at your own risk
                n_epochs = 71
            else:
                n_epochs = 28
            test = RandomParagraphSetDatasetBuilder(120, "flatten", True, 1)
            train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), 0, 1)
        else:
            n_epochs = 14
            test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, 1)
            train = StratifyParagraphSetsBuilder(35, mode == "merge", True, 1)

    data = TriviaQaWebDataset()

    params = get_triviaqa_train_params(n_epochs, n_dev, n_train)

    data = PreprocessedData(data, extract, train, test, eval_on_verified=False)

    data.preprocess(args.n_processes, 1000)

    with open(__file__, "r") as f:
        notes = f.read()
    notes = "*" * 10 + "\nMode: " + args.mode + "\n" + "*"*10 + "\n" + notes

    trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes)
Exemple #2
0
def main():
    parser = argparse.ArgumentParser("Train our ELMo model on SQuAD")
    parser.add_argument("loss_mode", choices=['default', 'confidence'])
    parser.add_argument("output_dir")
    parser.add_argument("--dim", type=int, default=90)
    parser.add_argument("--l2", type=float, default=0)
    parser.add_argument("--mode",
                        choices=["input", "output", "both", "none"],
                        default="both")
    parser.add_argument("--top_layer_only", action="store_true")
    parser.add_argument("--no-tfidf",
                        action='store_true',
                        help="Don't add TF-IDF negative examples")
    args = parser.parse_args()

    out = args.output_dir + "-" + datetime.now().strftime("%m%d-%H%M%S")

    dim = args.dim
    recurrent_layer = CudnnGru(dim, w_init=TruncatedNormal(stddev=0.05))

    if args.loss_mode == 'default':
        n_epochs = 24
        answer_encoder = SingleSpanAnswerEncoder()
        predictor = BoundsPredictor(
            ChainBiMapper(first_layer=recurrent_layer,
                          second_layer=recurrent_layer))
        batcher = ClusteredBatcher(45, ContextLenKey(), False, False)
        data = DocumentQaTrainingData(SquadCorpus(), None, batcher, batcher)
    elif args.loss_mode == 'confidence':
        if args.no_tfidf:
            prepro = SquadDefault()
            n_epochs = 15
        else:
            prepro = SquadTfIdfRanker(NltkPlusStopWords(True), 4, True)
            n_epochs = 50
        answer_encoder = DenseMultiSpanAnswerEncoder()
        predictor = ConfidencePredictor(ChainBiMapper(
            first_layer=recurrent_layer,
            second_layer=recurrent_layer,
        ),
                                        AttentionEncoder(),
                                        FullyConnected(80, activation="tanh"),
                                        aggregate="sum")
        eval_dataset = RandomParagraphSetDatasetBuilder(
            100, 'flatten', True, 0)
        train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True,
                                          False)
        data = PreprocessedData(SquadCorpus(),
                                prepro,
                                StratifyParagraphsBuilder(train_batching, 1),
                                eval_dataset,
                                eval_on_verified=False)
        data.preprocess(1)

    params = trainer.TrainParams(trainer.SerializableOptimizer(
        "Adadelta", dict(learning_rate=1.0)),
                                 ema=0.999,
                                 max_checkpoints_to_keep=2,
                                 async_encoding=10,
                                 num_epochs=n_epochs,
                                 log_period=30,
                                 eval_period=1200,
                                 save_period=1200,
                                 best_weights=("dev", "b17/text-f1"),
                                 eval_samples=dict(dev=None, train=8000))

    lm_reduce = MapperSeq(
        ElmoLayer(args.l2,
                  layer_norm=False,
                  top_layer_only=args.top_layer_only),
        DropoutLayer(0.5),
    )
    model = AttentionWithElmo(
        encoder=DocumentAndQuestionEncoder(answer_encoder),
        lm_model=SquadContextConcatSkip(),
        append_before_atten=(args.mode == "both" or args.mode == "output"),
        append_embed=(args.mode == "both" or args.mode == "input"),
        max_batch_size=128,
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d",
                                     word_vec_init_scale=0,
                                     learn_unk=False,
                                     cpu=True),
        char_embed=CharWordEmbedder(LearnedCharEmbedder(word_size_th=14,
                                                        char_th=49,
                                                        char_dim=20,
                                                        init_scale=0.05,
                                                        force_cpu=True),
                                    MaxPool(Conv1d(100, 5, 0.8)),
                                    shared_parameters=True),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        lm_reduce=None,
        lm_reduce_shared=lm_reduce,
        per_sentence=False,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=SequenceMapperSeq(
            FullyConnected(dim * 2, activation="relu"),
            ResidualLayer(
                SequenceMapperSeq(
                    VariationalDropoutLayer(0.8),
                    recurrent_layer,
                    VariationalDropoutLayer(0.8),
                    StaticAttentionSelf(TriLinear(bias=True),
                                        ConcatWithProduct()),
                    FullyConnected(dim * 2, activation="relu"),
                )), VariationalDropoutLayer(0.8)),
        predictor=predictor)

    with open(__file__, "r") as f:
        notes = f.read()
        notes = str(sorted(args.__dict__.items(),
                           key=lambda x: x[0])) + "\n" + notes

    trainer.start_training(
        data, model, params,
        [LossEvaluator(),
         SpanEvaluator(bound=[17], text_eval="squad")], ModelDir(out), notes)
Exemple #3
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument("-o", "--official_output", type=str,
                        help="where to output an official result file")
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('--answer_bounds', nargs='+', type=int, default=[17],
                        help="Max size of answer")
    parser.add_argument('-b', '--batch_size', type=int, default=200,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # Add ja_test choice to test Multilingual QA dataset.
    parser.add_argument(
        '-c', '--corpus', choices=["dev", "train", "ja_test", "pred"], default="dev")
    parser.add_argument('--no_ema', action="store_true",
                        help="Don't use EMA weights even if they exist")
    # Add ja_test choice to test Multilingual QA pipeline.
    parser.add_argument('-p', '--pred_filepath', default=None,
                        help="The csv file path if you try pred mode")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = SquadCorpus()
    if args.corpus == "dev":
        questions = corpus.get_dev()
    # Add ja_test choice to test Multilingual QA pipeline.
    elif args.corpus == "ja_test":
        questions = corpus.get_ja_test()
    # This is for prediction mode for MLQA pipeline.
    elif args.corpus == "pred":
        questions = create_pred_dataset(args.pred_filepath)
    else:
        questions = corpus.get_train()
    questions = split_docs(questions)

    if args.sample_questions:
        np.random.RandomState(0).shuffle(
            sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    questions.sort(key=lambda x: x.n_context_words, reverse=True)
    dataset = ParagraphAndQuestionDataset(
        questions, FixedOrderBatcher(args.batch_size, True))

    evaluators = [SpanEvaluator(args.answer_bounds, text_eval="squad")]
    if args.official_output is not None:
        evaluators.append(RecordSpanPrediction(args.answer_bounds[0]))

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: dataset},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema)[args.corpus]

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))

    # Save the official output
    if args.official_output is not None:
        quid_to_para = {}
        for x in questions:
            quid_to_para[x.question_id] = x.paragraph

        q_id_to_answers = {}
        q_ids = evaluation.per_sample["question_id"]
        spans = evaluation.per_sample["predicted_span"]
        for q_id, (start, end) in zip(q_ids, spans):
            text = quid_to_para[q_id].get_original_text(start, end)
            q_id_to_answers[q_id] = text

        with open(args.official_output, "w") as f:
            json.dump(q_id_to_answers, f)
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument("-o", "--official_output", type=str, help="where to output an official result file")
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('--answer_bounds', nargs='+', type=int, default=[17],
                        help="Max size of answer")
    parser.add_argument('-b', '--batch_size', type=int, default=200,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist")
    parser.add_argument('--none_prob', action="store_true", help="Output none probability for samples")
    parser.add_argument('--elmo', action="store_true", help="Use elmo model")
    parser.add_argument('--per_question_loss_file', type=str, default=None,
            help="Run question by question and output a question_id -> loss output to this file")
    args = parser.parse_known_args()[0]

    model_dir = ModelDir(args.model)

    corpus = SquadCorpus()
    if args.corpus == "dev":
        questions = corpus.get_dev()
    else:
        questions = corpus.get_train()
    questions = split_docs(questions)

    if args.sample_questions:
        np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    questions.sort(key=lambda x:x.n_context_words, reverse=True)
    dataset = ParagraphAndQuestionDataset(questions, FixedOrderBatcher(args.batch_size, True))

    evaluators = [SpanEvaluator(args.answer_bounds, text_eval="squad")]
    if args.official_output is not None:
        evaluators.append(RecordSpanPrediction(args.answer_bounds[0]))
    if args.per_question_loss_file is not None:
        evaluators.append(RecordSpanPredictionScore(args.answer_bounds[0], args.batch_size, args.none_prob))

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()
    if args.elmo:
        model.lm_model.lm_vocab_file = './elmo-params/squad_train_dev_all_unique_tokens.txt'
        model.lm_model.options_file = './elmo-params/options_squad_lm_2x4096_512_2048cnn_2xhighway_skip.json'
        model.lm_model.weight_file = './elmo-params/squad_context_concat_lm_2x4096_512_2048cnn_2xhighway_skip.hdf5'
        model.lm_model.embed_weights_file = None


    evaluation = trainer.test(model, evaluators, {args.corpus: dataset},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema)[args.corpus]

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))

    # Save the official output
    if args.official_output is not None:
        quid_to_para = {}
        for x in questions:
            quid_to_para[x.question_id] = x.paragraph

        q_id_to_answers = {}
        q_ids = evaluation.per_sample["question_id"]
        spans = evaluation.per_sample["predicted_span"]
        for q_id, (start, end) in zip(q_ids, spans):
            text = quid_to_para[q_id].get_original_text(start, end)
            q_id_to_answers[q_id] = text

        with open(args.official_output, "w") as f:
            json.dump(q_id_to_answers, f)

    if args.per_question_loss_file is not None:
        print("Saving result")
        output_file = args.per_question_loss_file
        ids = evaluation.per_sample["question_ids"]
        f1s = evaluation.per_sample["text_f1"]
        ems = evaluation.per_sample["text_em"]
        losses = evaluation.per_sample["loss"]

        if args.none_prob:
            none_probs = evaluation.per_sample["none_probs"]
            """
            results = {question_id: {'f1': float(f1), 'em': float(em), 'loss': float(loss), 'none_prob': float(none_prob)} for question_id, f1, em, loss, none_prob in zip(ids, f1s, ems, losses, none_probs)}
            """
            results = {question_id: float(none_prob) for question_id, none_prob in zip(ids, none_probs)}
        else:
            results = {question_id: {'f1': float(f1), 'em': float(em), 'loss': float(loss)} for question_id, f1, em, loss in zip(ids, f1s, ems, losses)}


        with open(output_file, 'w') as f:
            json.dump(results, f)
Exemple #5
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a model on document-level SQuAD')
    parser.add_argument(
        'mode',
        choices=["paragraph", "confidence", "shared-norm", "merge", "sigmoid"])
    parser.add_argument("name", help="Output directory")
    args = parser.parse_args()
    mode = args.mode
    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")

    corpus = SquadCorpus()
    if mode == "merge":
        # Adds paragraph start tokens, since we will be concatenating paragraphs together
        pre = WithIndicators(True, para_tokens=False, doc_start_token=False)
    else:
        pre = None

    model = get_model(50, 100, args.mode, pre)

    if mode == "paragraph":
        # Run in the "standard" known-paragraph setting
        if model.preprocessor is not None:
            raise NotImplementedError()
        n_epochs = 26
        train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True,
                                          False)
        eval_batching = ClusteredBatcher(45, ContextLenKey(), False, False)
        data = DocumentQaTrainingData(corpus, None, train_batching,
                                      eval_batching)
        eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")]
    else:
        eval_set_mode = {
            "confidence": "flatten",
            "sigmoid": "flatten",
            "shared-norm": "group",
            "merge": "merge"
        }[mode]
        eval_dataset = RandomParagraphSetDatasetBuilder(
            100, eval_set_mode, True, 0)

        if mode == "confidence" or mode == "sigmoid":
            if mode == "sigmoid":
                # needs to be trained for a really long time for reasons unknown, even this might be too small
                n_epochs = 100
            else:
                n_epochs = 50  # more epochs since we only "see" the label very other epoch-osh
            train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3),
                                              True, False)
            data = PreprocessedData(
                SquadCorpus(),
                SquadTfIdfRanker(NltkPlusStopWords(True), 4, True,
                                 model.preprocessor),
                StratifyParagraphsBuilder(train_batching, 1),
                eval_dataset,
                eval_on_verified=False,
            )
        else:
            n_epochs = 26
            data = PreprocessedData(
                SquadCorpus(),
                SquadTfIdfRanker(NltkPlusStopWords(True), 4, True,
                                 model.preprocessor),
                StratifyParagraphSetsBuilder(25, args.mode == "merge", True,
                                             1),
                eval_dataset,
                eval_on_verified=False,
            )

        eval = [LossEvaluator(), MultiParagraphSpanEvaluator(17, "squad")]
        data.preprocess(1)

    with open(__file__, "r") as f:
        notes = f.read()
        notes = args.mode + "\n" + notes

    trainer.start_training(data, model, train_params(n_epochs), eval,
                           model_dir.ModelDir(out), notes)
Exemple #6
0
def main():
    """
    A close-as-possible impelemntation of BiDaF, its based on the `dev` tensorflow 1.1 branch of Ming's repo
    which, in particular, uses Adam not Adadelta. I was not able to replicate the results in paper using Adadelta,
    but with Adam i was able to get to 78.0 F1 on the dev set with this scripts. I believe this approach is
    an exact reproduction up the code in the repo, up to initializations.

    Notes: Exponential Moving Average is very important, as is early stopping. This is also in particualr best run
    on a GPU due to the large number of parameters and batch size involved.
    """
    out = get_output_name_from_cli()

    train_params = TrainParams(SerializableOptimizer(
        "Adam", dict(learning_rate=0.001)),
                               num_epochs=12,
                               ema=0.999,
                               async_encoding=10,
                               log_period=30,
                               eval_period=1000,
                               save_period=1000,
                               eval_samples=dict(dev=None, train=8000))

    # recurrent_layer = BiRecurrentMapper(LstmCellSpec(100, keep_probs=0.8))
    # recurrent_layer = FusedLstm()
    recurrent_layer = SequenceMapperSeq(DropoutLayer(0.8), CudnnLstm(100))

    model = Attention(
        encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()),
        word_embed=FixedWordEmbedder(vec_name="glove.6B.100d",
                                     word_vec_init_scale=0,
                                     learn_unk=False),
        char_embed=CharWordEmbedder(embedder=LearnedCharEmbedder(16, 49, 8),
                                    layer=ReduceLayer("max",
                                                      Conv1d(100, 5, 0.8),
                                                      mask=False),
                                    shared_parameters=True),
        word_embed_layer=None,
        embed_mapper=SequenceMapperSeq(HighwayLayer(activation="relu"),
                                       HighwayLayer(activation="relu"),
                                       recurrent_layer),
        preprocess=None,
        question_mapper=None,
        context_mapper=None,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=NullMapper(),
        predictor=BoundsPredictor(
            ChainConcat(start_layer=SequenceMapperSeq(recurrent_layer,
                                                      recurrent_layer),
                        end_layer=recurrent_layer)),
    )

    with open(__file__, "r") as f:
        notes = f.read()

    eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")]

    corpus = SquadCorpus()
    train_batching = ClusteredBatcher(60, ContextLenBucketedKey(3), True,
                                      False)
    eval_batching = ClusteredBatcher(60, ContextLenKey(), False, False)
    data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching)

    trainer.start_training(data, model, train_params, eval,
                           model_dir.ModelDir(out), notes)
def main():
    parser = argparse.ArgumentParser("Train rejection model on SQuAD")

    parser.add_argument("--corpus_dir", type=str, default="~/data/document-qa")
    parser.add_argument("--output_dir",
                        type=str,
                        default="~/model/document-qa/squad")
    parser.add_argument("--lm_dir", type=str, default="~/data/lm")
    parser.add_argument("--exp_id", type=str, default="rejection")

    parser.add_argument("--lr", type=float, default=0.5)
    parser.add_argument("--epoch", type=int, default=20)

    parser.add_argument("--dim", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=45)

    parser.add_argument("--l2", type=float, default=0)
    parser.add_argument("--mode",
                        choices=["input", "output", "both", "none"],
                        default="both")
    parser.add_argument("--top_layer_only", action="store_true")

    args = parser.parse_args()

    print("Arguments : ", args)

    out = args.output_dir + "_" + args.exp_id + "_lr" + str(
        args.lr) + "-" + datetime.now().strftime("%m%d-%H%M%S")
    dim = args.dim
    batch_size = args.batch_size
    out = expanduser(out)
    lm_dir = expanduser(args.lm_dir)
    corpus_dir = expanduser(args.corpus_dir)

    print("Make global recurrent_layer...")
    recurrent_layer = CudnnGru(
        dim, w_init=tf.keras.initializers.TruncatedNormal(stddev=0.05))
    params = trainer.TrainParams(trainer.SerializableOptimizer(
        "Adadelta", dict(learning_rate=args.lr)),
                                 ema=0.999,
                                 max_checkpoints_to_keep=2,
                                 async_encoding=10,
                                 num_epochs=args.epoch,
                                 log_period=30,
                                 eval_period=1200,
                                 save_period=1200,
                                 best_weights=("dev", "b17/text-f1"),
                                 eval_samples=dict(dev=None, train=8000))

    lm_reduce = MapperSeq(
        ElmoLayer(args.l2,
                  layer_norm=False,
                  top_layer_only=args.top_layer_only),
        DropoutLayer(0.5),
    )

    model = AttentionWithElmo(
        encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()),
        lm_model=SquadContextConcatSkip(lm_dir=lm_dir),
        append_before_atten=(args.mode == "both" or args.mode == "output"),
        append_embed=(args.mode == "both" or args.mode == "input"),
        max_batch_size=128,
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d",
                                     word_vec_init_scale=0,
                                     learn_unk=False,
                                     cpu=True),
        char_embed=CharWordEmbedder(LearnedCharEmbedder(word_size_th=14,
                                                        char_th=49,
                                                        char_dim=20,
                                                        init_scale=0.05,
                                                        force_cpu=True),
                                    MaxPool(Conv1d(100, 5, 0.8)),
                                    shared_parameters=True),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        lm_reduce=None,
        lm_reduce_shared=lm_reduce,
        per_sentence=False,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=SequenceMapperSeq(
            FullyConnected(dim * 2, activation="relu"),
            ResidualLayer(
                SequenceMapperSeq(
                    VariationalDropoutLayer(0.8),
                    recurrent_layer,
                    VariationalDropoutLayer(0.8),
                    StaticAttentionSelf(TriLinear(bias=True),
                                        ConcatWithProduct()),
                    FullyConnected(dim * 2, activation="relu"),
                )), VariationalDropoutLayer(0.8)),
        predictor=BoundsPredictor(
            ChainBiMapper(first_layer=recurrent_layer,
                          second_layer=recurrent_layer)))

    batcher = ClusteredBatcher(batch_size, ContextLenKey(), False, False)
    data = DocumentQaTrainingData(SquadCorpus(corpus_dir), None, batcher,
                                  batcher)

    with open(__file__, "r") as f:
        notes = f.read()
        notes = str(sorted(args.__dict__.items(),
                           key=lambda x: x[0])) + "\n" + notes

    trainer.start_training(
        data, model, params,
        [LossEvaluator(),
         SpanEvaluator(bound=[17], text_eval="squad")], ModelDir(out), notes)