Пример #1
0
    model = CNN_CNN_LSTM(word_vocab_size, word_embedding_dim, word_out_channels, char_vocab_size, 
                         char_embedding_dim, char_out_channels, decoder_hidden_units,
                         tag_to_id, pretrained = word_embeds)

elif (model_name == 'CNN_CNN_LSTM_MC'):
    print ('CNN_CNN_LSTM_MC')
    word_vocab_size = len(word_to_id)
    word_embedding_dim = parameters['wrdim']
    word_out_channels = parameters['wdchl']
    char_vocab_size = len(char_to_id)
    char_embedding_dim = parameters['chdim']
    char_out_channels = parameters['cnchl']
    decoder_hidden_units = parameters['dchid']

    model = CNN_CNN_LSTM_MC(word_vocab_size, word_embedding_dim, word_out_channels, char_vocab_size, 
                            char_embedding_dim, char_out_channels, decoder_hidden_units,
                            tag_to_id, pretrained = word_embeds)
    
elif (model_name == 'CNN_CNN_LSTM_BB'):
        print ('CNN_CNN_LSTM_BB')
        word_vocab_size = len(word_to_id)
        word_embedding_dim = parameters['wrdim']
        word_out_channels = parameters['wdchl']
        char_vocab_size = len(char_to_id)
        char_embedding_dim = parameters['chdim']
        char_out_channels = parameters['cnchl']
        decoder_hidden_units = parameters['dchid']
        sigma_prior = parameters['sigmp']

        model = CNN_CNN_LSTM_BB(word_vocab_size, word_embedding_dim, word_out_channels, char_vocab_size, 
                                char_embedding_dim, char_out_channels, decoder_hidden_units,
Пример #2
0
def make_model(config, mappings, result_path):
    word_to_id = mappings['word_to_id']
    tag_to_id = mappings['tag_to_id']
    char_to_id = mappings['char_to_id']
    word_embeds = mappings['word_embeds']
    if config.opt.reload:
        log.info(
            'Loading Saved Weights....................................................................'
        )
        model_path = os.path.join(result_path, config.opt.usemodel,
                                  config.opt.checkpoint, 'modelweights')
        model = torch.load(model_path)
    else:
        log.info(
            'Building Model............................................................................'
        )
        log.info(config.opt.usemodel)
        word_vocab_size = len(word_to_id)
        char_vocab_size = len(char_to_id)
        word_embedding_dim = config.parameters.wrdim
        char_out_channels = config.parameters.cnchl
        char_embedding_dim = config.parameters.chdim

        if (config.opt.usemodel == 'CNN_BiLSTM_CRF'):
            word_hidden_dim = config.parameters.wldim
            model = CNN_BiLSTM_CRF(word_vocab_size,
                                   word_embedding_dim,
                                   word_hidden_dim,
                                   char_vocab_size,
                                   char_embedding_dim,
                                   char_out_channels,
                                   tag_to_id,
                                   pretrained=word_embeds)

        elif (config.opt.usemodel == 'CNN_BiLSTM_CRF_MC'):
            word_hidden_dim = config.parameters['wldim']
            model = CNN_BiLSTM_CRF_MC(word_vocab_size,
                                      word_embedding_dim,
                                      word_hidden_dim,
                                      char_vocab_size,
                                      char_embedding_dim,
                                      char_out_channels,
                                      tag_to_id,
                                      pretrained=word_embeds)

        elif (config.opt.usemodel == 'CNN_BiLSTM_CRF_BB'):
            word_hidden_dim = config.parameters['wldim']
            sigma_prior = config.parameters['sigmp']
            model = CNN_BiLSTM_CRF_BB(word_vocab_size,
                                      word_embedding_dim,
                                      word_hidden_dim,
                                      char_vocab_size,
                                      char_embedding_dim,
                                      char_out_channels,
                                      tag_to_id,
                                      sigma_prior=sigma_prior,
                                      pretrained=word_embeds)

        elif (config.opt.usemodel == 'CNN_CNN_LSTM'):
            word_out_channels = config.parameters['wdchl']
            decoder_hidden_units = config.parameters['dchid']
            model = CNN_CNN_LSTM(word_vocab_size,
                                 word_embedding_dim,
                                 word_out_channels,
                                 char_vocab_size,
                                 char_embedding_dim,
                                 char_out_channels,
                                 decoder_hidden_units,
                                 tag_to_id,
                                 pretrained=word_embeds)

        elif (config.opt.usemodel == 'CNN_CNN_LSTM_MC'):
            word_out_channels = config.parameters['wdchl']
            decoder_hidden_units = config.parameters['dchid']
            model = CNN_CNN_LSTM_MC(word_vocab_size,
                                    word_embedding_dim,
                                    word_out_channels,
                                    char_vocab_size,
                                    char_embedding_dim,
                                    char_out_channels,
                                    decoder_hidden_units,
                                    tag_to_id,
                                    pretrained=word_embeds)

        elif (config.opt.usemodel == 'CNN_CNN_LSTM_BB'):
            word_out_channels = config.parameters['wdchl']
            decoder_hidden_units = config.parameters['dchid']
            sigma_prior = config.parameters['sigmp']
            model = CNN_CNN_LSTM_BB(word_vocab_size,
                                    word_embedding_dim,
                                    word_out_channels,
                                    char_vocab_size,
                                    char_embedding_dim,
                                    char_out_channels,
                                    decoder_hidden_units,
                                    tag_to_id,
                                    sigma_prior=sigma_prior,
                                    pretrained=word_embeds)
        else:
            raise KeyError
    return model