def train(): print 'train start '+'.'*30 #grocery=Grocery('sample') grocery=Grocery('version1.0') grocery.train(trainlist) grocery.save() print 'train end '+'.'*30
def tGrocery(): outFile = open('testResult.tmp', 'w') [trainingSet, benchmark] = pickle.load(open('SampleSeg.pk')) testingSet = [] correctLabel = [] for i in xrange(len(benchmark)): print '%d out of %d' % (i, len(benchmark)) testingSet.append(benchmark[i][1]) correctLabel.append(benchmark[i][0]) grocery = Grocery('test') grocery.train(trainingSet) grocery.save() # load new_grocery = Grocery('test') new_grocery.load() Prediction = [] for i in xrange(len(testingSet)): print '%d out of %d' % (i, len(testingSet)) prediction = new_grocery.predict(testingSet[i]) Prediction.append(prediction) temp = correctLabel[i] + '<-->' + prediction + ' /x01' + testingSet[i] + '\n' outFile.write(temp) correct = 0 for i in xrange(len(Prediction)): print Prediction[i], correctLabel[i], if Prediction[i] == correctLabel[i]: correct += 1 print 'Correct' else: print 'False' print 'Correct Count:', correct print 'Accuracy: %f' % (1.0 * correct / len(Prediction))
class GroceryModel(object): def __init__(self): self.grocery = Grocery('TextClassify') def train(self,train_file): f = open(train_file,'r') line = f.readline().decode('utf8') dataset = [] while line: tmp = line.split('\t') dataset.append((tmp[0],''.join(tmp[1:]))) line = f.readline().decode('utf8') f.close() self.grocery.train(dataset) self.grocery.save() def load_model(self): self.grocery.load() def test(self,test_src): self.load_model() f = open(test_src,'r') line = f.readline().decode('utf8') dataset = [] while line: tmp = line.split('\t') dataset.append((tmp[0],''.join(tmp[1:]))) line = f.readline().decode('utf8') f.close() result = self.grocery.test(dataset) print result def predict(self,text): print self.grocery.predict(text)
class AutoGrocery(object): """ """ def __init__(self, name, train_data): self._train_data = train_data self._grocery = Grocery(project_dir + '/models/model_data/' + name) def train(self): self._grocery.train(self._train_data) def save(self): self._grocery.save() def load(self): self._grocery.load() def predicate(self, src): if not self._grocery.get_load_status(): try: self.load() except ValueError: self.train() self.save() pr = self._grocery.predict(src) label = pr.predicted_y return label, pr.dec_values[label]
def __train__model__(): dataframe = pd.read_excel(Classify.__FILE_PATH__) data = dataframe[[u'类型', u'释义']] train_data = [(x[0],x[1]) for x in data.values] grocery = Grocery('Classify') grocery.train(train_data) grocery.save() Classify.__MODEL__ = grocery
def __train__model__(): dataframe = pd.read_excel(Classify.__FILE_PATH__) data = dataframe[[u'类型', u'释义']] train_data = [(x[0], x[1]) for x in data.values] grocery = Grocery('Classify') grocery.train(train_data) grocery.save() Classify.__MODEL__ = grocery
def test_grocery(): grocery = Grocery('model_redian') grocery.train('trdata_4.txt') grocery.save() new_grocery = Grocery('model_redian') new_grocery.load() test_result = new_grocery.test('tedata_4.txt') print test_result.accuracy_labels print test_result.recall_labels test_result.show_result()
def test_grocery(): grocery = Grocery('model_redian') grocery.train('trdata_4.txt') grocery.save() new_grocery = Grocery('model_redian') new_grocery.load() test_result = new_grocery.test('tedata_4.txt') print test_result.accuracy_labels print test_result.recall_labels test_result.show_result()
def test_main(self): grocery = Grocery(self.grocery_name) grocery.train(self.train_src) grocery.save() new_grocery = Grocery('test') new_grocery.load() assert grocery.get_load_status() assert grocery.predict('考生必读:新托福写作考试评分标准') == 'education' # cleanup if self.grocery_name and os.path.exists(self.grocery_name): shutil.rmtree(self.grocery_name)
def test_main(self): grocery = Grocery(self.grocery_name) grocery.train(self.train_src) grocery.save() new_grocery = Grocery('test') new_grocery.load() assert grocery.get_load_status() assert grocery.predict('考生必读:新托福写作考试评分标准') == 'education' # cleanup if self.grocery_name and os.path.exists(self.grocery_name): shutil.rmtree(self.grocery_name)
def sentiment_train(gro_name, train_set): """ :param gro_name: :param train_set: :return: """ gro_ins = Grocery(gro_name) # gro_ins.load() gro_ins.train(train_set) print("Is trained? ", gro_ins.get_load_status()) gro_ins.save()
def sentiment_train(gro_name, train_set): """ tgGrocery svm train :param gro_name: :param train_set: :return: """ gro_ins = Grocery(gro_name) # gro_ins.load() gro_ins.train(train_set) print("Is trained? ", gro_ins.get_load_status()) gro_ins.save()
def train(train_origin_path, fold): grocery = Grocery('cv_' + str(fold) + '_model') #, custom_tokenize=segment) train_src = [] with open(train_origin_path) as f: for line in f: label, text = line.strip().split("|text|") label = yiji_label[classify_dict[label]] train_src.append((label, text)) grocery.train(train_src) grocery.save()
def train_phrasing_and_save(self, trainsets=all): ''' :param trainsets: :param model_name: :return: ''' try: grocery = Grocery(self.model_name) grocery.train(trainsets) grocery.save() return True except: return False
def test_main(self): grocery = Grocery(self.grocery_name) grocery.train(self.train_src) grocery.save() new_grocery = Grocery('test') new_grocery.load() assert grocery.get_load_status() result = grocery.predict('just a testing') print(result) result = grocery.predict('考生必读:新托福写作考试评分标准') print(result) print("type of result is :",type(result)) assert str(grocery.predict('考生必读:新托福写作考试评分标准')) == 'education' assert str(grocery.predict('法网')) == 'sports' # cleanup if self.grocery_name and os.path.exists(self.grocery_name): shutil.rmtree(self.grocery_name)
class MyGrocery(object): def __init__(self, name): super(MyGrocery, self).__init__() self.grocery = Grocery(name) self.loaded = False self.correct = 1.0 def train(self, src): lines = [] for line in csv.reader(open(src)): label, s = line[0],line[1] text = s.decode('utf8') lines.append((label, text)) self.grocery.train(lines) def save_model(self): self.grocery.save() def train_and_save(self, src): self.train(src) self.save_model() def load_model(self): if not self.loaded: self.grocery.load() self.loaded = True def predict(self, text): self.load_model() return self.grocery.predict(text) def test(self, src): self.load_model() total, wrong_num = 0.0, 0.0 for line in csv.reader(open(src)): total += 1 if line[0] != self.predict(line[1]): wrong_num += 1 print "load test file from " + src correct = (total - wrong_num ) / total self.correct = correct print "total: %d , wrong_num: %d, success percentage: %f" %(total, wrong_num, correct) result = dict(type="test", total=total, wrong_num=wrong_num, correct=correct) return json.dumps(result)
def tgrocery_train(train_data,test_data): '''model预测''' print("训练语料总数为: " + str(len(train_data))) test_corpus, test_label = test_split(test_data) grocery = Grocery('TextGrocery') print("start training......") grocery.train(train_data) grocery.save() new_grocery = Grocery('TextGrocery') new_grocery.load() predict_label = [] for sample in test_corpus: label = new_grocery.predict(sample) predict_label.append(str(label)) # print(predict_label) return test_corpus,test_label,predict_label
def learn_model(file_name): path = os.path.join(BASE_DIR, 'learn', file_name) try: df = pd.read_excel(path) except Exception as e: return {'IsErr': True, 'ErrDesc': u'找不到文档或者读取文档出错'} try: # 删去缺失值的行 df = df.dropna(axis=0) df = df.apply(split_comment, axis=1) except Exception as e: return {'IsErr': True, 'ErrDesc': u'文档格式有误,应包含Tag(标签名字),Comment(评价内容)'} try: # 拆分学习组和测试组 3 :2 len_learn = len(df) / 5 * 3 # 生成学习文档和测试文档 learn_file_name, test_file_name = output_file(file_name, df, len_learn) tmp_learn_name = os.path.join(BASE_DIR, 'learn', 'model_' + learn_file_name.split('.')[0]) grocery = Grocery(tmp_learn_name.encode('utf-8')) path = os.path.join(BASE_DIR, 'learn', learn_file_name) grocery.train(path.encode('utf-8')) grocery.save() except Exception as e: return {'IsErr': True, 'ErrDesc': u'学习不成功,没有生产新的模型,请再次尝试。'} # 测试 res = test_sample(tmp_learn_name, test_file_name) return { 'IsErr': False, 'ErrDesc': u'成功生产新的模型,测试验证的正确率为%s, 模型保存为:%s' % (res, os.path.split(tmp_learn_name)[1]) }
from tgrocery import Grocery # 新开张一个杂货铺(别忘了取名) grocery = Grocery('sample') # 训练文本可以用列表传入 train_src = [ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与') ] grocery.train(train_src) # 也可以用文件传入(默认以tab为分隔符,也支持自定义) #grocery.train('train_ch.txt') # 保存模型 grocery.save() # 加载模型(名字和保存的一样) new_grocery = Grocery('sample') new_grocery.load() # 预测 new_grocery.predict('考生必读:新托福写作考试评分标准') #education # 测试 test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] new_grocery.test(test_src) # 输出测试的准确率 #0.5
# coding=utf-8 from tgrocery import Grocery grocery = Grocery('sample') train_src = [('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')] grocery.train(train_src) #grocery.train('/home/wangjianfei/git/data/train_ch.txt') # grocery.train('train_ch.txt') grocery.save() new_grocery = Grocery('sample') new_grocery.load() print( new_grocery.predict( 'Abbott government spends $8 million on higher education media blitz')) test_src = [ ('education', '福建春季公务员考试报名18日截止 2月6日考试'), ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), ] print("start test..................") #grocery.test('/home/wangjianfei/git/data/test.txt') # grocery.train('train_ch.txt')) # custom_grocery = Grocery('custom', custom_tokenize=list) print(new_grocery.test(test_src))
def demo_flask(image_file): grocery = Grocery('Addrss_NLP') model_name=grocery.name text_converter=None if (os.path.exists(model_name)): tgM=GroceryTextModel(text_converter,model_name) tgM.load(model_name) grocery.model=tgM print('load!!!!!') else: add_file = open('pkl_data/address1.pkl', 'rb') other_file = open('pkl_data/others1.pkl', 'rb') add_list = pickle.load(add_file) other_list = pickle.load(other_file) add_file .close() other_file .close() grocery = Grocery('Addrss_NLP') add_list.extend(other_list) grocery.train(add_list) print (grocery.get_load_status()) grocery.save() # print('train!!!!!!!!') addrline = [] t = time.time() result_dir = '/data/share/nginx/html/bbox' image = np.array(Image.open(image_file).convert('RGB')) result, image_framed = ocr_whole.model(image) output_file = os.path.join(result_dir, image_file.split('/')[-1]) Image.fromarray(image_framed).save(output_file) ret_total = '' for key in result: string1 = result[key][1] # print("predict line text :",string1) string2 = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*{}[]+", "", string1) no_digit = len(list(filter(str.isdigit, string2))) no_alpha = len(list(filter(is_alphabet, string2))) if '注册' in string2 or '洼册' in string2 or '洼·册' in string2 or '洼.册' in string2 or '汪·册' in string2 or len(set('登记机关') & set(string2)) >= 3 or '电话' in string2 or ((no_digit / len(string2) > 0.7 and no_digit > 5)): predict_result='others' elif no_alpha>5 or len(set('经营范围化学品') & set(string2)) >= 3 or len(set('年月日') & set(string2)) >= 2: predict_result='others' else: predict_result = grocery.predict(string2) if (str(predict_result) == 'address'): string1 = string1.replace('《', '(') string1 = string1.replace('》', ')') string1 = string1.replace('(', '(') string1 = string1.replace(')', ')') string1 = string1.replace('((','(') if ((not ret_total) or len(string1) > len(ret_total)): ret_total = '' ret_total += string1 else: ret_total += string1 if ')' in ret_total: if '(' not in ret_total: ret_total = ret_total.replace('C', '(') ret_total = re.sub(r'((\w)住所(.*)', '', ret_total) ret_total = re.sub(r'((\w)住房(.*)', '', ret_total) ret_total = re.sub(r'(不作为(.*)', '', ret_total) ret_total = re.sub(r'(有效期(.*)', '', ret_total) ret_total = re.sub(r'(仅限(.*)', '', ret_total) ret_total = re.sub(r'(临时经营(.*)', '', ret_total) ret_total = re.sub(r'(仅限办公(.*)', '', ret_total) ret_total = re.sub(r'(经营场所(.*)', '', ret_total) ret_total = re.sub(r"^[经]*[营]*[场/住]*[所]*", "", ret_total) ret_total = stupid_revise(ret_total) print("Mission complete, it took {:.3f}s".format(time.time() - t)) print('\nRecongition Result:\n') print(ret_total) return output_file,ret_total
class TagPredictor(object): def _custom_tokenize(self, line, **kwargs): try: kwargs["method"] except: method = str(self.kwargs["method"]) else: method = str(kwargs["method"]) if method == "normal": tokens = self.key_ext.calculateTokens(line, doc_len_lower_bound=5, doc_len_upper_bound=500, method="normal") elif method == "processed": tokens = line.split(',') return tokens def __init__(self, *args, **kwargs): self.grocery_name = str(kwargs["grocery_name"]) method = str(kwargs["method"]) train_src = str(kwargs["train_src"]) self.PREFIX = conf.load("predict_label")["prefix"] self.MODEL_DIR = conf.load("predict_label")["model_dir"] self.kwargs = kwargs if method == "normal": self.key_ext = keyExt() self.grocery = Grocery(self.grocery_name, custom_tokenize=self._custom_tokenize) elif method == "jieba": self.grocery = Grocery(self.grocery_name) elif method == "processed": self.grocery = Grocery(self.grocery_name, custom_tokenize=self._custom_tokenize) pass def trainFromDocs(self, *args, **kwargs): model = self.grocery.train(self.kwargs["train_src"]) return model def autoEvaluation(self, *args, **kwargs): prune_threshold = float(kwargs["threshold"]) excluded_labels = kwargs["excluded_labels"] excluded_docs = kwargs["excluded_docs"] train_data = [] with open(self.kwargs["train_src"], 'rb') as f: for line in f: try: line.split('\t', 1)[1] except: continue else: train_data.append( (line.split('\t', 1)[0], line.split('\t', 1)[1].split('\n', 1)[0])) f.close() print "#items before filtering:", len(train_data) print "-- Now we filter out the excluded docs --" train_data = [i for i in train_data if i[1] not in excluded_docs] print "#items after filtering:", len(train_data) print "-- Now we filter out the excluded labels --" train_data = [i for i in train_data if i[0] not in excluded_labels] print "#items after filtering:", len(train_data) n = len(train_data) #number of rows in your dataset indices = range(n) indices = shuffle(indices) train_set = map(lambda x: train_data[x], indices[:n * 10 // 10]) test_set = map(lambda x: train_data[x], indices[:n * 10 // 10]) self.grocery.train(train_set) test_result = self.grocery.test(test_set) print '-- Accuracy after training --' print 'Accuracy, A-0:', test_result low_recall_label = [] for item in test_result.recall_labels.items(): if item[1] < prune_threshold: low_recall_label.append(item[0]) new_train_set = [ item for item in train_set if item[0] not in low_recall_label ] new_test_set = [ item for item in train_set if item[0] not in low_recall_label ] self.grocery.train(new_train_set) new_test_result = self.grocery.test(new_test_set) print '-- Accuracy after training, with low-recall labels (less than', str( prune_threshold * 100), '%) pruned --' print 'Accuracy, A-1:', new_test_result return self.grocery, new_test_result def manualEvaluation(self, *args, **kwargs): n_docs = int(kwargs["n_docs"]) excluded_labels = kwargs["excluded_labels"] excluded_docs = kwargs["excluded_docs"] train_data = [] with open(self.kwargs["train_src"], 'rb') as f: for line in f: try: line.split('\t', 1)[1] except: continue else: train_data.append( (line.split('\t', 1)[0], line.split('\t', 1)[1].split('\n', 1)[0])) f.close() train_data = [ item for item in train_data if item[0] not in excluded_labels ] train_data = [i for i in train_data if i[1] not in excluded_docs] n = len(train_data) #number of rows in your dataset indices = range(n) indices = shuffle(indices) test_set = map(lambda x: train_data[x], indices[0:n_docs]) g = self.loadTrainModel() test_result = g.test(test_set) return test_set, test_result def saveTrainModel(self, *args, **kwargs): self.grocery.save() os.rename( self.PREFIX + self.grocery_name + '_train.svm', self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm') return def loadTrainModel(self, *args, **kwargs): os.rename( self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm', self.PREFIX + self.grocery_name + '_train.svm') self.grocery.load() os.rename( self.PREFIX + self.grocery_name + '_train.svm', self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm') return self.grocery def predict(self, line, **kwargs): tag = self.grocery.predict(line) return tag def test(self, *args, **kwargs): test_src = str(kwargs["test_src"]) test_result = self.grocery.test(test_src) print "Total Accuracy", test_result return test_result
class Solution: def __init__(self): self.grocery = Grocery("sample") self.input = open("../data/train_raw.json", "r") def get_train_test(self): info_list = [] n = 0 train_set_file = open("train_file.json", "w") test_set_file = open("test_file.json", "w") for line in self.input: n += 1 info_dict = json.loads(line.strip()) label = info_dict['label'] info = info_dict['title'] url = info_dict['url'] content = info_dict['content'] info_list.append((label, info, url, content)) num_array = np.random.permutation(n) info_array = np.array(info_list) n = len(info_array) train_set = info_array[num_array][:int(n * 0.7)] test_set = info_array[num_array][int(n * 0.7):] out_dict = {} for label, info, url, content in train_set: out_dict['url'] = url out_dict['label'] = label out_dict['info'] = info out_dict['content'] = content print >> train_set_file, json.dumps(out_dict) out_dict = {} for label, info, url, content in test_set: out_dict['url'] = url out_dict['label'] = label out_dict['info'] = info out_dict['content'] = content print >> test_set_file, json.dumps(out_dict) train_set_file.close() test_set_file.close() def train(self): #train_file = open("train_file.json", "r") #test_file = open("test_file.json", "r") train_file = open("../THUCTC/train_7_3.json", "r") test_file = open("../THUCTC/test_7_3.json", "r") train_src = [] test_src = [] for line in train_file: in_dict = json.loads(line.strip()) try: #train_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2])) train_src.append((in_dict['label'], in_dict['title'])) except Exception as e: pass self.grocery.train(train_src) for line in test_file: in_dict = json.loads(line.strip()) try: #test_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2])) test_src.append((in_dict['label'], in_dict['title'])) except Exception as e: pass print len(train_src) self.grocery.save() print self.grocery.test(test_src) """
# -*- coding: utf-8 -*- # # 训练样本 # Author: Alex # Created Time: 2016年12月29日 星期四 11时19分06秒 from tgrocery import Grocery gr = Grocery('test') train_file = "./data/train.csv" gr.train(train_src=train_file) gr.save()
def main(): # Get market_sentiment of word from NTUSD-Fin train_t = [] train_s = [] targetIn = {} targetDict = dict() with open('NTUSD-Fin/NTUSD_Fin_hashtag_v1.0.json', 'r', encoding='utf-8') as f: targetIn = json.load(f) N = len(targetIn) for i in range(N): word = "#" + targetIn[i]['token'] targetDict[word] = targetIn[i]['market_sentiment'] sg = str(GroupValue_s(str(targetDict[word] / 3.5))) train_s.append((sg, word)) with open('NTUSD-Fin/NTUSD_Fin_word_v1.0.json', 'r', encoding='utf-8') as f: targetIn = json.load(f) N = len(targetIn) for i in range(N): word = targetIn[i]['token'] targetDict[word] = targetIn[i]['market_sentiment'] sg = str(GroupValue_s(str(targetDict[word] / 3.5))) train_s.append((sg, word)) # Training File: Load data & Use tgrocery to train classification model TrainingFile = open('training_set.json', 'r') TrainingData = json.load(TrainingFile) TrainingFile.close() DataList = [] grocery_t = Grocery("tweet") grocery_s = Grocery("snippet") for DataElement in TrainingData: tempt = DataManager() tempt.insertData(DataElement) tempt.group_t = GroupValue_t(tempt.sentiment) tempt.group_s = GroupValue_s(tempt.sentiment) line = re.sub("https?://[\w\-]+(\.[\w\-]+)+\S*", " ", DataElement["tweet"]) train_t.append((str(tempt.group_t), line)) if isinstance(DataElement["snippet"], list): for line in DataElement["snippet"]: train_s.append((str(tempt.group_s), line)) elif DataElement["snippet"] != "": train_s.append((str(tempt.group_s), DataElement["snippet"])) else: tempt.group_s = 0.0 DataList.append(tempt) grocery_t.train(train_t + train_s) grocery_t.save() grocery_s.train(train_s) grocery_s.save() # Save training data created by WordScore() and GroupValue_*() # Data will be uesd for LinearRegression() in BOTH.py outfile = open('TG_train.txt', 'w', encoding='utf-8') dataScore = [] dataSentiment = [] for row in DataList: dataSentiment.append([float(row.sentiment)]) a = WordScore(row.tweet, targetDict) b = WordScore(row.snippet, targetDict) c = row.group_t d = row.group_s dataScore.append([a, b, c, d]) print(a, b, c, d, file=outfile) outfile.close() ''' # Train linear regression model model = LinearRegression() model.fit(dataScore, dataSentiment) # Test for training data print('(train)R-squared: %.3f' % model.score(dataScore, dataSentiment)) #0.915 predictions = model.predict(dataScore) rms = mean_squared_error(dataSentiment,predictions) print('RMSE: %.3f' % sqrt(rms)) #0.110 print('MSE: %.3f' % rms) #0.012 ''' # Testing File: Load data & Use tgrocery classification model to predict TestingFile = open('test_set.json', 'r') TestingData = json.load(TestingFile) TestingFile.close() DataList = [] new_grocery_t = Grocery('tweet') new_grocery_t.load() new_grocery_s = Grocery('snippet') new_grocery_s.load() for DataElement in TestingData: tempt = DataManager() tempt.insertData(DataElement) line = re.sub("https?://[\w\-]+(\.[\w\-]+)+\S*", " ", DataElement["tweet"]) tempt.group_t = float('{0}'.format(new_grocery_t.predict(line))) value = 0.0 if isinstance(DataElement["snippet"], list): for line in DataElement["snippet"]: value = value + float('{0}'.format( new_grocery_s.predict(line))) value = value / len(DataElement["snippet"]) elif DataElement["snippet"] != "": value = float('{0}'.format( new_grocery_s.predict(DataElement["snippet"]))) tempt.group_s = value DataList.append(tempt) # Save testing data created by WordScore() and classification prediction # Data will be uesd for LinearRegression() in BOTH.py outfile = open('TG_test.txt', 'w', encoding='utf-8') dataScore = [] dataSentiment = [] for row in DataList: dataSentiment.append([float(row.sentiment)]) a = WordScore(row.tweet, targetDict) b = WordScore(row.snippet, targetDict) c = row.group_t d = row.group_s dataScore.append([a, b, c, d]) print(a, b, c, d, file=outfile) outfile.close() '''
def train(path,name): grocery = Grocery(name) grocery.train(path) grocery.save()