def train_compare_result(train_src, test_src): grocery = Grocery('test') grocery.train(train_src) print grocery.get_load_status() len_test = len(test_src) print len_test Predict_num = 0 History = [] for test in test_src: Predict_result = { 'predict_title': test[1], 'predict_class': None, 'true_class': None } predict_title = Predict_result['predict_title'] predict_result = grocery.predict(predict_title) Predict_result['predict_class'], Predict_result['true_class'] = test[ 0], predict_result if str(predict_result) == str(test[0]): # print 'prediction is True' Predict_num += 1 History.append(Predict_result) # print 'prediction is False' predict_precision = float(Predict_num) / len_test return predict_precision, History
class AutoGrocery(object): """ """ def __init__(self, name, train_data): self._train_data = train_data self._grocery = Grocery(project_dir + '/models/model_data/' + name) def train(self): self._grocery.train(self._train_data) def save(self): self._grocery.save() def load(self): self._grocery.load() def predicate(self, src): if not self._grocery.get_load_status(): try: self.load() except ValueError: self.train() self.save() pr = self._grocery.predict(src) label = pr.predicted_y return label, pr.dec_values[label]
def test_main(self): grocery = Grocery(self.grocery_name) grocery.train(self.train_src) grocery.save() new_grocery = Grocery('test') new_grocery.load() assert grocery.get_load_status() assert grocery.predict('考生必读:新托福写作考试评分标准') == 'education' # cleanup if self.grocery_name and os.path.exists(self.grocery_name): shutil.rmtree(self.grocery_name)
def sentiment_train(gro_name, train_set): """ :param gro_name: :param train_set: :return: """ gro_ins = Grocery(gro_name) # gro_ins.load() gro_ins.train(train_set) print("Is trained? ", gro_ins.get_load_status()) gro_ins.save()
def sentiment_train(gro_name, train_set): """ tgGrocery svm train :param gro_name: :param train_set: :return: """ gro_ins = Grocery(gro_name) # gro_ins.load() gro_ins.train(train_set) print("Is trained? ", gro_ins.get_load_status()) gro_ins.save()
def test_main(self): grocery = Grocery(self.grocery_name) grocery.train(self.train_src) grocery.save() new_grocery = Grocery('test') new_grocery.load() assert grocery.get_load_status() result = grocery.predict('just a testing') print(result) result = grocery.predict('考生必读:新托福写作考试评分标准') print(result) print("type of result is :",type(result)) assert str(grocery.predict('考生必读:新托福写作考试评分标准')) == 'education' assert str(grocery.predict('法网')) == 'sports' # cleanup if self.grocery_name and os.path.exists(self.grocery_name): shutil.rmtree(self.grocery_name)
tdic['id'].append(_id) tdic['type'].append(_type) tdic['contents'].append(contents) i +=1 #train = pd.read_csv( train_file, header = 0, delimiter = "\t", quoting = 3 ) #test = pd.read_csv( test_file, header = 1, delimiter = "\t", quoting = 3 ) train = DataFrame(dic) test = DataFrame(tdic) # #classfynews_instance 是模型保存路径 grocery = Grocery('classfynews_instance') train_in = [train['contents'],train['type']] grocery.train(train_in) print grocery.get_load_status() #grocery.save() copy_grocery = Grocery('classfynews_instance') copy_grocery.load() #copy_grocery = grocery test_in = [test['contents'],test['type']] #输入类似 ['我是中国人','台北*****'] #输出 [11,12] test_result = copy_grocery.predict(test['contents']) print test_result.predicted_y #test_result = copy_grocery.test(test_in) #print test_result.show_result()
# coding: utf-8 from tgrocery import Grocery grocery = Grocery('test') train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] grocery.train(train_src) print(grocery.get_load_status()) predict_result = grocery.predict('考生必读:新托福写作考试评分标准') print(predict_result) print(predict_result.dec_values) grocery = Grocery('read_text') train_src = '../text_src/train_ch.txt' grocery.train(train_src) print(grocery.get_load_status()) predict_result = grocery.predict('考生必读:新托福写作考试评分标准') print(predict_result) print(predict_result.dec_values)
class Cat: def __init__(self): self.grocery = Grocery('autohome') def test(self): print self.grocery.get_load_status()
# coding: utf-8 from tgrocery import Grocery grocery = Grocery('test') train_src = [ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] grocery.train(train_src) print grocery.get_load_status() test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] test_result = grocery.test(test_src) print test_result.accuracy_labels print test_result.recall_labels grocery = Grocery('text_src') train_src = '../text_src/train_ch.txt' grocery.train(train_src) print grocery.get_load_status() test_src = '../text_src/test_ch.txt' test_result = grocery.test(test_src) print test_result.accuracy_labels
# coding: utf-8 from tgrocery import Grocery grocery = Grocery('test') train_src = [ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] grocery.train(train_src) print(grocery.get_load_status()) predict_result = grocery.predict('考生必读:新托福写作考试评分标准') print(predict_result) print(predict_result.dec_values) grocery = Grocery('read_text') train_src = '../text_src/train_ch.txt' grocery.train(train_src) print(grocery.get_load_status()) predict_result = grocery.predict('考生必读:新托福写作考试评分标准') print(predict_result) print(predict_result.dec_values)
# coding: utf-8 from tgrocery import Grocery # pass a tokenizer, must be a python func custom_grocery = Grocery('custom', custom_tokenize=list) train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] custom_grocery.train(train_src) print custom_grocery.get_load_status() print custom_grocery.predict('考生必读:新托福写作考试评分标准')
# coding: utf-8 from tgrocery import Grocery # pass a tokenizer, must be a python func custom_grocery = Grocery('custom', custom_tokenize=list) train_src = [ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] custom_grocery.train(train_src) print custom_grocery.get_load_status() print custom_grocery.predict('考生必读:新托福写作考试评分标准')
def demo_flask(image_file): grocery = Grocery('Addrss_NLP') model_name=grocery.name text_converter=None if (os.path.exists(model_name)): tgM=GroceryTextModel(text_converter,model_name) tgM.load(model_name) grocery.model=tgM print('load!!!!!') else: add_file = open('pkl_data/address1.pkl', 'rb') other_file = open('pkl_data/others1.pkl', 'rb') add_list = pickle.load(add_file) other_list = pickle.load(other_file) add_file .close() other_file .close() grocery = Grocery('Addrss_NLP') add_list.extend(other_list) grocery.train(add_list) print (grocery.get_load_status()) grocery.save() # print('train!!!!!!!!') addrline = [] t = time.time() result_dir = '/data/share/nginx/html/bbox' image = np.array(Image.open(image_file).convert('RGB')) result, image_framed = ocr_whole.model(image) output_file = os.path.join(result_dir, image_file.split('/')[-1]) Image.fromarray(image_framed).save(output_file) ret_total = '' for key in result: string1 = result[key][1] # print("predict line text :",string1) string2 = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*{}[]+", "", string1) no_digit = len(list(filter(str.isdigit, string2))) no_alpha = len(list(filter(is_alphabet, string2))) if '注册' in string2 or '洼册' in string2 or '洼·册' in string2 or '洼.册' in string2 or '汪·册' in string2 or len(set('登记机关') & set(string2)) >= 3 or '电话' in string2 or ((no_digit / len(string2) > 0.7 and no_digit > 5)): predict_result='others' elif no_alpha>5 or len(set('经营范围化学品') & set(string2)) >= 3 or len(set('年月日') & set(string2)) >= 2: predict_result='others' else: predict_result = grocery.predict(string2) if (str(predict_result) == 'address'): string1 = string1.replace('《', '(') string1 = string1.replace('》', ')') string1 = string1.replace('(', '(') string1 = string1.replace(')', ')') string1 = string1.replace('((','(') if ((not ret_total) or len(string1) > len(ret_total)): ret_total = '' ret_total += string1 else: ret_total += string1 if ')' in ret_total: if '(' not in ret_total: ret_total = ret_total.replace('C', '(') ret_total = re.sub(r'((\w)住所(.*)', '', ret_total) ret_total = re.sub(r'((\w)住房(.*)', '', ret_total) ret_total = re.sub(r'(不作为(.*)', '', ret_total) ret_total = re.sub(r'(有效期(.*)', '', ret_total) ret_total = re.sub(r'(仅限(.*)', '', ret_total) ret_total = re.sub(r'(临时经营(.*)', '', ret_total) ret_total = re.sub(r'(仅限办公(.*)', '', ret_total) ret_total = re.sub(r'(经营场所(.*)', '', ret_total) ret_total = re.sub(r"^[经]*[营]*[场/住]*[所]*", "", ret_total) ret_total = stupid_revise(ret_total) print("Mission complete, it took {:.3f}s".format(time.time() - t)) print('\nRecongition Result:\n') print(ret_total) return output_file,ret_total