Beispiel #1
0
if __name__ == '__main__':
    # dataset = 'THUCNews'
    # dataset = 'TOUTIAONews'
    # dataset = 'weibo_senti_100k'
    # dataset = 'simplifyweibo_4_moods'
    # dataset = 'Chinese_conversation_sentiment-master'
    # dataset = 'NLPCC2017'
    dataset = 'testtt'
    model_name = args.model
    x = import_module('models.' + model_name)
    config = x.Config(dataset)
    np.random.seed(2)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(4)
    torch.backends.cudnn.deterministic = True  #保证每次运行结果一样

    start_time = time.time()
    print('加载数据集')
    train_data, dev_data, test_data = utils.bulid_dataset(config)
    train_iter = utils.bulid_iterator(train_data, config)
    dev_iter = utils.bulid_iterator(dev_data, config)
    test_iter = utils.bulid_iterator(test_data, config)

    time_dif = utils.get_time_dif(start_time)
    print("模型开始之前,准备数据时间:", time_dif)

    # 模型训练,评估与测试
    model = x.Model(config).to(config.device)
    train.train(config, model, train_iter, dev_iter, test_iter)
    # train.test(config, model, test_iter)
import numpy as np
import pandas as pd
from utils import bulid_dataset
import matplotlib.pyplot as plt
from keras.models import Model, Input, load_model
from keras.callbacks import ModelCheckpoint
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional
plt.style.use("ggplot")

# 1 加载数据
ner_dataset_dir = '../data/ner_dataset.csv'
dataset_dir = '../data/dataset.pkl'

# 2 构建数据集
n_words, n_tags, max_len, words,tags,\
X_train, X_test, y_train, y_test=bulid_dataset(ner_dataset_dir,dataset_dir,max_len=50)


# 3 构建和训练模型
def train():
    input = Input(shape=(max_len, ))
    model = Embedding(input_dim=n_words, output_dim=50,
                      input_length=max_len)(input)
    model = Dropout(0.1)(model)
    model = Bidirectional(
        LSTM(units=100, return_sequences=True, recurrent_dropout=0.1))(model)
    out = TimeDistributed(Dense(n_tags, activation='softmax'))(
        model)  # softmax output layer

    model = Model(input, out)
    model.compile(optimizer='rmsprop',