def get_model(char_th: int, dim: int, mode: str, preprocess: Optional[TextPreprocessor]):
    recurrent_layer = CudnnGru(dim, w_init=TruncatedNormal(stddev=0.05))
    #pdb.set_trace()
    if mode.startswith("shared-norm"):
        answer_encoder = GroupedSpanAnswerEncoder()
        predictor = BoundsPredictor(
            ChainBiMapper(
                first_layer=recurrent_layer,
                second_layer=recurrent_layer
            ),
            span_predictor=IndependentBoundsGrouped(aggregate="sum")
        )
    elif mode == "confidence":
        answer_encoder = DenseMultiSpanAnswerEncoder()
        predictor = ConfidencePredictor(
            ChainBiMapper(
                first_layer=recurrent_layer,
                second_layer=recurrent_layer,
            ),
            AttentionEncoder(),
            FullyConnected(80, activation="tanh"),
            aggregate="sum"
        )
    elif mode == "sigmoid":
        answer_encoder = DenseMultiSpanAnswerEncoder()
        predictor = BoundsPredictor(
            ChainBiMapper(
                first_layer=recurrent_layer,
                second_layer=recurrent_layer
            ),
            span_predictor=IndependentBoundsSigmoidLoss()
        )
    elif mode == "paragraph" or mode == "merge":
        answer_encoder = MultiChoiceAnswerEncoder()
        predictor = MultiChoicePredictor(4)
    else:
        raise NotImplementedError(mode)

    return Attention(
        encoder=DocumentAndQuestionEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=char_th, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        preprocess=preprocess,
        word_embed_layer=None,
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        question_mapper=None,
        context_mapper=None,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=SequenceMapperSeq(FullyConnected(dim * 2, activation="relu"),
                                        ResidualLayer(SequenceMapperSeq(
                                            VariationalDropoutLayer(0.8),
                                            recurrent_layer,
                                            VariationalDropoutLayer(0.8),
                                            StaticAttentionSelf(TriLinear(bias=True), ConcatWithProduct()),
                                            FullyConnected(dim * 2, activation="relu"),
                                        )),
                                        VariationalDropoutLayer(0.8)),
        #templayer = BiLinear(bias = True),
        predictor=predictor
    )
Esempio n. 2
0
def main():
    """
    A close-as-possible impelemntation of BiDaF, its based on the `dev` tensorflow 1.1 branch of Ming's repo
    which, in particular, uses Adam not Adadelta. I was not able to replicate the results in paper using Adadelta,
    but with Adam i was able to get to 78.0 F1 on the dev set with this scripts. I believe this approach is
    an exact reproduction up the code in the repo, up to initializations.

    Notes: Exponential Moving Average is very important, as is early stopping. This is also in particualr best run
    on a GPU due to the large number of parameters and batch size involved.
    """
    out = get_output_name_from_cli()

    train_params = TrainParams(SerializableOptimizer(
        "Adam", dict(learning_rate=0.001)),
                               num_epochs=12,
                               ema=0.999,
                               async_encoding=10,
                               log_period=30,
                               eval_period=1000,
                               save_period=1000,
                               eval_samples=dict(dev=None, train=8000))

    # recurrent_layer = BiRecurrentMapper(LstmCellSpec(100, keep_probs=0.8))
    # recurrent_layer = FusedLstm()
    recurrent_layer = SequenceMapperSeq(DropoutLayer(0.8), CudnnLstm(100))

    model = Attention(
        encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()),
        word_embed=FixedWordEmbedder(vec_name="glove.6B.100d",
                                     word_vec_init_scale=0,
                                     learn_unk=False),
        char_embed=CharWordEmbedder(embedder=LearnedCharEmbedder(16, 49, 8),
                                    layer=ReduceLayer("max",
                                                      Conv1d(100, 5, 0.8),
                                                      mask=False),
                                    shared_parameters=True),
        word_embed_layer=None,
        embed_mapper=SequenceMapperSeq(HighwayLayer(activation="relu"),
                                       HighwayLayer(activation="relu"),
                                       recurrent_layer),
        preprocess=None,
        question_mapper=None,
        context_mapper=None,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=NullMapper(),
        predictor=BoundsPredictor(
            ChainConcat(start_layer=SequenceMapperSeq(recurrent_layer,
                                                      recurrent_layer),
                        end_layer=recurrent_layer)),
    )

    with open(__file__, "r") as f:
        notes = f.read()

    eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")]

    corpus = SquadCorpus()
    train_batching = ClusteredBatcher(60, ContextLenBucketedKey(3), True,
                                      False)
    eval_batching = ClusteredBatcher(60, ContextLenKey(), False, False)
    data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching)

    trainer.start_training(data, model, train_params, eval,
                           model_dir.ModelDir(out), notes)