Пример #1
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.rnn_size,
                                   opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.rnn_size,
                             opt.global_attention,
                             opt.global_attention_function,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
Пример #2
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "rnn":

        if opt.input_feed:
            return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                       opt.rnn_size, opt.global_attention,
                                       opt.coverage_attn, opt.context_gate,
                                       opt.copy_attn, opt.dropout, embeddings,
                                       opt.reuse_copy_attn)
        else:
            return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                 opt.rnn_size, opt.global_attention,
                                 opt.coverage_attn, opt.context_gate,
                                 opt.copy_attn, opt.dropout, embeddings,
                                 opt.reuse_copy_attn)
    else:
        raise ModuleNotFoundError("Decoder type not found")
Пример #3
0
def build_model(model_opt, fields, gpu, checkpoint=None):
    if model_opt.rnn_size != -1:
        model_opt.enc_rnn_size = model_opt.rnn_size
        model_opt.dec_rnn_size = model_opt.rnn_size

    if model_opt.lang_vec_size is not None:
        lang_vec_size = model_opt.lang_vec_size
    else:
        lang_vec_size = model_opt.src_word_vec_size

    if model_opt.infl_vec_size is not None:
        infl_vec_size = model_opt.infl_vec_size
    else:
        infl_vec_size = model_opt.src_word_vec_size

    # Make atomic embedding modules (i.e. not multispace ones)
    lemma_character_embedding = build_embedding(fields["src"][0][1],
                                                model_opt.src_word_vec_size)
    inflected_character_embedding = build_embedding(
        fields["tgt"][0][1], model_opt.tgt_word_vec_size)
    if "inflection" in fields:
        inflection_embedding = build_embedding(fields["inflection"][0][1],
                                               infl_vec_size)
    else:
        inflection_embedding = None

    lang_field = fields["language"][0][1] if "language" in fields else None
    lang_embeddings = dict()
    if lang_field is not None and model_opt.lang_location is not None:
        lang_locations = set(model_opt.lang_location)
        if "tgt" in lang_locations:
            assert model_opt.lang_rep != "token", \
                "Can only use a feature for tgt language representation"
        for loc in lang_locations:
            lang_embeddings[loc] = build_embedding(lang_field, lang_vec_size)

    num_langs = len(lang_field.vocab) if "language" in fields else 1

    # Build the full, multispace embeddings
    encoder_embedding = MultispaceEmbedding(lemma_character_embedding,
                                            language=lang_embeddings.get(
                                                "src", None),
                                            mode=model_opt.lang_rep)

    decoder_embedding = MultispaceEmbedding(
        inflected_character_embedding,
        language=lang_embeddings.get("tgt", None),
        mode=model_opt.lang_rep  # only 'feature' should be allowed here
    )

    if inflection_embedding is not None:
        inflection_embedding = MultispaceEmbedding(
            inflection_embedding,
            language=lang_embeddings.get("inflection", None),
            mode=model_opt.lang_rep)
        if model_opt.inflection_rnn:
            if not hasattr(model_opt, "inflection_rnn_layers"):
                model_opt.inflection_rnn_layers = 1
            inflection_embedding = InflectionLSTMEncoder(
                inflection_embedding,
                model_opt.dec_rnn_size,  # need to think about this
                num_layers=model_opt.inflection_rnn_layers,
                dropout=model_opt.dropout,
                bidirectional=model_opt.brnn)
            #,
            #fused_msd = model_opt.fused_msd)

    # Build encoder
    if model_opt.brnn:
        assert model_opt.enc_rnn_size % 2 == 0
        hidden_size = model_opt.enc_rnn_size // 2
    else:
        hidden_size = model_opt.enc_rnn_size

    enc_rnn = getattr(nn, model_opt.rnn_type)(
        input_size=encoder_embedding.embedding_dim,
        hidden_size=hidden_size,
        num_layers=model_opt.enc_layers,
        dropout=model_opt.dropout,
        bidirectional=model_opt.brnn)
    encoder = RNNEncoder(encoder_embedding, enc_rnn)

    # Build decoder.
    attn_dim = model_opt.dec_rnn_size
    if model_opt.inflection_attention:
        if model_opt.inflection_gate is not None:
            if model_opt.separate_outputs:
                # if model_opt.separate_heads:
                #     attn = UnsharedGatedFourHeadedAttention.from_options(
                #         attn_dim,
                #         attn_type=model_opt.global_attention,
                #         attn_func=model_opt.global_attention_function,
                #         gate_func=model_opt.inflection_gate
                # )
                # else:
                attn = UnsharedGatedTwoHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate)
            elif model_opt.global_gate_heads:
                attn = DecoderUnsharedGatedFourHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate,
                    combine_gate_input=model_opt.global_gate_head_combine,
                    n_global_heads=model_opt.global_gate_heads_number,
                    infl_attn_func=model_opt.infl_attention_function
                    if hasattr(model_opt, 'infl_attention_function') else None)
                global_head = GlobalGateTowHeadAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    n_global_heads=model_opt.global_gate_heads_number,
                    tahn_transform=model_opt.global_gate_head_nonlin)
            # elif model_opt.global_gate_heads_mix:
            #         attn = DecoderUnsharedGatedFourHeadedAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             gate_func=model_opt.inflection_gate,
            #             combine_gate_input=model_opt.global_gate_head_combine,
            #             n_global_heads=model_opt.global_gate_heads_number * 2,
            #             infl_attn_func=model_opt.infl_attention_function if hasattr(model_opt, 'infl_attention_function') else None
            #         )
            #         global_head_subw = GlobalGateTowHeadAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             n_global_heads=model_opt.global_gate_heads_number,
            #             tahn_transform=model_opt.global_gate_head_nonlin,
            #             att_name='gate_lemma_subw_global_'
            #         )
            #         global_head_char = GlobalGateTowHeadAttention.from_options(
            #             attn_dim,
            #             attn_type=model_opt.global_attention,
            #             attn_func=model_opt.global_attention_function,
            #             n_global_heads=model_opt.global_gate_heads_number,
            #             tahn_transform=model_opt.global_gate_head_nonlin,
            #             att_name='gate_lemma_char_global_'
            #         )
            else:
                attn = GatedTwoHeadedAttention.from_options(
                    attn_dim,
                    attn_type=model_opt.global_attention,
                    attn_func=model_opt.global_attention_function,
                    gate_func=model_opt.inflection_gate)
        else:
            attn = TwoHeadedAttention.from_options(
                attn_dim,
                attn_type=model_opt.global_attention,
                attn_func=model_opt.global_attention_function)
    else:
        attn = Attention.from_options(
            attn_dim,
            attn_type=model_opt.global_attention,
            attn_func=model_opt.global_attention_function)

    dec_input_size = decoder_embedding.embedding_dim
    if model_opt.input_feed:
        dec_input_size += model_opt.dec_rnn_size
        stacked_cell = StackedLSTM if model_opt.rnn_type == "LSTM" \
            else StackedGRU
        dec_rnn = stacked_cell(model_opt.dec_layers, dec_input_size,
                               model_opt.rnn_size, model_opt.dropout)
    else:
        dec_rnn = getattr(nn, model_opt.rnn_type)(
            input_size=dec_input_size,
            hidden_size=model_opt.dec_rnn_size,
            num_layers=model_opt.dec_layers,
            dropout=model_opt.dropout)

    #dec_class = InputFeedRNNDecoder if model_opt.input_feed else RNNDecoder
    if model_opt.input_feed:
        decoder = InputFeedRNNDecoder(decoder_embedding,
                                      dec_rnn,
                                      attn,
                                      dropout=model_opt.dropout)
        # if model_opt.global_gate_heads:
        #     decoder = InputFeedRNNDecoderGlobalHead(
        #         decoder_embedding, dec_rnn, attn, global_head, dropout=model_opt.dropout
        #         )
    else:
        decoder = RNNDecoder(decoder_embedding,
                             dec_rnn,
                             attn,
                             dropout=model_opt.dropout)
    #decoder = dec_class(
    #    decoder_embedding, dec_rnn, attn, dropout=model_opt.dropout
    #)

    if model_opt.out_bias == 'multi':
        bias_vectors = num_langs
    elif model_opt.out_bias == 'single':
        bias_vectors = 1
    else:
        bias_vectors = 0
    if model_opt.loss == "sparsemax":
        output_transform = LogSparsemax(dim=-1)
        # if model_opt.global_attention_function == "sparsemax":
        #     #output_transform = LogSparsemax(dim=-1)
        #     output_transform = Sparsemax(dim=-1) # adpation, log is applied within OutputLayer class
        # elif model_opt.global_attention_function == "fusedmax":
        #     output_transform = Fusedmax()
        # elif model_opt.global_attention_function == "oscarmax":
        #     output_transform = Oscarmax()
        # else:
        #     print('Unknown global_attention:', model_opt.global_attention)
    else:
        output_transform = nn.LogSoftmax(dim=-1)
    generator = OutputLayer(model_opt.dec_rnn_size,
                            decoder_embedding.num_embeddings, output_transform,
                            bias_vectors)
    if model_opt.share_decoder_embeddings:
        generator.weight_matrix.weight = decoder.embeddings["main"].weight

    if model_opt.inflection_attention:
        if model_opt.global_gate_heads:
            model = onmt.models.model.InflectionGGHAttentionModel(
                encoder, inflection_embedding, global_head, decoder, generator)
        # elif model_opt.global_gate_heads_mix:
        #     model = onmt.models.model.InflectionGGHMixedAttentionModel(
        #     encoder, inflection_embedding, global_head_subw, global_head_char, decoder, generator)
        else:
            model = onmt.models.model.InflectionAttentionModel(
                encoder, inflection_embedding, decoder, generator)
    else:
        model = onmt.models.NMTModel(encoder, decoder, generator)

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
    elif model_opt.param_init != 0.0:
        for p in model.parameters():
            p.data.uniform_(-model_opt.param_init, model_opt.param_init)

    device = torch.device("cuda" if gpu else "cpu")
    model.to(device)

    return model
Пример #4
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """

    if opt.decoder_type == "transformer":
        logger.info('TransformerDecoder: layers %d, input dim %d, '
                    'fat relu hidden dim %d, num heads %d, %s global attn, '
                    'copy attn %d, self attn type %s, dropout %.2f' %
                    (opt.dec_layers, opt.dec_rnn_size, opt.transformer_ff,
                     opt.heads, opt.global_attention, opt.copy_attn,
                     opt.self_attn_type, opt.dropout))

        # dec_rnn_size   = dimension of keys/values/queries (input to FF)
        # transformer_ff = dimension of fat relu
        return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        logger.info('InputFeedRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.dec_rnn_size,
                                   opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        logger.info('StdRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.dec_rnn_size,
                             opt.global_attention,
                             opt.global_attention_function,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
Пример #5
0
                                             word_padding_idx=src_padding)
encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size,
                                   num_layers=1,
                                   rnn_type="LSTM",
                                   bidirectional=True,
                                   embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size,
                                             len(vocab["tgt"]),
                                             word_padding_idx=tgt_padding)

from onmt.decoders.decoder import InputFeedRNNDecoder as InputFeedRNNDecoder

decoder = InputFeedRNNDecoder(hidden_size=rnn_size,
                              num_layers=1,
                              bidirectional_encoder=True,
                              rnn_type="LSTM",
                              embeddings=decoder_embeddings)

# from onmt.models.model import NMTModel as NMTModel
model = onmt.models.model.NMTModel(encoder, decoder)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(nn.Linear(rnn_size, len(vocab["tgt"])),
                                nn.LogSoftmax())

loss = onmt.utils.loss.NMTLossCompute(model.generator, vocab["tgt"])

# up the optimizer

optim = onmt.utils.optimizers.Optim(method="sgd",