예제 #1
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings), None
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings), None
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings), None
    else:
        # "rnn" or "brnn"
        word_encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings, None, opt.bridge)
        if opt.rnn_type == "LSTM":
            emb_size = opt.enc_layers * opt.rnn_size * 2
        else:
            emb_size = opt.enc_layers * opt.rnn_size
        sen_encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.sen_enc_layers,
                          opt.sen_rnn_size, opt.dropout, None, emb_size, opt.bridge)
        return word_encoder, sen_encoder
예제 #2
0
 def __init__(self, rnn_type, bidirectional, num_layers,
              hidden_size, dropout=0.0, embeddings=None,
              use_bridge=False, n_memory_layers=3):
     super(ReDREncoder, self).__init__()
     self.reference_encoder = RNNEncoder(
             rnn_type, bidirectional, num_layers, hidden_size, dropout, embeddings, use_bridge)
     self.history_encoder = RNNEncoder(
             rnn_type, bidirectional, num_layers, hidden_size, dropout, embeddings, use_bridge)
     self.redr_layer = ReDRLayer(hidden_size, rnn_type, num_layers=num_layers, n_memory_layers=n_memory_layers)
예제 #3
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers,
                                  opt.enc_rnn_size,
                                  opt.heads,
                                  opt.transformer_ff,
                                  opt.dropout,
                                  embeddings,
                                  ablation=opt.ablation)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #4
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        logger.info('RNNEncoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.enc_layers,
                     opt.enc_rnn_size, opt.dropout))
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #5
0
def build_encoder(opt, embeddings, embeddings_latt = False, feat_vec_size = 512):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
        embeddings_latt: embeddings of senses if lattice is used    # latt!!!
        feat_vec_size: for adaptable feat_vec_size             # latt!!!
    """
#latt
    if opt.encoder_type == "transformer":
        if embeddings_latt != False:  #latt
            return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                      opt.heads, opt.transformer_ff,
                                      opt.dropout, embeddings, embeddings_latt, feat_vec_size) #latt
        else:
            return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                      opt.heads, opt.transformer_ff,
                                      opt.dropout, embeddings, embeddings_latt, feat_vec_size) #latt
#latt
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #6
0
def build_encoder(opt, embeddings, main_encoder=None):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        encoder = TransformerEncoder(opt.enc_layers,
                                     opt.enc_rnn_size,
                                     opt.heads,
                                     opt.transformer_ff,
                                     opt.dropout,
                                     embeddings,
                                     main_encoder=main_encoder,
                                     mtl_opt=opt)
    elif opt.encoder_type == "cnn":
        encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                             opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        encoder = MeanEncoder(opt.enc_layers, embeddings)
    else:
        encoder = RNNEncoder(opt.rnn_type,
                             opt.brnn,
                             opt.enc_layers,
                             opt.enc_rnn_size,
                             opt.dropout,
                             embeddings,
                             opt.bridge,
                             main_encoder=main_encoder,
                             mtl_opt=opt)
    return encoder
예제 #7
0
    def __init__(self,
                 rnn_type,
                 encoder_type,
                 passage_enc_layers,
                 qa_word_enc_layers,
                 qa_sent_enc_layers,
                 hidden_size,
                 dropout=0.0,
                 embeddings=None,
                 self_attn=1):
        super(MemoryEncoder, self).__init__()

        assert embeddings is not None
        self.embeddings = embeddings

        self.rnn_type = rnn_type

        bidirectional = True if encoder_type == 'brnn' else False

        # passage
        self.passage_encoder = RNNEncoder(rnn_type, bidirectional,
                                          passage_enc_layers, hidden_size,
                                          dropout, embeddings.embedding_size)
        # qa history
        qa_word_dropout = dropout if qa_word_enc_layers > 1 else 0.0
        qa_sent_dropout = dropout if qa_sent_enc_layers > 1 else 0.0
        self.qa_word_encoder = RNNEncoder(
            rnn_type, bidirectional, qa_word_enc_layers, hidden_size,
            qa_word_dropout,
            embeddings.word_vec_size)  # here the qa history has feature
        # # here for utterance level modeling, we only use unidirectional rnn
        sent_brnn = False
        self.qa_sent_encoder = RNNEncoder(rnn_type, sent_brnn,
                                          qa_sent_enc_layers, hidden_size,
                                          qa_sent_dropout, hidden_size)

        self.self_attn = self_attn
        if self.self_attn:
            # weight for self attention
            self.selfattn_ws = nn.Linear(hidden_size, hidden_size, bias=False)
            self.selfattn_wf = nn.Linear(hidden_size * 2, hidden_size)
            self.selfattn_wg = nn.Linear(hidden_size * 2, hidden_size)
            self.softmax = nn.Softmax(dim=-1)
 def __init__(self, rnn_type, bidirectional, num_layers,
              hidden_size, dropout=0.0, embeddings=None, content_selection_attn_hidden=None):
     super(MacroPlanEncoder, self).__init__()
     assert embeddings is not None
     self.rnn_encoder = RNNEncoder(rnn_type, bidirectional, num_layers, hidden_size, dropout, embeddings)
     self.attn = GlobalAttentionContext(hidden_size, attn_type="general")
     self.content_selection_attn = GlobalSelfAttention(hidden_size, attn_type="general",
                                                       attn_hidden=content_selection_attn_hidden)
     self.num_layers = num_layers
     self.hidden_size = hidden_size
예제 #9
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                         opt.enc_rnn_size, opt.dropout, embeddings, opt.bridge)
    return encoder
예제 #10
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "rnn" or opt.encoder_type == "brnn":
        # "rnn" or "brnn"
        logger.info("opt.encoder_type " + opt.encoder_type + " opt.brnn " +
                    str(opt.brnn))
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers, opt.rnn_size,
                          opt.dropout, embeddings, opt.bridge)
    else:
        raise ModuleNotFoundError("Unsupported model type")
예제 #11
0
    def __init__(self,
                 rnn_type,
                 encoder_type,
                 enc_layers,
                 hidden_size,
                 dropout=0.0,
                 embeddings=None):
        super(QueryEncoder, self).__init__()

        assert embeddings is not None
        self.embeddings = embeddings

        self.rnn_type = rnn_type

        bidirectional = True if encoder_type == 'brnn' else False

        self.src_encoder = RNNEncoder(rnn_type, bidirectional, enc_layers,
                                      hidden_size, dropout,
                                      embeddings.embedding_size)
        self.answer_encoder = MeanEncoder()
예제 #12
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        encoder = TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                     opt.heads, opt.transformer_ff,
                                     opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                             opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        encoder = MeanEncoder(opt.enc_layers, embeddings)
    elif opt.encoder_type == "hr_brnn":
        bi_enc = True
        encoder = HREncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                            opt.enc_rnn_size, opt.dropout, embeddings,
                            opt.bridge)
    elif opt.encoder_type == "seq_hr_brnn":
        bi_enc = True
        encoder = SeqHREncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                               opt.enc_rnn_size, opt.dropout, embeddings,
                               opt.bridge)
    elif opt.encoder_type == "tg_brnn":
        bi_enc = True
        encoder = TGEncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                            opt.enc_rnn_size, opt.dropout, embeddings)
    else:
        bi_enc = 'brnn' in opt.encoder_type
        encoder = RNNEncoder(opt.rnn_type,
                             bi_enc,
                             opt.enc_layers,
                             opt.enc_rnn_size,
                             opt.dropout,
                             embeddings,
                             opt.bridge,
                             use_catSeq_dp=opt.use_catSeq_dp)
    return encoder
예제 #13
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    elif opt.encoder_type == "rnntreelstm" or opt.encoder_type == "treelstm":
        opt.brnn = True if opt.encoder_type == "rnntreelstm" else False
        return TreeLSTMEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                        opt.rnn_size, opt.dropout, embeddings,
                        opt.bridge, False)
    elif opt.encoder_type == "rnnbitreelstm" or opt.encoder_type == "bitreelstm":
        opt.brnn = True if opt.encoder_type == "rnnbitreelstm" else False
        return TreeLSTMEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                        opt.rnn_size, opt.dropout, embeddings,
                        opt.bridge, True)    
    elif opt.encoder_type == "rnngcn" or opt.encoder_type == "gcn":
        opt.brnn = True if opt.encoder_type == "rnngcn" else False
        return GCNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge, opt.gcn_dropout, 
                          opt.gcn_edge_dropout, opt.n_gcn_layers, 
                          opt.activation, opt.highway)    
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #14
0
def build_encoder(opt, embeddings, fields=None):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == 'simple_context_0':
        # bottom n-1 layers are shared
        return SimpleContextTransformerEncoder(
                                  opt.enc_layers - 1, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings,
                                  selected_ctx=0)
    elif opt.encoder_type == 'simple_context_1':
        # bottom n-1 layers are shared
        return SimpleContextTransformerEncoder(
                                  opt.enc_layers - 1, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings,
                                  selected_ctx=1)

    elif opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #15
0
def build_model(model_opt, fields, gpu, checkpoint=None):
    if model_opt.rnn_size != -1:
        model_opt.enc_rnn_size = model_opt.rnn_size
        model_opt.dec_rnn_size = model_opt.rnn_size

    if model_opt.lang_vec_size is not None:
        lang_vec_size = model_opt.lang_vec_size
    else:
        lang_vec_size = model_opt.src_word_vec_size

    if model_opt.infl_vec_size is not None:
        infl_vec_size = model_opt.infl_vec_size
    else:
        infl_vec_size = model_opt.src_word_vec_size

    # Make atomic embedding modules (i.e. not multispace ones)
    lemma_character_embedding = build_embedding(fields["src"][0][1],
                                                model_opt.src_word_vec_size)
    inflected_character_embedding = build_embedding(
        fields["tgt"][0][1], model_opt.tgt_word_vec_size)
    if "inflection" in fields:
        inflection_embedding = build_embedding(fields["inflection"][0][1],
                                               infl_vec_size)
    else:
        inflection_embedding = None

    lang_field = fields["language"][0][1] if "language" in fields else None
    lang_embeddings = dict()
    if lang_field is not None and model_opt.lang_location is not None:
        lang_locations = set(model_opt.lang_location)
        if "tgt" in lang_locations:
            assert model_opt.lang_rep != "token", \
                "Can only use a feature for tgt language representation"
        for loc in lang_locations:
            lang_embeddings[loc] = build_embedding(lang_field, lang_vec_size)

    num_langs = len(lang_field.vocab) if "language" in fields else 1

    # Build the full, multispace embeddings
    encoder_embedding = MultispaceEmbedding(lemma_character_embedding,
                                            language=lang_embeddings.get(
                                                "src", None),
                                            mode=model_opt.lang_rep)

    decoder_embedding = MultispaceEmbedding(
        inflected_character_embedding,
        language=lang_embeddings.get("tgt", None),
        mode=model_opt.lang_rep  # only 'feature' should be allowed here
    )

    if inflection_embedding is not None:
        inflection_embedding = MultispaceEmbedding(
            inflection_embedding,
            language=lang_embeddings.get("inflection", None),
            mode=model_opt.lang_rep)
        if model_opt.inflection_rnn:
            if not hasattr(model_opt, "inflection_rnn_layers"):
                model_opt.inflection_rnn_layers = 1
            inflection_embedding = InflectionLSTMEncoder(
                inflection_embedding,
                model_opt.dec_rnn_size,  # need to think about this
                num_layers=model_opt.inflection_rnn_layers,
                dropout=model_opt.dropout,
                bidirectional=model_opt.brnn)
            #,
            #fused_msd = model_opt.fused_msd)

    # Build encoder
    if model_opt.brnn:
        assert model_opt.enc_rnn_size % 2 == 0
        hidden_size = model_opt.enc_rnn_size // 2
    else:
        hidden_size = model_opt.enc_rnn_size

    enc_rnn = getattr(nn, model_opt.rnn_type)(
        input_size=encoder_embedding.embedding_dim,
        hidden_size=hidden_size,
        num_layers=model_opt.enc_layers,
        dropout=model_opt.dropout,
        bidirectional=model_opt.brnn)
    encoder = RNNEncoder(encoder_embedding, enc_rnn)

    # Build decoder.
    attn_dim = model_opt.dec_rnn_size
    if model_opt.inflection_attention:
        if model_opt.inflection_gate is not None:
            if model_opt.separate_outputs:
                # if model_opt.separate_heads:
                #     attn = UnsharedGatedFourHeadedAttention.from_options(
                #         attn_dim,
                #         attn_type=model_opt.global_attention,
                #         attn_func=model_opt.global_attention_function,
                #         gate_func=model_opt.inflection_gate
                # )
                # else:
                attn = UnsharedGatedTwoHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate)
            elif model_opt.global_gate_heads:
                attn = DecoderUnsharedGatedFourHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate,
                    combine_gate_input=model_opt.global_gate_head_combine,
                    n_global_heads=model_opt.global_gate_heads_number,
                    infl_attn_func=model_opt.infl_attention_function
                    if hasattr(model_opt, 'infl_attention_function') else None)
                global_head = GlobalGateTowHeadAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    n_global_heads=model_opt.global_gate_heads_number,
                    tahn_transform=model_opt.global_gate_head_nonlin)
            # elif model_opt.global_gate_heads_mix:
            #         attn = DecoderUnsharedGatedFourHeadedAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             gate_func=model_opt.inflection_gate,
            #             combine_gate_input=model_opt.global_gate_head_combine,
            #             n_global_heads=model_opt.global_gate_heads_number * 2,
            #             infl_attn_func=model_opt.infl_attention_function if hasattr(model_opt, 'infl_attention_function') else None
            #         )
            #         global_head_subw = GlobalGateTowHeadAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             n_global_heads=model_opt.global_gate_heads_number,
            #             tahn_transform=model_opt.global_gate_head_nonlin,
            #             att_name='gate_lemma_subw_global_'
            #         )
            #         global_head_char = GlobalGateTowHeadAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             n_global_heads=model_opt.global_gate_heads_number,
            #             tahn_transform=model_opt.global_gate_head_nonlin,
            #             att_name='gate_lemma_char_global_'
            #         )
            else:
                attn = GatedTwoHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate)
        else:
            attn = TwoHeadedAttention.from_options(
                attn_dim,
                attn_type=model_opt.global_attention,
                attn_func=model_opt.global_attention_function)
    else:
        attn = Attention.from_options(
            attn_dim,
            attn_type=model_opt.global_attention,
            attn_func=model_opt.global_attention_function)

    dec_input_size = decoder_embedding.embedding_dim
    if model_opt.input_feed:
        dec_input_size += model_opt.dec_rnn_size
        stacked_cell = StackedLSTM if model_opt.rnn_type == "LSTM" \
            else StackedGRU
        dec_rnn = stacked_cell(model_opt.dec_layers, dec_input_size,
                               model_opt.rnn_size, model_opt.dropout)
    else:
        dec_rnn = getattr(nn, model_opt.rnn_type)(
            input_size=dec_input_size,
            hidden_size=model_opt.dec_rnn_size,
            num_layers=model_opt.dec_layers,
            dropout=model_opt.dropout)

    #dec_class = InputFeedRNNDecoder if model_opt.input_feed else RNNDecoder
    if model_opt.input_feed:
        decoder = InputFeedRNNDecoder(decoder_embedding,
                                      dec_rnn,
                                      attn,
                                      dropout=model_opt.dropout)
        # if model_opt.global_gate_heads:
        #     decoder = InputFeedRNNDecoderGlobalHead(
        #         decoder_embedding, dec_rnn, attn, global_head, dropout=model_opt.dropout
        #         )
    else:
        decoder = RNNDecoder(decoder_embedding,
                             dec_rnn,
                             attn,
                             dropout=model_opt.dropout)
    #decoder = dec_class(
    #    decoder_embedding, dec_rnn, attn, dropout=model_opt.dropout
    #)

    if model_opt.out_bias == 'multi':
        bias_vectors = num_langs
    elif model_opt.out_bias == 'single':
        bias_vectors = 1
    else:
        bias_vectors = 0
    if model_opt.loss == "sparsemax":
        output_transform = LogSparsemax(dim=-1)
        # if model_opt.global_attention_function == "sparsemax":
        #     #output_transform = LogSparsemax(dim=-1)
        #     output_transform = Sparsemax(dim=-1) # adpation, log is applied within OutputLayer class
        # elif model_opt.global_attention_function == "fusedmax":
        #     output_transform = Fusedmax()
        # elif model_opt.global_attention_function == "oscarmax":
        #     output_transform = Oscarmax()
        # else:
        #     print('Unknown global_attention:', model_opt.global_attention)
    else:
        output_transform = nn.LogSoftmax(dim=-1)
    generator = OutputLayer(model_opt.dec_rnn_size,
                            decoder_embedding.num_embeddings, output_transform,
                            bias_vectors)
    if model_opt.share_decoder_embeddings:
        generator.weight_matrix.weight = decoder.embeddings["main"].weight

    if model_opt.inflection_attention:
        if model_opt.global_gate_heads:
            model = onmt.models.model.InflectionGGHAttentionModel(
                encoder, inflection_embedding, global_head, decoder, generator)
        # elif model_opt.global_gate_heads_mix:
        #     model = onmt.models.model.InflectionGGHMixedAttentionModel(
        #     encoder, inflection_embedding, global_head_subw, global_head_char, decoder, generator)
        else:
            model = onmt.models.model.InflectionAttentionModel(
                encoder, inflection_embedding, decoder, generator)
    else:
        model = onmt.models.NMTModel(encoder, decoder, generator)

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
    elif model_opt.param_init != 0.0:
        for p in model.parameters():
            p.data.uniform_(-model_opt.param_init, model_opt.param_init)

    device = torch.device("cuda" if gpu else "cpu")
    model.to(device)

    return model
예제 #16
0
def build_base_model(model_opt, fields, gpu, checkpoint=None):
    """
    Args:
        model_opt: the option loaded from checkpoint.
        fields: `Field` objects for the model.
        gpu(bool): whether to use gpu.
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
    Returns:
        the NMTModel.
    """
    assert model_opt.model_type in ["text", "img", "audio", "vector"], \
        ("Unsupported model type %s" % (model_opt.model_type))

    use_src_directly_for_dec = False
    # Build encoder.
    if model_opt.model_type == "text":
        src_dict = fields["src"].vocab
        feature_dicts = inputters.collect_feature_vocabs(fields, 'src')
        src_embeddings = build_embeddings(model_opt, src_dict, feature_dicts)
        encoder = build_encoder(model_opt, src_embeddings)
    elif model_opt.model_type == "img":
        if ("image_channel_size" not in model_opt.__dict__):
            image_channel_size = 3
        else:
            image_channel_size = model_opt.image_channel_size

        encoder = ImageEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.enc_rnn_size, model_opt.dropout,
                               image_channel_size)
    elif model_opt.model_type == "audio":
        encoder = AudioEncoder(model_opt.rnn_type, model_opt.enc_layers,
                               model_opt.dec_layers, model_opt.brnn,
                               model_opt.enc_rnn_size, model_opt.dec_rnn_size,
                               model_opt.audio_enc_pooling, model_opt.dropout,
                               model_opt.sample_rate, model_opt.window_size)
    elif model_opt.model_type == "vector":
        use_src_directly_for_dec = True
        if not hasattr(fields["src"], 'vocab'):
            fields["src"].vocab = fields["tgt"].vocab
        src_dict = fields["src"].vocab
        #self.word_lut.weight.requires_grad = False
        feature_dicts = inputters.collect_feature_vocabs(fields, 'src')
        tgt_embeddings = build_embeddings(model_opt, src_dict, feature_dicts)
        if model_opt.encoder_type == "rnn" or model_opt.encoder_type == "brnn":
            encoder = RNNEncoder(model_opt.rnn_type, model_opt.brnn,
                                 model_opt.enc_layers, model_opt.enc_rnn_size,
                                 model_opt.dropout, None, model_opt.bridge)
            tgt_embeddings = None
        elif model_opt.decoder_type == "cnn":
            use_src_directly_for_dec = False
            encoder = CNNEncoder(model_opt.enc_layers, model_opt.enc_rnn_size,
                                 model_opt.cnn_kernel_width, model_opt.dropout,
                                 None)
            tgt_embeddings = None
        else:
            encoder = None

    # Build decoder.
    tgt_dict = fields["tgt"].vocab
    feature_dicts = inputters.collect_feature_vocabs(fields, 'tgt')
    if model_opt.model_type != "vector":
        tgt_embeddings = build_embeddings(model_opt,
                                          tgt_dict,
                                          feature_dicts,
                                          for_encoder=False)
    # else:
    #     tgt_embeddings = None

    # Share the embedding matrix - preprocess with share_vocab required.
    if model_opt.share_embeddings:
        # src/tgt vocab should be the same if `-share_vocab` is specified.
        if src_dict != tgt_dict:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')

        tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight

    decoder = build_decoder(model_opt, tgt_embeddings)

    # Build NMTModel(= encoder + decoder).
    device = torch.device("cuda" if gpu else "cpu")
    if model_opt.decoder_type.startswith("vecdif"):
        model = onmt.models.VecModel(
            encoder,
            decoder,
            use_src_directly_for_dec=use_src_directly_for_dec)
    else:
        model = onmt.models.NMTModel(
            encoder,
            decoder,
            use_src_directly_for_dec=use_src_directly_for_dec)

    # Build Generator.
    if not model_opt.copy_attn:
        if model_opt.generator_function == "sparsemax":
            gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
        elif model_opt.generator_function == "sigmoid":
            gen_func = nn.Sigmoid()
        else:
            gen_func = nn.LogSoftmax(dim=-1)
        if model_opt.model_type == "vector":
            if model_opt.generator_function == "none":
                # if model_opt.final_vec_size != model_opt.dec_rnn_size:
                #     generator = nn.Sequential(
                #         nn.Linear(model_opt.dec_rnn_size, model_opt.final_vec_size))
                # else:
                generator = None
            else:
                generator = nn.Sequential(
                    nn.Linear(model_opt.dec_rnn_size,
                              model_opt.final_vec_size), gen_func)
        else:
            generator = nn.Sequential(
                nn.Linear(model_opt.dec_rnn_size, len(fields["tgt"].vocab)),
                gen_func)
        if model_opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight
    else:
        generator = CopyGenerator(model_opt.dec_rnn_size, fields["tgt"].vocab)

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
        if generator is not None:
            generator.load_state_dict(checkpoint['generator'], strict=False)
    else:
        if model_opt.param_init != 0.0:
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
            if generator is not None:
                for p in generator.parameters():
                    p.data.uniform_(-model_opt.param_init,
                                    model_opt.param_init)
        if model_opt.param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            if generator is not None:
                for p in generator.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        if hasattr(model.encoder,
                   'embeddings') and model_opt.model_type != "vector":
            model.encoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc)
        if hasattr(model.decoder,
                   'embeddings') and model_opt.model_type != "vector":
            model.decoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)

    # Add generator to model (this registers it as parameter of model).
    model.generator = generator
    model.to(device)

    return model