Exemple #1
0
 def __init__(self):
     print('Configuring CNN model...')
     if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
         build_vocab(train_dir, vocab_dir, config.vocab_size)
     self.categories, cat_to_id = read_category()
     words, self.word_to_id = read_vocab(vocab_dir)
     self.table = pd.read_excel('predict_check_data.xls')
     category_set = list(set(self.table['name'].tolist()))
     self.config = TCNNConfig(len(list(category_set)))
     self.config.vocab_size = len(words)
     self.model = TextCNN(self.config)
     self.categories = list(set(self.table['name'].tolist()))
     self.categories.sort(key=self.table['name'].tolist().index)
Exemple #2
0
    print(y_pred_cls)
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
    #     raise ValueError("""usage: python run_cnn.py [train / test]""")

    print('Configuring CNN model...')
    config = TCNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    dataNums = [16, 32, 64, 128, 256]
    for i in dataNums:
        if i == 0:
            continue
        g1 = tf.Graph()
        sess1 = tf.Session(graph=g1)
        with sess1.as_default():
            with g1.as_default():
                model = TextCNN(config, batch_size=i)
                train()
                test()
                plt.plot(xx, yy1)
Exemple #3
0
        config = TRNNConfig()
        t_name = sys.argv[3]
        t_th = sys.argv[2]
        data_dir = sys.argv[4]
        base_dir = 'data/' + data_dir + '/' + t_name
        classes = sys.argv[5].split('-')


        train_dir = os.path.join(base_dir, 'train.csv')
        test_dir = os.path.join(base_dir, 'test.csv')
        val_dir = os.path.join(base_dir, 'dev.csv')
        vocab_dir = os.path.join('data/data_orginal/'+t_name, 'vocab.csv')

        if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
            print(' vocab_dir not exists: ',vocab_dir)
            build_vocab('data/data_orginal/'+t_name+'/whole.csv', vocab_dir, config.vocab_size)
        categories, cat_to_id = read_category(classes)
        words, word_to_id = read_vocab(vocab_dir)
        config.vocab_size = len(words)
        config.num_classes = len(classes)

        mode_name = 'textrnn'
        save_dir = 'checkpoints/' + t_name + '/' + mode_name + '_' + t_name + "_" + data_dir + '_' + t_th + 'th'
        print('save_dir:', save_dir)
        save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径

        model = TextRNN(config)
        if sys.argv[1] == 'train':
            train()
        else:
            test()
Exemple #4
0
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(
                    msg.format(total_batch, loss_train, acc_train, loss_val,
                               acc_val, time_dif, improved_str))

            session.run(train_op, feed_dict=feed_dict)  # 运行优化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 验证集正确率长期不提升,提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环
        if flag:  # 同上
            break


if __name__ == '__main__':
    model = TextCNN()
    config = model.config
    file_config, _ = Config().parse.parse_known_args()
    print('Configuring CNN model...')
    if not os.path.exists(file_config.vocab_path):  # 如果不存在词汇表,重建
        build_vocab(file_config.train_path, file_config.vocab_path,
                    config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(file_config.vocab_path)
    config.vocab_size = len(words)
    train(model, file_config)
Exemple #5
0
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
    #     raise ValueError("""usage: python run_cnn.py [train / test]""")
    choice = input("train or test:")
    # if choice=='train':
    #     create_file(data_dir,train_dir,test_dir,val_dir,4000,1000)
    print('Configuring CNN model...')
    config = TCNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, 100000)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    filter_sizes = [3, 4, 5]  # 3
    num_filters = 32
    #model = CNN(config.seq_length,config.num_classes,config.vocab_size,config.embedding_dim,filter_sizes,num_filters,0.0)
    model = TextCNN(config)

    if choice == 'train':
        train()
    else:
        test()
    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
        raise ValueError("""usage: python run_rnn.py [train / test]""")

    print('Configuring RNN model...')
    config = TRNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    model = TextRNN(config)

    if sys.argv[1] == 'train':
        train()
    else:
        test()
        metrics.classification_report(y_test_cls,
                                      y_pred_cls,
                                      target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
        raise ValueError("""usage: python run_rnn.py [train / test]""")

    print('Configuring RNN model...')
    config = TRNNConfig()
    if not os.path.exists(vocab_path):  # 如果不存在词汇表,重建
        build_vocab(train_path, vocab_path, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_path)
    config.vocab_size = len(words)
    model = TextRNN(config)

    if sys.argv[1] == 'train':
        train()
    else:
        test()
Exemple #8
0
val_dir = os.path.join(base_dir, 'cnewsval.txt')
vocab_dir = os.path.join(base_dir, 'cnewsvocab.txt')
vector_word_dir= os.path.join(base_dir, 'vector_word.txt')#vector_word trained by word2vec
vector_word_npz=os.path.join(base_dir, 'vector_word.npz')# save vector_word to numpy file
#最佳验证结果保存路径
save_dir = r'HOME\mydata\lstm\checkpoints'
save_path = os.path.join(save_dir, 'best_validation') 
#获取词典
'''build_vocab(train_dir,vocab_dir)
_,word_to_id=read_vocab(vocab_dir)
categories,cat_to_id=read_category()

config=TRNNConfig()
model=TextRNN(config)'''
config=TRNNConfig()
build_vocab(train_dir,vocab_dir)
words,word_to_id=read_vocab(vocab_dir)
categories,cat_to_id=read_category()
config.vocab_size = len(words)
if not os.path.exists(vector_word_npz):
   export_word2vec_vectors(word_to_id, vector_word_dir, vector_word_npz)
config.pre_trianing = get_training_word2vec_vectors(vector_word_npz)
model=TextRNN(config)
init=tf.global_variables_initializer()

def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))
Exemple #9
0

def loaddata():
    from sklearn.datasets import load_files
    from sklearn.model_selection import train_test_split
    dir = 'C:\\Users\\chenshuai\\Desktop\\py\\dataset\\individual\\'
    paper = load_files(dir, encoding='UTF-8')
    x_train, x_test, y_train, y_test = train_test_split(paper.data, paper.target, test_size=0.2)
    return x_train, y_train, x_test, y_test

if __name__ == '__main__':
    # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
    #     raise ValueError("""usage: python run_cnn.py [train / test]""")

    print('Configuring CNN model...')
    config = TCNNConfig()
    print('Loading raw data...')
    x_train, y_train, x_test, y_test = loaddata()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        print('Building vocab...')
        build_vocab(x_train, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    model = TextCNN(config)

    # if sys.argv[1] == 'train':
    train(x_train, y_train)
    # else:
    test(x_test, y_test)
def train(model,data):
    if 
    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
    tensorboard_dir = 'tensorboard/textcnn'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print("Loading training and validation data...")
    # 载入训练集与验证集
    start_time = time.time()
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)

    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 总批次
    best_acc_val = 0.0  # 最佳验证集准确率
    last_improved = 0  # 记录上一次提升批次
    require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练

    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

            if total_batch % config.save_per_batch == 0:
                # 每多少轮次将训练结果写入tensorboard scalar
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            if total_batch % config.print_per_batch == 0:
                # 每多少轮次输出在训练集和验证集上的性能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                loss_val, acc_val = evaluate(session, x_val, y_val)  # todo

                if acc_val > best_acc_val:
                    # 保存最好结果
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

            session.run(model.optim, feed_dict=feed_dict)  # 运行优化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 验证集正确率长期不提升,提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环
        if flag:  # 同上
            break


def test():
    print("Loading test data...")
    start_time = time.time()
    x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

    print('Testing...')
    loss_test, acc_test = evaluate(session, x_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = 128
    data_len = len(x_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    for i in range(num_batch):  # 逐批次处理
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            model.input_x: x_test[start_id:end_id],
            model.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    #if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
    #    raise ValueError("""usage: python run_cnn.py [train / test]""")

    print('Configuring CNN model...')
    config = TCNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    model = TextCNN(config)

    #if sys.argv[1] == 'train':
    #    train()
    #else:
    #    test()
    train()