예제 #1
0
def load_data(vocab_dir):
    """
    返回单词以及类别id,为构建矩阵作准备
    :param vocab_dir: 词汇表
    :return:  words, word_to_id, categories, cat_to_id
    """
    words, word_to_id = read_vocab(vocab_dir)
    categories, cat_to_id = read_category()
    # vocab_size = len(words)
    return words, word_to_id, categories, cat_to_id
예제 #2
0
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
예제 #3
0
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.config.pre_training = np.load(pre_training)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        if not load_checkpoint(save_dir, self.session):
            exit()
예제 #4
0
  def load_model(self):
    sess = tf.Session()
    print('Configuring CNN model...')
    config = TCNNConfig()
    cnn_model = TextCNN(config)

    saver = tf.train.Saver()
    params_file = tf.train.latest_checkpoint(self.model_dir)
    saver.restore(sess, params_file)

    categories, cat_to_id = read_category()
    vocab_dir = 'cnews/cnews.vocab.txt'
    words, word_to_id = read_vocab(vocab_dir)

    self.words = words
    self.word_to_id = word_to_id
    self.categories = categories
    self.cat_to_id = cat_to_id

    self.cnn_model = cnn_model
    self.sess = sess
    print(self.cnn_model)
    print(self.sess)
예제 #5
0
    # 评估
    print("Precision, Recall and F1-Score...")
    print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


if __name__ == '__main__':
    if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
        raise ValueError("""usage: python run_rnn.py [train / test]""")

    print('Configuring RNN model...')
    config = TRNNConfig()
    if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
        build_vocab(train_dir, vocab_dir, config.vocab_size)
    categories, cat_to_id = read_category()
    words, word_to_id = read_vocab(vocab_dir)
    config.vocab_size = len(words)
    model = TextRNN(config)

    if sys.argv[1] == 'train':
        train()
    else:
        test()
예제 #6
0
 def __init__(self):
     self.categories, self.cat_to_id = read_category()
     self.words, self.word_to_id = read_vocab(vocab_dir)
     self.model = TextRNN()
     self.model.load_state_dict(torch.load('model_params.pkl'))
예제 #7
0
from cnews_loader import read_category, read_vocab, process_file
from model import TextRNN
import numpy as np
import torch.utils.data as Data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#设置数据目录
vocab_file = 'cnews.vocab.txt'
train_file = 'cnews.train1.txt'
test_file = 'cnews.test.txt'
val_file = 'cnews.val.txt'
# 获取文本的类别及其对应id的字典
categories, cat_to_id = read_category()
#print(categories)
# 获取训练文本中所有出现过的字及其所对应的id
words, word_to_id = read_vocab('cnews.vocab.txt')
#print(words)
#print(word_to_id)
#print(word_to_id)
#获取字数
vocab_size = len(words)

# 数据加载及分批
# 获取训练数据每个字的id和对应标签的one-hot形式
x_train, y_train = process_file('cnews.train1.txt', word_to_id, cat_to_id, 600)
#print('x_train=', x_train)
x_val, y_val = process_file('cnews.val.txt', word_to_id, cat_to_id, 600)

#设置GPU
cuda = torch.device('cuda')
x_train, y_train = torch.LongTensor(x_train), torch.Tensor(y_train)