def load_data(vocab_dir): """ 返回单词以及类别id,为构建矩阵作准备 :param vocab_dir: 词汇表 :return: words, word_to_id, categories, cat_to_id """ words, word_to_id = read_vocab(vocab_dir) categories, cat_to_id = read_category() # vocab_size = len(words) return words, word_to_id, categories, cat_to_id
def __init__(self): self.config = TCNNConfig() self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.config.vocab_size = len(self.words) self.model = TextCNN(self.config) self.session = tf.Session() self.session.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess=self.session, save_path=save_path) # 读取保存的模型
def __init__(self): self.config = TCNNConfig() self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.config.vocab_size = len(self.words) self.config.pre_training = np.load(pre_training) self.model = TextCNN(self.config) self.session = tf.Session() self.session.run(tf.global_variables_initializer()) if not load_checkpoint(save_dir, self.session): exit()
def load_model(self): sess = tf.Session() print('Configuring CNN model...') config = TCNNConfig() cnn_model = TextCNN(config) saver = tf.train.Saver() params_file = tf.train.latest_checkpoint(self.model_dir) saver.restore(sess, params_file) categories, cat_to_id = read_category() vocab_dir = 'cnews/cnews.vocab.txt' words, word_to_id = read_vocab(vocab_dir) self.words = words self.word_to_id = word_to_id self.categories = categories self.cat_to_id = cat_to_id self.cnn_model = cnn_model self.sess = sess print(self.cnn_model) print(self.sess)
# 评估 print("Precision, Recall and F1-Score...") print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) # 混淆矩阵 print("Confusion Matrix...") cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) print(cm) time_dif = get_time_dif(start_time) print("Time usage:", time_dif) if __name__ == '__main__': if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']: raise ValueError("""usage: python run_rnn.py [train / test]""") print('Configuring RNN model...') config = TRNNConfig() if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建 build_vocab(train_dir, vocab_dir, config.vocab_size) categories, cat_to_id = read_category() words, word_to_id = read_vocab(vocab_dir) config.vocab_size = len(words) model = TextRNN(config) if sys.argv[1] == 'train': train() else: test()
def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextRNN() self.model.load_state_dict(torch.load('model_params.pkl'))
from cnews_loader import read_category, read_vocab, process_file from model import TextRNN import numpy as np import torch.utils.data as Data device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #设置数据目录 vocab_file = 'cnews.vocab.txt' train_file = 'cnews.train1.txt' test_file = 'cnews.test.txt' val_file = 'cnews.val.txt' # 获取文本的类别及其对应id的字典 categories, cat_to_id = read_category() #print(categories) # 获取训练文本中所有出现过的字及其所对应的id words, word_to_id = read_vocab('cnews.vocab.txt') #print(words) #print(word_to_id) #print(word_to_id) #获取字数 vocab_size = len(words) # 数据加载及分批 # 获取训练数据每个字的id和对应标签的one-hot形式 x_train, y_train = process_file('cnews.train1.txt', word_to_id, cat_to_id, 600) #print('x_train=', x_train) x_val, y_val = process_file('cnews.val.txt', word_to_id, cat_to_id, 600) #设置GPU cuda = torch.device('cuda') x_train, y_train = torch.LongTensor(x_train), torch.Tensor(y_train)