Ejemplo n.º 1
0
    def __init__(self):
        self.config = TCNNConfig()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.config.pre_training = pd.read_csv(word_vector_dir, header=None, index_col=None).values
        self.model = TextCNN(self.config)
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        self.session = tf.Session(config=session_conf)
        self.session.run(tf.global_variables_initializer())

        # self.session.run(tf.initialize_local_variables())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
Ejemplo n.º 2
0
def predict(sentences):
    config = RNNConfig()
    config.pre_trianing = get_training_word2vec_vectors(vector_word_npz)
    model = TextRNN(config)
    save_dir = './checkpoints/textrnn'
    save_path = os.path.join(save_dir, 'best_validation')
    _, word_to_id = read_vocab(vocab_filename)
    input_x = process_file(sentences, word_to_id, max_length=config.seq_length)
    labels = {
        0: '娱乐',
        1: '游戏',
        2: '音乐',
        3: '星座运势',
        4: '体育',
        5: '食物',
        6: '时尚',
        7: '社会万象',
        8: '汽车',
        9: '农业',
        10: '母婴育儿',
        11: '科技',
        12: '军事',
        13: '教育',
        14: '健康养生',
        15: '国际视野',
        16: '搞笑',
        17: '动漫',
        18: '宠物',
        19: '财经',
        20: '历史',
        21: '家居',
        22: '房产'
    }

    feed_dict = {
        model.input_x: input_x,
        model.keep_prob: 1,
        model.sequence_lengths: get_sequence_length(input_x)
    }
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)
    y_prob = session.run(tf.nn.softmax(model.logits), feed_dict=feed_dict)
    y_prob = y_prob.tolist()
    cat = []
    for prob in y_prob:
        top2 = list(map(prob.index, heapq.nlargest(1, prob)))
        cat.append(labels[top2[0]])
    tf.reset_default_graph()
    return cat
Ejemplo n.º 3
0
import os
import sys
import time
import tensorflow as tf
import data_helper as DH
import math
"""
提供对模型的预测
"""
# Parameters
# ==================================================

#DH.Get_Save_CategoryFromOriginData()
categories, cat_to_id = DH.Get_Categories()
id_to_cat = DH.Get_Id_To_Cat()
words, word_to_id = DH.read_vocab(DH.TextConfig.vocab_filename)
vocab_size = len(words)

num_classes = len(categories)

pad_seq_len = DH.TextConfig.seq_length

logger = DH.logger_fn(
    'tflog', 'logs/predict-{0}.log'.format(
        time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))))

# Data Parameters

tf.flags.DEFINE_string("predict_data_file", DH.TextConfig.predict_File,
                       "Data source for the test data")
tf.flags.DEFINE_string("checkpoint_dir", "./runs/checkpoints",
Ejemplo n.º 4
0
        metrics.classification_report(y_test_cls,
                                      y_pred_cls,
                                      target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    config = RNNConfig()  # 获取配置参数
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    # print(word_to_id)
    # print(words)
    if not os.path.exists(vector_word_npz):
        export_word2vec_vectors(word_to_id, words_embeding, vector_word_npz)
    config.pre_trianing = get_training_word2vec_vectors(vector_word_npz)
    model = TextRNN(config)
    option = ''
    if option == 'train':
        train()
    else:
        test()
Ejemplo n.º 5
0
    for r, l in zip(res, labels):
        if r == l[1]:
            correct += 1
        num += 1
    return correct / num


# =====================数据预处理=====================

# 加载数据,返回数据集和标签
print("Loading data...")
sents, labels = data_helper.load_data('./data/train.txt')
test_sents, test_labels = data_helper.load_data('./data/test.txt')

max_len = 1024
vocab = data_helper.read_vocab()
data = data_helper.sent2idx(sents, vocab, max_len)
test_data = data_helper.sent2idx(test_sents, vocab, max_len)
epoch = 100


with tf.Graph().as_default():
    session_conf = tf.ConfigProto()
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        cnn = TextCNN(vocab_size=len(vocab),
                      seq_len=max_len,
                      embedding_size=256,
                      num_classes=2,
                      filter_sizes=[3, 4, 5],
                      num_filters=256)