コード例 #1
0
ファイル: model_builder.py プロジェクト: zoudajia/rencos
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size, opt.heads,
                                  opt.transformer_ff, opt.global_attention,
                                  opt.copy_attn, opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.dec_rnn_size, opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.dec_rnn_size, opt.global_attention,
                             opt.global_attention_function, opt.coverage_attn,
                             opt.context_gate, opt.copy_attn, opt.dropout,
                             embeddings, opt.reuse_copy_attn)
コード例 #2
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        assert opt.key_model in ["key_generator", "key_end2end"]
        if opt.key_model == "key_generator":
            return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                       opt.dec_layers, opt.rnn_size,
                                       opt.global_attention,
                                       opt.coverage_attn,
                                       opt.context_gate,
                                       opt.copy_attn,
                                       opt.dropout,
                                       embeddings,
                                       opt.reuse_copy_attn,
                                       no_sftmax_bf_rescale=opt.no_sftmax_bf_rescale,
                                       use_retrieved_keys=opt.use_retrieved_keys)
        else:
            return MyInputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                         opt.dec_layers, opt.rnn_size,
                                         opt.global_attention,
                                         opt.coverage_attn,
                                         opt.context_gate,
                                         opt.copy_attn,
                                         opt.dropout,
                                         embeddings,
                                         opt.reuse_copy_attn,
                                         not_use_sel_probs=opt.not_use_sel_probs,
                                         no_sftmax_bf_rescale=opt.no_sftmax_bf_rescale,
                                         use_retrieved_keys=opt.use_retrieved_keys,
                                         only_rescale_copy=opt.only_rescale_copy)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.rnn_size,
                             opt.global_attention,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
コード例 #3
0
ファイル: model_builder.py プロジェクト: yyht/ExHiRD-DKG
def build_decoder(opt,
                  embeddings,
                  eok_idx=None,
                  eos_idx=None,
                  pad_idx=None,
                  sep_idx=None,
                  p_end_idx=None,
                  a_end_idx=None,
                  position_enc=None,
                  position_enc_embsize=None):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        decoder = TransformerDecoder(opt.dec_layers, opt.dec_rnn_size,
                                     opt.heads, opt.transformer_ff,
                                     opt.global_attention, opt.copy_attn,
                                     opt.self_attn_type, opt.dropout,
                                     embeddings)
    elif opt.decoder_type == "cnn":
        decoder = CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                             opt.global_attention, opt.copy_attn,
                             opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.decoder_type == "hre_rnn":
        assert opt.input_feed
        bi_enc = 'brnn' in opt.encoder_type
        dec_class = HREInputFeedRNNDecoder
        decoder = dec_class(opt.rnn_type, bi_enc, opt.dec_layers,
                            opt.dec_rnn_size, opt.global_attention,
                            opt.global_attention_function, opt.coverage_attn,
                            opt.context_gate, opt.copy_attn, opt.dropout,
                            embeddings, opt.reuse_copy_attn, opt.hr_attn_type)
    elif opt.decoder_type == "seq_hre_rnn":
        assert opt.input_feed
        bi_enc = 'brnn' in opt.encoder_type
        dec_class = SeqHREInputFeedRNNDecoder
        decoder = dec_class(opt.rnn_type,
                            bi_enc,
                            opt.dec_layers,
                            opt.dec_rnn_size,
                            opt.global_attention,
                            opt.global_attention_function,
                            opt.coverage_attn,
                            opt.context_gate,
                            opt.copy_attn,
                            opt.dropout,
                            embeddings,
                            opt.reuse_copy_attn,
                            hr_attn_type=opt.hr_attn_type,
                            seqHRE_attn_rescale=opt.seqHRE_attn_rescale)
    elif opt.decoder_type == "hrd_rnn" or opt.decoder_type == "seq_hre_hrd_rnn":
        assert opt.input_feed
        # assert eok_idx is not None
        assert eos_idx is not None
        assert pad_idx is not None
        bi_enc = 'brnn' in opt.encoder_type
        hr_enc = 'hr' in opt.encoder_type
        seqhr_enc = opt.encoder_type == "seq_hr_brnn"
        dec_class = HRDInputFeedRNNDecoder
        decoder = dec_class(
            opt.rnn_type,
            bi_enc,
            opt.dec_layers,
            opt.dec_rnn_size,
            opt.global_attention,
            opt.global_attention_function,
            opt.coverage_attn,
            opt.context_gate,
            opt.copy_attn,
            opt.dropout,
            embeddings,
            opt.reuse_copy_attn,
            hr_attn_type=opt.hr_attn_type,
            word_dec_init_type=opt.word_dec_init_type,
            remove_input_feed_w=opt.remove_input_feed_w,
            input_feed_w_type=opt.sent_dec_input_feed_w_type,
            hr_enc=hr_enc,
            seqhr_enc=seqhr_enc,
            seqE_HRD_rescale_attn=opt.seqE_HRD_rescale_attn,
            seqHRE_attn_rescale=opt.seqHRE_attn_rescale,
            use_zero_s_emb=opt.use_zero_s_emb,
            not_detach_coverage=opt.not_detach_coverage,
            eok_idx=eok_idx,
            eos_idx=eos_idx,
            pad_idx=pad_idx,
            sep_idx=sep_idx,
            p_end_idx=p_end_idx,
            a_end_idx=a_end_idx,
            position_enc=position_enc,
            position_enc_word_init=opt.use_position_enc_word_init_state,
            position_enc_sent_feed_w=opt.use_position_enc_sent_input_feed_w,
            position_enc_first_word_feed=opt.use_position_enc_first_word_feed,
            position_enc_embsize=position_enc_embsize,
            position_enc_start_token=opt.use_opsition_enc_start_token,
            position_enc_sent_state=opt.use_position_enc_sent_state,
            position_enc_all_first_valid_word_dec_inputs=opt.
            use_position_enc_first_valid_word_dec_inputs,
            sent_dec_init_type=opt.sent_dec_init_type,
            remove_input_feed_h=opt.remove_input_feed_h,
            detach_input_feed_w=opt.detach_input_feed_w,
            use_target_encoder=opt.use_target_encoder,
            src_states_capacity=opt.src_states_capacity,
            src_states_sample_size=opt.src_states_sample_size)
    elif opt.decoder_type == "CatSeqD_rnn":
        dec_class = CatSeqDInputFeedRNNDecoder
        bi_enc = 'brnn' in opt.encoder_type
        decoder = dec_class(opt.rnn_type,
                            bi_enc,
                            opt.dec_layers,
                            opt.dec_rnn_size,
                            opt.global_attention,
                            opt.global_attention_function,
                            opt.coverage_attn,
                            opt.context_gate,
                            opt.copy_attn,
                            opt.dropout,
                            embeddings,
                            opt.reuse_copy_attn,
                            sep_idx=sep_idx,
                            use_target_encoder=opt.use_target_encoder,
                            target_hidden_size=opt.target_hidden_size,
                            src_states_capacity=opt.src_states_capacity,
                            src_states_sample_size=opt.src_states_sample_size,
                            use_catSeq_dp=opt.use_catSeq_dp)
    elif opt.decoder_type == "CatSeqCorr_rnn":
        assert opt.input_feed
        dec_class = CatSeqCorrInputFeedRNNDecoder
        bi_enc = 'brnn' in opt.encoder_type
        decoder = dec_class(opt.rnn_type, bi_enc, opt.dec_layers,
                            opt.dec_rnn_size, opt.global_attention,
                            opt.global_attention_function, opt.coverage_attn,
                            opt.context_gate, opt.copy_attn, opt.dropout,
                            embeddings, opt.reuse_copy_attn)
    else:
        assert opt.input_feed
        dec_class = InputFeedRNNDecoder if opt.input_feed else StdRNNDecoder
        bi_enc = 'brnn' in opt.encoder_type
        decoder = dec_class(opt.rnn_type,
                            bi_enc,
                            opt.dec_layers,
                            opt.dec_rnn_size,
                            opt.global_attention,
                            opt.global_attention_function,
                            opt.coverage_attn,
                            opt.context_gate,
                            opt.copy_attn,
                            opt.dropout,
                            embeddings,
                            opt.reuse_copy_attn,
                            use_catSeq_dp=opt.use_catSeq_dp)

    return decoder
コード例 #4
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """

    if opt.decoder_type == "transformer":
        logger.info('TransformerDecoder: layers %d, input dim %d, '
                    'fat relu hidden dim %d, num heads %d, %s global attn, '
                    'copy attn %d, self attn type %s, dropout %.2f' %
                    (opt.dec_layers, opt.dec_rnn_size, opt.transformer_ff,
                     opt.heads, opt.global_attention, opt.copy_attn,
                     opt.self_attn_type, opt.dropout))

        # dec_rnn_size   = dimension of keys/values/queries (input to FF)
        # transformer_ff = dimension of fat relu
        return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        logger.info('InputFeedRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.dec_rnn_size,
                                   opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        logger.info('StdRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.dec_rnn_size,
                             opt.global_attention,
                             opt.global_attention_function,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)