def build_decoder(opt, embeddings): """ Various decoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ if opt.decoder_type == "transformer": return TransformerDecoder(opt.dec_layers, opt.rnn_size, opt.heads, opt.transformer_ff, opt.global_attention, opt.copy_attn, opt.self_attn_type, opt.dropout, embeddings) elif opt.decoder_type == "cnn": return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention, opt.copy_attn, opt.cnn_kernel_width, opt.dropout, embeddings) elif opt.input_feed: return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn) else: return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn)
def build_decoder(opt, embeddings): """ Various decoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ if opt.decoder_type == "rnn": if opt.input_feed: return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn) else: return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn) else: raise ModuleNotFoundError("Decoder type not found")
def build_model(model_opt, fields, gpu, checkpoint=None): if model_opt.rnn_size != -1: model_opt.enc_rnn_size = model_opt.rnn_size model_opt.dec_rnn_size = model_opt.rnn_size if model_opt.lang_vec_size is not None: lang_vec_size = model_opt.lang_vec_size else: lang_vec_size = model_opt.src_word_vec_size if model_opt.infl_vec_size is not None: infl_vec_size = model_opt.infl_vec_size else: infl_vec_size = model_opt.src_word_vec_size # Make atomic embedding modules (i.e. not multispace ones) lemma_character_embedding = build_embedding(fields["src"][0][1], model_opt.src_word_vec_size) inflected_character_embedding = build_embedding( fields["tgt"][0][1], model_opt.tgt_word_vec_size) if "inflection" in fields: inflection_embedding = build_embedding(fields["inflection"][0][1], infl_vec_size) else: inflection_embedding = None lang_field = fields["language"][0][1] if "language" in fields else None lang_embeddings = dict() if lang_field is not None and model_opt.lang_location is not None: lang_locations = set(model_opt.lang_location) if "tgt" in lang_locations: assert model_opt.lang_rep != "token", \ "Can only use a feature for tgt language representation" for loc in lang_locations: lang_embeddings[loc] = build_embedding(lang_field, lang_vec_size) num_langs = len(lang_field.vocab) if "language" in fields else 1 # Build the full, multispace embeddings encoder_embedding = MultispaceEmbedding(lemma_character_embedding, language=lang_embeddings.get( "src", None), mode=model_opt.lang_rep) decoder_embedding = MultispaceEmbedding( inflected_character_embedding, language=lang_embeddings.get("tgt", None), mode=model_opt.lang_rep # only 'feature' should be allowed here ) if inflection_embedding is not None: inflection_embedding = MultispaceEmbedding( inflection_embedding, language=lang_embeddings.get("inflection", None), mode=model_opt.lang_rep) if model_opt.inflection_rnn: if not hasattr(model_opt, "inflection_rnn_layers"): model_opt.inflection_rnn_layers = 1 inflection_embedding = InflectionLSTMEncoder( inflection_embedding, model_opt.dec_rnn_size, # need to think about this num_layers=model_opt.inflection_rnn_layers, dropout=model_opt.dropout, bidirectional=model_opt.brnn) #, #fused_msd = model_opt.fused_msd) # Build encoder if model_opt.brnn: assert model_opt.enc_rnn_size % 2 == 0 hidden_size = model_opt.enc_rnn_size // 2 else: hidden_size = model_opt.enc_rnn_size enc_rnn = getattr(nn, model_opt.rnn_type)( input_size=encoder_embedding.embedding_dim, hidden_size=hidden_size, num_layers=model_opt.enc_layers, dropout=model_opt.dropout, bidirectional=model_opt.brnn) encoder = RNNEncoder(encoder_embedding, enc_rnn) # Build decoder. attn_dim = model_opt.dec_rnn_size if model_opt.inflection_attention: if model_opt.inflection_gate is not None: if model_opt.separate_outputs: # if model_opt.separate_heads: # attn = UnsharedGatedFourHeadedAttention.from_options( # attn_dim, # attn_type=model_opt.global_attention, # attn_func=model_opt.global_attention_function, # gate_func=model_opt.inflection_gate # ) # else: attn = UnsharedGatedTwoHeadedAttention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function, gate_func=model_opt.inflection_gate) elif model_opt.global_gate_heads: attn = DecoderUnsharedGatedFourHeadedAttention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function, gate_func=model_opt.inflection_gate, combine_gate_input=model_opt.global_gate_head_combine, n_global_heads=model_opt.global_gate_heads_number, infl_attn_func=model_opt.infl_attention_function if hasattr(model_opt, 'infl_attention_function') else None) global_head = GlobalGateTowHeadAttention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function, n_global_heads=model_opt.global_gate_heads_number, tahn_transform=model_opt.global_gate_head_nonlin) # elif model_opt.global_gate_heads_mix: # attn = DecoderUnsharedGatedFourHeadedAttention.from_options( # attn_dim, # attn_type=model_opt.global_attention, # attn_func=model_opt.global_attention_function, # gate_func=model_opt.inflection_gate, # combine_gate_input=model_opt.global_gate_head_combine, # n_global_heads=model_opt.global_gate_heads_number * 2, # infl_attn_func=model_opt.infl_attention_function if hasattr(model_opt, 'infl_attention_function') else None # ) # global_head_subw = GlobalGateTowHeadAttention.from_options( # attn_dim, # attn_type=model_opt.global_attention, # attn_func=model_opt.global_attention_function, # n_global_heads=model_opt.global_gate_heads_number, # tahn_transform=model_opt.global_gate_head_nonlin, # att_name='gate_lemma_subw_global_' # ) # global_head_char = GlobalGateTowHeadAttention.from_options( # attn_dim, # attn_type=model_opt.global_attention, # attn_func=model_opt.global_attention_function, # n_global_heads=model_opt.global_gate_heads_number, # tahn_transform=model_opt.global_gate_head_nonlin, # att_name='gate_lemma_char_global_' # ) else: attn = GatedTwoHeadedAttention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function, gate_func=model_opt.inflection_gate) else: attn = TwoHeadedAttention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function) else: attn = Attention.from_options( attn_dim, attn_type=model_opt.global_attention, attn_func=model_opt.global_attention_function) dec_input_size = decoder_embedding.embedding_dim if model_opt.input_feed: dec_input_size += model_opt.dec_rnn_size stacked_cell = StackedLSTM if model_opt.rnn_type == "LSTM" \ else StackedGRU dec_rnn = stacked_cell(model_opt.dec_layers, dec_input_size, model_opt.rnn_size, model_opt.dropout) else: dec_rnn = getattr(nn, model_opt.rnn_type)( input_size=dec_input_size, hidden_size=model_opt.dec_rnn_size, num_layers=model_opt.dec_layers, dropout=model_opt.dropout) #dec_class = InputFeedRNNDecoder if model_opt.input_feed else RNNDecoder if model_opt.input_feed: decoder = InputFeedRNNDecoder(decoder_embedding, dec_rnn, attn, dropout=model_opt.dropout) # if model_opt.global_gate_heads: # decoder = InputFeedRNNDecoderGlobalHead( # decoder_embedding, dec_rnn, attn, global_head, dropout=model_opt.dropout # ) else: decoder = RNNDecoder(decoder_embedding, dec_rnn, attn, dropout=model_opt.dropout) #decoder = dec_class( # decoder_embedding, dec_rnn, attn, dropout=model_opt.dropout #) if model_opt.out_bias == 'multi': bias_vectors = num_langs elif model_opt.out_bias == 'single': bias_vectors = 1 else: bias_vectors = 0 if model_opt.loss == "sparsemax": output_transform = LogSparsemax(dim=-1) # if model_opt.global_attention_function == "sparsemax": # #output_transform = LogSparsemax(dim=-1) # output_transform = Sparsemax(dim=-1) # adpation, log is applied within OutputLayer class # elif model_opt.global_attention_function == "fusedmax": # output_transform = Fusedmax() # elif model_opt.global_attention_function == "oscarmax": # output_transform = Oscarmax() # else: # print('Unknown global_attention:', model_opt.global_attention) else: output_transform = nn.LogSoftmax(dim=-1) generator = OutputLayer(model_opt.dec_rnn_size, decoder_embedding.num_embeddings, output_transform, bias_vectors) if model_opt.share_decoder_embeddings: generator.weight_matrix.weight = decoder.embeddings["main"].weight if model_opt.inflection_attention: if model_opt.global_gate_heads: model = onmt.models.model.InflectionGGHAttentionModel( encoder, inflection_embedding, global_head, decoder, generator) # elif model_opt.global_gate_heads_mix: # model = onmt.models.model.InflectionGGHMixedAttentionModel( # encoder, inflection_embedding, global_head_subw, global_head_char, decoder, generator) else: model = onmt.models.model.InflectionAttentionModel( encoder, inflection_embedding, decoder, generator) else: model = onmt.models.NMTModel(encoder, decoder, generator) if checkpoint is not None: model.load_state_dict(checkpoint['model'], strict=False) elif model_opt.param_init != 0.0: for p in model.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) device = torch.device("cuda" if gpu else "cpu") model.to(device) return model
def build_decoder(opt, embeddings): """ Various decoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ if opt.decoder_type == "transformer": logger.info('TransformerDecoder: layers %d, input dim %d, ' 'fat relu hidden dim %d, num heads %d, %s global attn, ' 'copy attn %d, self attn type %s, dropout %.2f' % (opt.dec_layers, opt.dec_rnn_size, opt.transformer_ff, opt.heads, opt.global_attention, opt.copy_attn, opt.self_attn_type, opt.dropout)) # dec_rnn_size = dimension of keys/values/queries (input to FF) # transformer_ff = dimension of fat relu return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size, opt.heads, opt.transformer_ff, opt.global_attention, opt.copy_attn, opt.self_attn_type, opt.dropout, embeddings) elif opt.decoder_type == "cnn": return CNNDecoder(opt.dec_layers, opt.dec_rnn_size, opt.global_attention, opt.copy_attn, opt.cnn_kernel_width, opt.dropout, embeddings) elif opt.input_feed: logger.info('InputFeedRNNDecoder: type %s, bidir %d, layers %d, ' 'hidden size %d, %s global attn (%s), ' 'coverage attn %d, copy attn %d, dropout %.2f' % (opt.rnn_type, opt.brnn, opt.dec_layers, opt.dec_rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.copy_attn, opt.dropout)) return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.dec_rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn) else: logger.info('StdRNNDecoder: type %s, bidir %d, layers %d, ' 'hidden size %d, %s global attn (%s), ' 'coverage attn %d, copy attn %d, dropout %.2f' % (opt.rnn_type, opt.brnn, opt.dec_layers, opt.dec_rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.copy_attn, opt.dropout)) return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.dec_rnn_size, opt.global_attention, opt.global_attention_function, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings, opt.reuse_copy_attn)
word_padding_idx=src_padding) encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1, rnn_type="LSTM", bidirectional=True, embeddings=encoder_embeddings) decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["tgt"]), word_padding_idx=tgt_padding) from onmt.decoders.decoder import InputFeedRNNDecoder as InputFeedRNNDecoder decoder = InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, rnn_type="LSTM", embeddings=decoder_embeddings) # from onmt.models.model import NMTModel as NMTModel model = onmt.models.model.NMTModel(encoder, decoder) # Specify the tgt word generator and loss computation module model.generator = nn.Sequential(nn.Linear(rnn_size, len(vocab["tgt"])), nn.LogSoftmax()) loss = onmt.utils.loss.NMTLossCompute(model.generator, vocab["tgt"]) # up the optimizer optim = onmt.utils.optimizers.Optim(method="sgd",