def get_sentences_model(rnn_dim, use_elmo, keep_rate=0.8):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    embed_mapper = SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer)

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_wiki_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=False, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    return SingleContextMaxSentenceModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder, use_sentence_segments=True, paragraph_as_sentence=True),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        sequence_encoder=MaxPool(map_layer=None, min_val=VERY_NEGATIVE_NUMBER, regular_reshape=True),
        sentences_encoder=SentenceMaxEncoder(),
        post_merger=None,
        merger=WithConcatOptions(sub=False, hadamard=True, dot=True, raw=True),
        max_batch_size=256
    )
def get_context_with_bottleneck_to_question_model(rnn_dim: int, q2c: bool, res_rnn: bool, res_self_att: bool):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    res_model = get_res_fc_seq_fc(model_rnn_dim=rnn_dim, rnn=res_rnn, self_att=res_self_att)

    question_to_context = \
        AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model) if q2c else None
    context_to_question = AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model)

    return SingleContextWithBottleneckToQuestionModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        question_to_context_attention=question_to_context,
        context_to_question_attention=context_to_question,
        sequence_encoder=CudnnGruEncoder(rnn_dim, w_init=TruncatedNormal(stddev=0.05)),
        # MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        rep_merge=ConcatLayer(),
        predictor=BinaryFixedPredictor()
    )
def get_multi_encode_softmax_weighting_model(rnn_dim, multi_rnn_dim, num_encodings, keep_rate=0.8, map_embed=True):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    multi_recurrent_layer = CudnnGru(multi_rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    return SingleContextMultipleEncodingWeightedSoftmaxModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(keep_rate),
            recurrent_layer,
        ) if map_embed else None,
        sequence_multi_encoder=MultiMapThenEncode(
            mapper=SequenceMapperSeq(VariationalDropoutLayer(keep_rate), multi_recurrent_layer),
            encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
            num_encodings=num_encodings
        ),
        weight_layer=MultiEncodingWeights(weight_mode='mlp'),
        merger=ConcatWithProduct(),
        post_merger=None,
        predictor=BinaryWeightedMultipleFixedPredictor()
    )
def get_model(rnn_dim):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    return ContextPairRelevanceModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d",
                                     word_vec_init_scale=0,
                                     learn_unk=False,
                                     cpu=True),
        char_embed=CharWordEmbedder(LearnedCharEmbedder(word_size_th=14,
                                                        char_th=50,
                                                        char_dim=20,
                                                        init_scale=0.05,
                                                        force_cpu=True),
                                    MaxPool(Conv1d(100, 5, 0.8)),
                                    shared_parameters=True),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            # VariationalDropoutLayer(0.8),  # fixme probably doesn't belong here
        ),
        question_to_context_attention=None,
        context_to_context_attention=None,
        context_to_question_attention=None,
        sequence_encoder=MaxPool(map_layer=None,
                                 min_val=0,
                                 regular_reshape=True),
        merger=MergeTwoContextsConcatQuestion(),
        predictor=BinaryFixedPredictor())
def get_multi_hop_model(rnn_dim, c2c: bool, q2c: bool, res_rnn: bool,
                        res_self_att: bool, post_merge: bool, encoder: str,
                        merge_type: str, num_c2c_hops: int):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    res_model = get_res_fc_seq_fc(model_rnn_dim=rnn_dim,
                                  rnn=res_rnn,
                                  self_att=res_self_att)
    context_to_context = \
        AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model) if c2c else None
    question_to_context = \
        AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model) if q2c else None

    if encoder == 'max':
        sequence_encoder = MaxPool(map_layer=None,
                                   min_val=0,
                                   regular_reshape=True)
    elif encoder == 'rnn':
        sequence_encoder = CudnnGruEncoder(rnn_dim,
                                           w_init=TruncatedNormal(stddev=0.05))
    else:
        raise NotImplementedError()

    if merge_type == 'max':
        attention_merger = MaxMerge(
            pre_map_layer=None,
            post_map_layer=(res_model if post_merge else None))
    else:
        attention_merger = WeightedMerge(
            pre_map_layer=None,
            post_map_layer=(res_model if post_merge else None),
            weight_type=merge_type)

    return MultiHopContextsToQuestionModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d",
                                     word_vec_init_scale=0,
                                     learn_unk=False,
                                     cpu=True),
        char_embed=CharWordEmbedder(LearnedCharEmbedder(word_size_th=14,
                                                        char_th=50,
                                                        char_dim=20,
                                                        init_scale=0.05,
                                                        force_cpu=True),
                                    MaxPool(Conv1d(100, 5, 0.8)),
                                    shared_parameters=True),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        question_to_context_attention=question_to_context,
        context_to_context_attention=context_to_context,
        c2c_hops=num_c2c_hops,
        context_to_question_attention=BiAttention(TriLinear(bias=True), True),
        attention_merger=attention_merger,
        sequence_encoder=sequence_encoder,
        predictor=BinaryFixedPredictor())
def get_reread_model(rnn_dim, use_elmo, encoder_keep_rate=0.8, reread_keep_rate=0.8,
                     two_phase_att=False, res_rnn=True, res_self_att=False,
                     multiply_iteration_probs=False, reformulate_by_context=False,
                     rank_first=False, rank_second=False, reread_rnn_dim=None,
                     first_rank_lambda=1.0, second_rank_lambda=1.0,
                     ranking_gamma=1.0):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = IterativeAnswerEncoder(group=rank_first or rank_second)

    embed_mapper = SequenceMapperSeq(VariationalDropoutLayer(encoder_keep_rate), recurrent_layer)

    if res_rnn or res_self_att:
        res_model = get_res_fc_seq_fc(model_rnn_dim=rnn_dim, rnn=res_rnn, self_att=res_self_att,
                                      keep_rate=reread_keep_rate)
    else:
        res_model = FullyConnected(rnn_dim * 2, activation="relu")
    attention = AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True),
                                        post_mapper=res_model)
    use_c2q = two_phase_att or not reformulate_by_context
    use_q2c = two_phase_att or reformulate_by_context

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_hotpot_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=False, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    return IterativeContextReReadModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder, use_sentence_segments=True, paragraph_as_sentence=True),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        sequence_encoder=MaxPool(map_layer=None, min_val=VERY_NEGATIVE_NUMBER, regular_reshape=True),
        sentences_encoder=SentenceMaxEncoder(),
        sentence_mapper=None,
        post_merger=None,
        merger=WithConcatOptions(sub=False, hadamard=True, dot=True, raw=True),
        reread_mapper=None if reread_rnn_dim is None else CudnnGru(reread_rnn_dim, w_init=TruncatedNormal(stddev=0.05)),
        pre_attention_mapper=None,  # VariationalDropoutLayer(reread_keep_rate),
        context_to_question_attention=attention if use_c2q else None,
        question_to_context_attention=attention if use_q2c else None,
        reformulate_by_context=reformulate_by_context,
        multiply_iteration_probs=multiply_iteration_probs,
        first_predictor=BinaryNullPredictor(rank_first, ranking_lambda=first_rank_lambda, gamma=ranking_gamma),
        second_predictor=BinaryNullPredictor(rank_second, ranking_lambda=second_rank_lambda, gamma=ranking_gamma),
        max_batch_size=512
    )
Exemple #7
0
def get_model_with_yes_no(rnn_dim: int, use_elmo, keep_rate=0.8):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))

    embed_mapper = SequenceMapperSeq(
        VariationalDropoutLayer(keep_rate),
        recurrent_layer,
        VariationalDropoutLayer(keep_rate),
    )

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_hotpot_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=False, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    answer_encoder = GroupedSpanAnswerEncoderWithYesNo(group=True)
    predictor = BoundsPredictor(
        ChainBiMapper(
            first_layer=recurrent_layer,
            second_layer=recurrent_layer
        ),
        span_predictor=IndependentBoundsGroupedWithYesNo()
    )

    return AttentionQAWithYesNo(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder, use_sentence_segments=False),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        question_mapper=None,
        context_mapper=None,
        memory_builder=NullBiMapper(),
        attention=BiAttention(TriLinear(bias=True), True),
        match_encoder=SequenceMapperSeq(FullyConnected(rnn_dim * 2, activation="relu"),
                                        ResidualLayer(SequenceMapperSeq(
                                            VariationalDropoutLayer(keep_rate),
                                            recurrent_layer,
                                            VariationalDropoutLayer(keep_rate),
                                            StaticAttentionSelf(TriLinear(bias=True), ConcatWithProduct()),
                                            FullyConnected(rnn_dim * 2, activation="relu"),
                                        )),
                                        VariationalDropoutLayer(keep_rate)),
        predictor=predictor,
        yes_no_question_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        yes_no_context_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True)
    )
def get_reread_merge_model(rnn_dim, use_elmo, keep_rate=0.8, res_rnn=True, res_self_att=False,
                           multiply_iteration_probs=False):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = IterativeAnswerEncoder()

    embed_mapper = SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer)

    if res_rnn or res_self_att:
        res_model = get_res_fc_seq_fc(model_rnn_dim=rnn_dim, rnn=res_rnn, self_att=res_self_att)
    else:
        res_model = FullyConnected(rnn_dim * 2, activation="relu")
    attention = AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model)

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_hotpot_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=False, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    return IterativeContextReReadMergeModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder, use_sentence_segments=True),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        sequence_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        sentences_encoder=SentenceMaxEncoder(),
        sentence_mapper=None,
        post_merger=None,
        merger=WithConcatOptions(sub=False, hadamard=True, dot=False, raw=True),
        context_to_question_attention=attention,
        question_to_context_attention=attention,
        reread_merger=ConcatWithProduct(),
        multiply_iteration_probs=multiply_iteration_probs,
        max_batch_size=128
    )
def get_model(rnn_dim, use_elmo, keep_rate=0.8):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = IterativeAnswerEncoder()

    embed_mapper = SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer)

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_hotpot_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=False, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    reformulation = ProjectMapEncodeReformulation(project_layer=None,
                                                  sequence_mapper=None,
                                                  encoder=CudnnGruEncoder(rnn_dim, w_init=TruncatedNormal(stddev=0.05)))
    # reformulation = WeightedSumThenProjectReformulation(rnn_dim*2, activation='relu')
    # reformulation = ProjectThenWeightedSumReformulation(rnn_dim*2, activation='relu')

    return IterativeContextMaxSentenceModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder, use_sentence_segments=True),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        sequence_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        sentences_encoder=SentenceMaxEncoder(),
        sentence_mapper=recurrent_layer,
        post_merger=None,
        merger=WithConcatOptions(sub=False, hadamard=True, dot=True, raw=True),
        reformulation_layer=reformulation,
        max_batch_size=128
    )
def get_basic_model(rnn_dim, post_merger_params: Optional[dict] = None, use_elmo=False, keep_rate=0.8):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    embed_mapper = SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer)
    # embed_mapper = SequenceMapperSeq(
    #         SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer),
    #         ResidualLayer(SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer)),
    #         ResidualLayer(SequenceMapperSeq(VariationalDropoutLayer(keep_rate), recurrent_layer))
    #     )

    elmo_model = None
    if use_elmo:
        print("Using Elmo!")
        elmo_model = get_squad_elmo()
        lm_reduce = MapperSeq(
            ElmoLayer(0, layer_norm=False, top_layer_only=False),
            DropoutLayer(0.5),
        )
        embed_mapper = ElmoWrapper(input_append=True, output_append=True, rnn_layer=embed_mapper, lm_reduce=lm_reduce)

    post_merger = None if post_merger_params is None else get_mlp(**post_merger_params)

    return BasicSingleContextAndQuestionIndependentModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        elmo_model=elmo_model,
        embed_mapper=embed_mapper,
        sequence_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        merger=ConcatWithProductSub(),
        post_merger=post_merger,
        predictor=BinaryFixedPredictor(sigmoid=True),
        max_batch_size=128
    )
def get_fixed_context_to_question(rnn_dim):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    return SingleFixedContextToQuestionModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        embed_mapper=SequenceMapperSeq(
            SequenceMapperSeq(VariationalDropoutLayer(0.8), recurrent_layer),
            ResidualLayer(SequenceMapperSeq(VariationalDropoutLayer(0.8), recurrent_layer)),
            ResidualLayer(SequenceMapperSeq(VariationalDropoutLayer(0.8), recurrent_layer))
        ),
        context_mapper=None,
        # ResidualLayer(
        #     SequenceMapperSeq(
        #         VariationalDropoutLayer(0.8),
        #         StaticAttentionSelf(TriLinear(bias=True), ConcatWithProduct()),
        #         FullyConnected(rnn_dim*2, activation=None))),
        context_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        question_mapper=None,
        # ResidualLayer(
        #     SequenceMapperSeq(
        #         VariationalDropoutLayer(0.8),
        #         StaticAttentionSelf(TriLinear(bias=True), ConcatWithProduct()),
        #         FullyConnected(rnn_dim*2, activation=None))),
        merger=WithConcatOptions(dot=True, sub=True, hadamard=True, raw=True, project=False),
        post_merger=SequenceMapperSeq(
            FullyConnected(rnn_dim * 2, activation='relu'),
            ResidualLayer(SequenceMapperSeq(VariationalDropoutLayer(0.8), recurrent_layer,
                                            FullyConnected(rnn_dim * 2, activation='relu')))
        ),
        final_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        predictor=BinaryFixedPredictor()
    )
def get_bottleneck_to_seq_model(rnn_dim, q2c: bool, res_rnn: bool, res_self_att: bool, seq_len=50):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    res_model = get_res_fc_seq_fc(model_rnn_dim=rnn_dim, rnn=res_rnn, self_att=res_self_att)

    question_to_context = \
        AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model) if q2c else None
    context_to_question = AttentionWithPostMapper(BiAttention(TriLinear(bias=True), True), post_mapper=res_model)

    sequence_generator = GenerativeRNN(tf.contrib.rnn.LSTMCell(num_units=rnn_dim,
                                                               initializer=tf.initializers.truncated_normal(
                                                                   stddev=0.05)),
                                       output_layer=FullyConnected(rnn_dim * 2, activation='relu'),
                                       vec_to_in=FullyConnected(rnn_dim * 2, activation='relu'),
                                       seq_len=seq_len, include_original_vec=False)

    return SingleContextBottleneckToSeqQuestionModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False, cpu=True),
        char_embed=CharWordEmbedder(
            LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=20, init_scale=0.05, force_cpu=True),
            MaxPool(Conv1d(100, 5, 0.8)),
            shared_parameters=True
        ),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
        ),
        sequence_generator=sequence_generator,
        pre_attention=VariationalDropoutLayer(0.8),
        question_to_context_attention=question_to_context,
        context_to_question_attention=context_to_question,
        sequence_encoder=MaxPool(map_layer=None, min_val=0, regular_reshape=True),
        predictor=BinaryFixedPredictor()
    )
def get_contexts_to_question_model(rnn_dim, post_merge):
    recurrent_layer = CudnnGru(rnn_dim, w_init=TruncatedNormal(stddev=0.05))
    answer_encoder = BinaryAnswerEncoder()

    if post_merge == 'res_rnn_self_att':
        post_map_layer = SequenceMapperSeq(
            FullyConnected(rnn_dim * 2, activation="relu"),
            ResidualLayer(
                SequenceMapperSeq(
                    VariationalDropoutLayer(0.8),
                    recurrent_layer,
                    VariationalDropoutLayer(0.8),
                    StaticAttentionSelf(TriLinear(bias=True),
                                        ConcatWithProduct()),
                    FullyConnected(rnn_dim * 2, activation="relu"),
                )))
    elif post_merge == 'res_rnn':
        post_map_layer = SequenceMapperSeq(
            FullyConnected(rnn_dim * 2, activation="relu"),
            ResidualLayer(
                SequenceMapperSeq(
                    VariationalDropoutLayer(0.8),
                    recurrent_layer,
                    FullyConnected(rnn_dim * 2, activation="relu"),
                )))
    elif post_merge == 'res_self_att':
        post_map_layer = SequenceMapperSeq(
            FullyConnected(rnn_dim * 2, activation="relu"),
            ResidualLayer(
                SequenceMapperSeq(
                    VariationalDropoutLayer(0.8),
                    StaticAttentionSelf(TriLinear(bias=True),
                                        ConcatWithProduct()),
                    FullyConnected(rnn_dim * 2, activation="relu"),
                )))
    else:
        raise NotImplementedError()

    return ContextsToQuestionModel(
        encoder=QuestionsAndParagraphsEncoder(answer_encoder),
        word_embed=FixedWordEmbedder(vec_name="glove.840B.300d",
                                     word_vec_init_scale=0,
                                     learn_unk=False,
                                     cpu=True),
        char_embed=CharWordEmbedder(LearnedCharEmbedder(word_size_th=14,
                                                        char_th=50,
                                                        char_dim=20,
                                                        init_scale=0.05,
                                                        force_cpu=True),
                                    MaxPool(Conv1d(100, 5, 0.8)),
                                    shared_parameters=True),
        embed_mapper=SequenceMapperSeq(
            VariationalDropoutLayer(0.8),
            recurrent_layer,
            VariationalDropoutLayer(0.8),
        ),
        attention_merger=MaxMerge(pre_map_layer=None,
                                  post_map_layer=post_map_layer),
        context_to_question_attention=BiAttention(TriLinear(bias=True), True),
        sequence_encoder=MaxPool(map_layer=None,
                                 min_val=0,
                                 regular_reshape=True),
        predictor=BinaryFixedPredictor())