def test(test_path): new_grocery = Grocery('cv_' + str(fold) + '_model') #, custom_tokenize=segment) new_grocery.load() test_src = [] with open(test_path) as f: for line in f: label, text = line.strip().split("|text|") label = yiji_label[classify_dict[label]] test_src.append((label, text)) test_result = new_grocery.test(test_src) #print test_result #print test_result.accuracy_overall #accs = test_result.accuracy_labels recalls = test_result.recall_labels #print "Recall for each class: ", recalls predictlabels = test_result.predicted_y truelabels = test_result.true_y acc = accuracy_score(truelabels, predictlabels) macro_precision, macro_recall, macro_fscore, _ = precision_recall_fscore_support( truelabels, predictlabels, average='macro') print "Accuracy: ", acc, "Macro-average Precision:", macro_precision, "Macro-average Recall:", macro_recall, "Macro-average Fscore:", macro_fscore labellist = [ 'safe_and_stable', 'industrial_information', 'politics', 'culture_health', 'social_livelihood', 'economic_and_financial' ] precision, recall, fscore, _ = precision_recall_fscore_support( truelabels, predictlabels, average=None, labels=labellist) precisions = dict() recalls = dict() for idx, p in enumerate(precision): precisions[labellist[idx]] = p for idx, c in enumerate(recall): recalls[labellist[idx]] = c
class GroceryModel(object): def __init__(self): self.grocery = Grocery('TextClassify') def train(self,train_file): f = open(train_file,'r') line = f.readline().decode('utf8') dataset = [] while line: tmp = line.split('\t') dataset.append((tmp[0],''.join(tmp[1:]))) line = f.readline().decode('utf8') f.close() self.grocery.train(dataset) self.grocery.save() def load_model(self): self.grocery.load() def test(self,test_src): self.load_model() f = open(test_src,'r') line = f.readline().decode('utf8') dataset = [] while line: tmp = line.split('\t') dataset.append((tmp[0],''.join(tmp[1:]))) line = f.readline().decode('utf8') f.close() result = self.grocery.test(dataset) print result def predict(self,text): print self.grocery.predict(text)
def test_grocery(): grocery = Grocery('model_redian') grocery.train('trdata_4.txt') grocery.save() new_grocery = Grocery('model_redian') new_grocery.load() test_result = new_grocery.test('tedata_4.txt') print test_result.accuracy_labels print test_result.recall_labels test_result.show_result()
#!/usr/bin/env python # coding=utf-8 from tgrocery import Grocery #grocery = Grocery('age56') #grocery.train('train4_age_56', ' ') #grocery.save() new_grocery = Grocery("age") new_grocery.load() predict_result = new_grocery.test('test4_age', ' ') #print len(predict_result.true_y) #for i in range(len(predict_result.predicted_y)): #print predict_result.predicted_y[i] print predict_result predict_result.show_result()
word_train.close() pinyin_train.close() kmer_train.close() grocery = Grocery('sample') train_src = r'E:\classify\plan2\train_kmer.txt' grocery.train(train_src, delimiter=',') print('Training finished! Time consumption:') mid = time.process_time() print(str(mid - start)) grocery.save() grocery.load() test_src = r'E:\classify\plan2\test_kmer.txt' print('Classification accuracy:') print(grocery.test(test_src, delimiter=',')) classifile = open(r'E:\classify\tokens.txt', mode='r', encoding='utf-8') pinyin = open(r'E:\classify\pinyin_grocery.txt', mode='w', encoding='utf-8') words = open(r'E:\classify\words_grocery.txt', mode='w', encoding='utf-8') for line in classifile.readlines(): if grocery.predict(getkmer(line, 2)) == 'word': words.write(line) if grocery.predict(getkmer(line, 2)) == 'pinyin': pinyin.write(line) classifile.close() pinyin.close() words.close()
class TagPredictor(object): def _custom_tokenize(self, line, **kwargs): try: kwargs["method"] except: method = str(self.kwargs["method"]) else: method = str(kwargs["method"]) if method == "normal": tokens = self.key_ext.calculateTokens(line, doc_len_lower_bound=5, doc_len_upper_bound=500, method="normal") elif method == "processed": tokens = line.split(',') return tokens def __init__(self, *args, **kwargs): self.grocery_name = str(kwargs["grocery_name"]) method = str(kwargs["method"]) train_src = str(kwargs["train_src"]) self.PREFIX = conf.load("predict_label")["prefix"] self.MODEL_DIR = conf.load("predict_label")["model_dir"] self.kwargs = kwargs if method == "normal": self.key_ext = keyExt() self.grocery = Grocery(self.grocery_name, custom_tokenize=self._custom_tokenize) elif method == "jieba": self.grocery = Grocery(self.grocery_name) elif method == "processed": self.grocery = Grocery(self.grocery_name, custom_tokenize=self._custom_tokenize) pass def trainFromDocs(self, *args, **kwargs): model = self.grocery.train(self.kwargs["train_src"]) return model def autoEvaluation(self, *args, **kwargs): prune_threshold = float(kwargs["threshold"]) excluded_labels = kwargs["excluded_labels"] excluded_docs = kwargs["excluded_docs"] train_data = [] with open(self.kwargs["train_src"], 'rb') as f: for line in f: try: line.split('\t', 1)[1] except: continue else: train_data.append( (line.split('\t', 1)[0], line.split('\t', 1)[1].split('\n', 1)[0])) f.close() print "#items before filtering:", len(train_data) print "-- Now we filter out the excluded docs --" train_data = [i for i in train_data if i[1] not in excluded_docs] print "#items after filtering:", len(train_data) print "-- Now we filter out the excluded labels --" train_data = [i for i in train_data if i[0] not in excluded_labels] print "#items after filtering:", len(train_data) n = len(train_data) #number of rows in your dataset indices = range(n) indices = shuffle(indices) train_set = map(lambda x: train_data[x], indices[:n * 10 // 10]) test_set = map(lambda x: train_data[x], indices[:n * 10 // 10]) self.grocery.train(train_set) test_result = self.grocery.test(test_set) print '-- Accuracy after training --' print 'Accuracy, A-0:', test_result low_recall_label = [] for item in test_result.recall_labels.items(): if item[1] < prune_threshold: low_recall_label.append(item[0]) new_train_set = [ item for item in train_set if item[0] not in low_recall_label ] new_test_set = [ item for item in train_set if item[0] not in low_recall_label ] self.grocery.train(new_train_set) new_test_result = self.grocery.test(new_test_set) print '-- Accuracy after training, with low-recall labels (less than', str( prune_threshold * 100), '%) pruned --' print 'Accuracy, A-1:', new_test_result return self.grocery, new_test_result def manualEvaluation(self, *args, **kwargs): n_docs = int(kwargs["n_docs"]) excluded_labels = kwargs["excluded_labels"] excluded_docs = kwargs["excluded_docs"] train_data = [] with open(self.kwargs["train_src"], 'rb') as f: for line in f: try: line.split('\t', 1)[1] except: continue else: train_data.append( (line.split('\t', 1)[0], line.split('\t', 1)[1].split('\n', 1)[0])) f.close() train_data = [ item for item in train_data if item[0] not in excluded_labels ] train_data = [i for i in train_data if i[1] not in excluded_docs] n = len(train_data) #number of rows in your dataset indices = range(n) indices = shuffle(indices) test_set = map(lambda x: train_data[x], indices[0:n_docs]) g = self.loadTrainModel() test_result = g.test(test_set) return test_set, test_result def saveTrainModel(self, *args, **kwargs): self.grocery.save() os.rename( self.PREFIX + self.grocery_name + '_train.svm', self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm') return def loadTrainModel(self, *args, **kwargs): os.rename( self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm', self.PREFIX + self.grocery_name + '_train.svm') self.grocery.load() os.rename( self.PREFIX + self.grocery_name + '_train.svm', self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm') return self.grocery def predict(self, line, **kwargs): tag = self.grocery.predict(line) return tag def test(self, *args, **kwargs): test_src = str(kwargs["test_src"]) test_result = self.grocery.test(test_src) print "Total Accuracy", test_result return test_result
class Solution: def __init__(self): self.grocery = Grocery("sample") self.input = open("../data/train_raw.json", "r") def get_train_test(self): info_list = [] n = 0 train_set_file = open("train_file.json", "w") test_set_file = open("test_file.json", "w") for line in self.input: n += 1 info_dict = json.loads(line.strip()) label = info_dict['label'] info = info_dict['title'] url = info_dict['url'] content = info_dict['content'] info_list.append((label, info, url, content)) num_array = np.random.permutation(n) info_array = np.array(info_list) n = len(info_array) train_set = info_array[num_array][:int(n * 0.7)] test_set = info_array[num_array][int(n * 0.7):] out_dict = {} for label, info, url, content in train_set: out_dict['url'] = url out_dict['label'] = label out_dict['info'] = info out_dict['content'] = content print >> train_set_file, json.dumps(out_dict) out_dict = {} for label, info, url, content in test_set: out_dict['url'] = url out_dict['label'] = label out_dict['info'] = info out_dict['content'] = content print >> test_set_file, json.dumps(out_dict) train_set_file.close() test_set_file.close() def train(self): #train_file = open("train_file.json", "r") #test_file = open("test_file.json", "r") train_file = open("../THUCTC/train_7_3.json", "r") test_file = open("../THUCTC/test_7_3.json", "r") train_src = [] test_src = [] for line in train_file: in_dict = json.loads(line.strip()) try: #train_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2])) train_src.append((in_dict['label'], in_dict['title'])) except Exception as e: pass self.grocery.train(train_src) for line in test_file: in_dict = json.loads(line.strip()) try: #test_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2])) test_src.append((in_dict['label'], in_dict['title'])) except Exception as e: pass print len(train_src) self.grocery.save() print self.grocery.test(test_src) """
# coding: utf-8 from tgrocery import Grocery grocery = Grocery('test') train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] grocery.train(train_src) print(grocery.get_load_status()) test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] test_result = grocery.test(test_src) print(test_result.accuracy_labels) print(test_result.recall_labels) grocery = Grocery('text_src') train_src = 'train_ch.txt' grocery.train(train_src) print(grocery.get_load_status()) test_src = 'test_ch.txt' test_result = grocery.test(test_src) print(test_result.accuracy_labels) print(test_result.recall_labels) test_result.show_result()
def test_sample(path, test_path): new_grocery = Grocery(path.encode('utf-8')) new_grocery.load() test_path = os.path.join(BASE_DIR, 'learn', test_path) res = new_grocery.test(test_path.encode('utf-8')) return str(res)
# -*- coding: utf-8 -*- # 测试文件 # Author: Alex # Created Time: 2017年06月02日 星期五 11时15分12秒 from tgrocery import Grocery grocery = Grocery('sample') grocery.train('train_data.txt', delimiter=';') grocery.save() print("*" * 40) new_grocery = Grocery('sample') new_grocery.load() print(new_grocery.test('train_data.txt', delimiter=';')) print("*" * 40) print(new_grocery.predict("考生必读:新托福写作考试评分标准"))
# -*- coding: utf-8 -*- # 测试文件 # Author: Alex # Created Time: 2017年06月02日 星期五 11时15分12秒 from tgrocery import Grocery train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] grocery = Grocery('sample') grocery.train(train_src) grocery.save() print("*" * 40) new_grocery = Grocery('sample') new_grocery.load() res = new_grocery.test(train_src) print(type(res)) print(res) print(res.accuracy_labels) print(res.show_result()) print("*" * 40) res = new_grocery.predict("考生必读:新托福写作考试评分标准") print(res) print(res.dec_values)
from tgrocery import Grocery grocery = Grocery('test') train_src = [ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] grocery.train(train_src) print grocery.get_load_status() test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] test_result = grocery.test(test_src) print test_result.accuracy_labels print test_result.recall_labels grocery = Grocery('text_src') train_src = '../text_src/train_ch.txt' grocery.train(train_src) print grocery.get_load_status() test_src = '../text_src/test_ch.txt' test_result = grocery.test(test_src) print test_result.accuracy_labels print test_result.recall_labels test_result.show_result()
# print new_grocery.predict('吃饱没有') # print new_grocery.predict('周杰伦') # print new_grocery.predict('黑色衣服好看') # print new_grocery.predict('王力宏') # print new_grocery.predict('波哥') # print new_grocery.predict('播歌') # print new_grocery.predict('我要听张含韵的歌') # print new_grocery.predict('放一首:富士山下') # print new_grocery.predict('点播:兄弟') # print new_grocery.predict('听歌') # print new_grocery.predict('听歌。') # print new_grocery.predict('我要听歌') # print new_grocery.predict('我要听音乐。') # print new_grocery.predict('播放歌曲。') # print new_grocery.predict('音乐播放。') # print new_grocery.predict('Music.') # print new_grocery.predict('音乐电台。') # print new_grocery.predict('单曲循环当前歌曲') # print new_grocery.predict('顺序播放歌曲。') # test_src = [ # ('education', '福建春季公务员考试报名18日截止 2月6日考试'), # ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), # ] #new_grocery.test(test_src) # 输出测试的准确率 # 同样可支持文件传入 result = new_grocery.test(text_set) result.show_result() # 自定义分词模块(必须是一个函数) #custom_grocery = Grocery('custom', custom_tokenize=list)
#用文件传入 grocery.train(train_src) #grocery.train( 'train_ch.txt' ) # 保存模型 grocery.save() # 加载模型(名字和保存的一样) new_grocery = Grocery( 'sample' ) new_grocery.load() # 预测 print new_grocery.predict( '考生必读:新托福写作考试评分标准' ) #-------------test用文件传入 test_src = [ ( 'education' , '福建春季公务员考试报名18日截止 2月6日考试' ), ( 'sports' , '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜' ), ] # 准确率 print new_grocery.test(test_src) # 用文本传入 #new_grocery.test( 'test_ch.txt' ) #自定义分词器 #custom_grocery = Grocery( 'custom' , custom_tokenize=list)
# coding=utf-8 from tgrocery import Grocery grocery = Grocery('sample') train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] grocery.train(train_src) #grocery.train('/home/wangjianfei/git/data/train_ch.txt') # grocery.train('train_ch.txt') grocery.save() new_grocery = Grocery('sample') new_grocery.load() print( new_grocery.predict( 'Abbott government spends $8 million on higher education media blitz')) test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] print("start test..................") #grocery.test('/home/wangjianfei/git/data/test.txt') # grocery.train('train_ch.txt')) # custom_grocery = Grocery('custom', custom_tokenize=list) print(new_grocery.test(test_src))
('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] grocery.train(train_src) # 也可以用文件传入(默认以tab为分隔符,也支持自定义) #grocery.train('train_ch.txt') # 保存模型 grocery.save() # 加载模型(名字和保存的一样) new_grocery = Grocery('sample') new_grocery.load() # 预测 new_grocery.predict('考生必读:新托福写作考试评分标准') #education # 测试 test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] new_grocery.test(test_src) # 输出测试的准确率 #0.5 # 同样可支持文件传入 #new_grocery.test('test_ch.txt') # 自定义分词模块(必须是一个函数) #custom_grocery = Grocery('custom', custom_tokenize=list)
random.shuffle(LIST_valid) for line in LIST_valid: print >> o, ('%s,%s' % (line[0], line[1])).encode('utf-8') # grocery = Grocery('sample', custom_tokenize=jieba.cut) grocery = Grocery('all') grocery.train('train-sent.txt', delimiter=',') # # 保存模型 grocery.save() new_grocery = Grocery('all') new_grocery.load() acc = new_grocery.test('valid-sent.txt', delimiter=',').accuracy_labels for i in acc: print i, acc[i] file = open('valid-sent.txt') result = open('result.txt', 'w+') DICT_res_stat = dict() mapping_dict = {'1': 'province', '2': 'city', '3': 'address', '4': 'town', '5': 'name', '6': 'shouji', '7': 'dianhua', '8': 'number', '9': 'leibie'} total_corr = 0 total_count = 0 with open('result.txt', 'w') as o: for line in file: line = line.strip() line_label = line.split(',')[0] text = line.split(',')[1]