def train(self): db = ArticleDB() # 导入数据 print "reading corpus.txt ..." corpus_name = self.proj_name + "/seg_join/corpus_pos.txt" with open(corpus_name, "r") as corpus: self.texts = corpus.readlines() sql = "select id, category from %s" % self.proj_name results = db.execute(sql) db.close() self.labels = [row[1] for row in results] # 添加标题和标签特征 print "reading title & tags..." doc_num = len(self.labels) for i in xrange(doc_num): tt_name = "%s/tt/%d" % (self.proj_name, i + 1) with open(tt_name, "r") as tt_file: title = tt_file.readline() tags = tt_file.readline() self.texts[i] += (" " + title) * 5 + (" " + tags) * 3 # 切分训练数据和测试数据 x = self.texts y = self.labels x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1) # 训练 print "training..." self.pipeline.fit(x_train, y_train) # 测试 print "testing..." y_pred = self.pipeline.predict(x_test) print(metrics.classification_report(y_test, y_pred)) # 生成heatmap图 print "drawing..." num = 16 m = np.zeros([num, num]) for x, y in zip(y_pred, y_test): m[x - 1, y - 1] += 1 for y in xrange(num): total = sum(m[:, y]) if total > 0: m[:, y] /= total category = Category() label_list = [category.n2c[i + 1] for i in xrange(1, 17)] index = pd.Index(label_list, dtype=str) df = pd.DataFrame(m, index=index, columns=label_list) sns.set(style="white") f, ax = plt.subplots(figsize=(11, 9)) cmap = sns.diverging_palette(220, 10, as_cmap=True) sns.heatmap(m, cmap=cmap, vmax=1.0, square=True, linewidths=.5, cbar_kws={"shrink": .5}, ax=ax) plt.xticks(rotation=-90) plt.yticks(rotation=0) f.savefig("tfidf_clf_result.png")
def tranform(self): # 分析每篇文章的主题分布,并保存磁盘作为特征 corpus_vecs = [] for i, doc_bow in enumerate(self.corpus_bow): print "infer topic vec: %d/%d" % (i + 1, self.doc_num) topic_id_weights = self.lda.get_document_topics( doc_bow, minimum_probability=-1.0) topic_weights = [item[1] for item in topic_id_weights] corpus_vecs.append(topic_weights) obj_name = self.obj_dir + str(i + 1) Dumper.save(topic_weights, obj_name) cluster_num1 = 10 cluster_num2 = 5 category_offset = 0 # 第一次聚类 print "first clustering..." corpus_vecs = np.asarray(corpus_vecs) clt = KMeans(n_clusters=cluster_num1) clt.fit(corpus_vecs) # 第一次聚类结果写入mysql print "writing clustering result to mysql..." db = ArticleDB() for i in xrange(self.doc_num): db.execute("update %s set category1=%d where id=%d" % (self.proj_name, clt.labels_[i], i + 1)) category_offset += cluster_num1 # 按照第一次聚类结果,对文章分组 clusters = [[] for i in xrange(cluster_num1)] for i in xrange(self.doc_num): clusters[clt.labels_[i]].append(i + 1) # 第二次聚类(分组进行) for i in xrange(cluster_num1): print "second clustering: %d/%d ..." % (i + 1, cluster_num1) # 第二次聚类 sub_vecs = [corpus_vecs[j - 1] for j in clusters[i]] clt = KMeans(n_clusters=cluster_num2) clt.fit(sub_vecs) # 第二次聚类结果写入mysql print "writing clustering result to mysql..." for j in xrange(len(clusters[i])): db.execute("update %s set category2=%d where id=%d" % (self.proj_name, category_offset + clt.labels_[j], clusters[i][j])) # 类别ID起始编码 category_offset += cluster_num2 db.commit() db.close() print "ok, successfully complete!"
def set_subcategory(): proj_name = "article_cat" db = ArticleDB() cats = db.execute("select distinct category from %s order by category" % proj_name) cats = [row[0] for row in cats] for cat in cats: with open("tags/tag_%d.txt" % cat, "w") as tag_file: ids = db.execute( "select id from %s where category=%d order by id" % (proj_name, cat)) ids = [row[0] for row in ids] tag_dict = defaultdict(lambda: 0) has_tag = 0 for id in ids: attr_name = "%s/attr/%d" % (proj_name, id) with open(attr_name, "r") as attr_file: lines = attr_file.readlines() tags = lines[3].strip() if len(tags) > 0: has_tag += 1 tags = tags.split(" ") for tag in tags: tag_dict[tag] += 1 id_num = len(ids) tag_list = tag_dict.items() tag_list.sort(key=lambda x: x[1], reverse=True) tag_list = ["%s\t%d\n" % (tag, num) for tag, num in tag_list] tag_file.writelines(tag_list) print("category %d: %d/%d has tags" % (cat, has_tag, id_num)) db.close()
def train(self): # 获取文章总数 db = ArticleDB() id_cats = db.execute("select id, category from %s" % self.project_name) ids = [row[0] for row in id_cats] cats = [row[1] for row in id_cats] db.close() freq_util = FreqCharUtil() # 导入数据 print "reading corpus.txt ..." corpus = [] mms = [0.0] * freq_util.freq_char_num for id, cat in zip(ids, cats): txt_name = self.project_name + "/txt/" + str(id) with open(txt_name, "r") as txt_file: text = txt_file.read() vec = freq_util.get_vec(text) for i in xrange(freq_util.freq_char_num): mms[i] = max(mms[i], vec[i]) corpus.append(vec) for vec in corpus: for i in xrange(freq_util.freq_char_num): vec[i] /= mms[i] # 切分训练和测试数据 x = np.asarray(corpus) y = np.asarray(cats) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1) # 训练 print "training classifier using doc2vec features..." self.clf.fit(x_train, y_train) # 测试 print "testing classifier..." y_pred = self.clf.predict(x_test) print(metrics.classification_report(y_test, y_pred))
def train(self): # 导入数据 db = ArticleDB() # 训练doc2vec print "training dod2vec model..." corpus_name = self.project_name + "/seg_join/corpus.txt" if not os.path.exists('./news.d2v'): self.model = Doc2Vec(min_count=1, window=10, size=400, sample=1e-4, negative=5, workers=8) sources = {corpus_name: 'TRAIN'} sentences = LabeledLineSentence(sources) self.model.build_vocab(sentences.to_array()) for epoch in range(10): self.model.fit(sentences.sentences_perm()) self.model.save('./news.d2v') else: self.model = Doc2Vec.load('./news.d2v') # 切分训练数据和测试数据 print "transform docs to vecs..." results = db.execute("select id, category from %s" % self.project_name) y = [row[1] for row in results] results = db.execute("select count(id) from %s" % self.project_name) doc_num = results[0][0] db.close() x = [self.model.docvecs["TRAIN_%d" % i] for i in range(doc_num)] Dumper.save(x, "doc_vec_all.dat") x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1) # 训练 print "training classifier using doc2vec features..." self.clf.fit(x_train, y_train) # 测试 print "testing classifier..." y_pred = self.clf.predict(x_test) print(metrics.classification_report(y_test, y_pred)) # 生成heatmap图 num = 16 m = np.zeros([num, num]) for x, y in zip(y_pred, y_test): m[x - 1, y - 1] += 1 for y in xrange(num): total = sum(m[:, y]) if total > 0: m[:, y] /= total label_list = [Category.n2c[i+1] for i in xrange(1, 17)] print label_list index = pd.Index(label_list, dtype=str) df = pd.DataFrame(m, index=index, columns=label_list) sns.set(style="white") f, ax = plt.subplots(figsize=(11, 9)) cmap = sns.diverging_palette(220, 10, as_cmap=True) sns.heatmap(df, cmap=cmap, vmax=1.0, square=True, linewidths=.5, cbar_kws={"shrink": .5}, ax=ax) plt.xticks(rotation=-90) plt.yticks(rotation=0) f.savefig("doc2vec_clf_result.png")
def __init__(self, proj_name, rm_exist=False): # 创建分词文件存放目录,已存在则删除重建 self.proj_name = proj_name self.seg_dir = proj_name + "/seg" self.seg_join_dir = proj_name + "/seg_join" self.tt_dir = proj_name + "/tt" # title和tags的分词 self.seg_pos_dir = proj_name + "/seg_pos" # 词性过滤的分词结果 if rm_exist: shutil.rmtree(self.seg_dir, ignore_errors=True) shutil.rmtree(self.seg_join_dir, ignore_errors=True) shutil.rmtree(self.tt_dir, ignore_errors=True) shutil.rmtree(self.seg_pos_dir, ignore_errors=True) try: os.makedirs(self.seg_dir) os.makedirs(self.seg_join_dir) os.makedirs(self.tt_dir) os.makedirs(self.seg_pos_dir) except: pass # 分句符号 self.stop_sent = StopSent() # 停用词 self.stop_word = StopWord("./stopwords_general.txt") # 词性过滤 self.pos_filter = PosFilter() # 连接数据库,获取文章总数 db = ArticleDB() try: sql = "select count(*) from %s" % proj_name self.doc_count = db.execute(sql)[0][0] db.close() except MySQLdb.Error, e: print "Mysql Error %d: %s" % (e.args[0], e.args[1])
def generate_train_test_small(): proj_name = "article_cat" db = ArticleDB() results = db.execute("select distinct category from article_cat") categories = [row[0] for row in results] sorted(categories) train_ids = [] test_ids = [] for cat in categories: results = db.execute( "select id, category from article_cat where category=%d limit 0,150" % cat) for item, row in enumerate(results): if item < 100: train_ids.append((row[0], row[1])) else: test_ids.append((row[0], row[1])) db.close() max_len = 0 with open("%s/seg_join/corpus.txt" % proj_name, "r") as corpus_file, \ open("%s/seg_join/text_train.csv" % proj_name, "w") as train_file, \ open("%s/seg_join/text_test.csv" % proj_name, "w") as test_file: corpus = corpus_file.readlines() train_sentences = [ '%d,%d,"%s"\n' % (cat, id, corpus[id - 1].strip()) for id, cat in train_ids ] test_sentences = [ '%d,%d,"%s"\n' % (cat, id, corpus[id - 1].strip()) for id, cat in test_ids ] train_file.writelines(train_sentences) test_file.writelines(test_sentences) for sen in train_sentences: cur_len = len(sen.split()) if cur_len > max_len: max_len = cur_len for sen in test_sentences: cur_len = len(sen.split()) if cur_len > max_len: max_len = cur_len print "max_words_in_sentence: %d" % max_len
def generate_train_test(): proj_name = "article_cat" db = ArticleDB() results = db.execute("select count(*) from article_cat") doc_num = results[0][0] results = db.execute("select id, category from article_cat order by id") db.close() train_num = int(doc_num * 0.8) test_num = int(doc_num * 0.2) max_len = 0 with open("%s/seg_join/corpus.txt" % proj_name, "r") as corpus_file, \ open("C:/Users/text_train.csv", "w") as train_file, \ open("C:/Users/text_test.csv", "w") as test_file: corpus = corpus_file.readlines() new_corpus = [] for line, row in zip(corpus, results): new_corpus.append('%d,%d,"%s"\n' % (row[1], row[0], line.strip())) from random import shuffle shuffle(new_corpus) train = new_corpus[0:train_num] test = new_corpus[train_num:train_num + test_num] train_file.writelines(train) test_file.writelines(test) for sen in train: cur_len = len(sen.split()) if cur_len > max_len: max_len = cur_len for sen in test: cur_len = len(sen.split()) if cur_len > max_len: max_len = cur_len print "max_words_in_sentence: %d" % max_len
def tranform(self): # 分析每篇文章的主题 db = ArticleDB() topic_doc_ids = [[] for i in xrange(self.topic_num)] topic_docs = [[] for i in xrange(self.topic_num)] for i, doc_bow in enumerate(self.corpus_bow): doc_topics = self.lda[doc_bow] if len(doc_topics) == 0: print "no topics for doc %d " % i + 1 continue topic_item = max(doc_topics, key=lambda topic_item: topic_item[1]) topic_id = topic_item[0] topic_doc_ids[topic_id].append(i + 1) topic_docs[topic_id].append(self.corpus[i]) db.execute("update %s set lda_category1=%d where id = %d" % (self.proj_name, topic_id, i + 1)) db.commit() db.close() # 对分组内文章再次进行主题分析 topic_offset = self.topic_num for topic_fid in xrange(self.topic_num): sub_ids = topic_doc_ids[topic_fid] sub_corpus = topic_docs[topic_fid] # 生成字典 print "creating dictionary" sub_id2word = corpora.Dictionary(sub_corpus) # 删除低频词 # ignore words that appear in less than 20 documents or more than 10% documents # id2word.filter_extremes(no_below=20, no_above=0.1) # 词频统计,转化成空间向量格式 print "tranforming doc to vector" sub_corpus_bow = [ sub_id2word.doc2bow(doc_bow) for doc_bow in sub_corpus ] # 训练LDA模型 print "training lda model" sub_lda_model_name = "lda_models/lda_%d.dat" % topic_fid if not os.path.exists(sub_lda_model_name): sub_lda = LdaModel(corpus=sub_corpus_bow, id2word=sub_id2word, num_topics=self.sub_topic_num, alpha='auto') Dumper.save(sub_lda, sub_lda_model_name) else: sub_lda = Dumper.load(sub_lda_model_name) # 给每个主题起名字 sub_topics = sub_lda.show_topics(num_topics=self.sub_topic_num, num_words=2, log=False, formatted=False) sub_topic_names = [ sub_topic[1][0][0] + "+" + sub_topic[1][1][0] for sub_topic in sub_topics ] for i, sub_topic_name in enumerate(sub_topic_names): self.tree.create_node((topic_offset + i, sub_topic_name), topic_offset + i, parent=topic_fid) # 打印识别出的主题 sub_topics = sub_lda.print_topics(num_topics=self.sub_topic_num, num_words=10) for sub_topic in sub_topics: print "sub topic %d: %s" % (sub_topic[0], sub_topic[1].encode("utf-8")) with open("sub_topics_%d.txt" % topic_fid, "w") as topic_file: for sub_topic in sub_topics: print >> topic_file, "topic %d: %s" % ( sub_topic[0], sub_topic[1].encode("utf-8")) # 分析每篇文章的主题 db = ArticleDB() for i, doc_bow in enumerate(sub_corpus_bow): doc_topics = sub_lda[doc_bow] if len(doc_topics) == 0: print "no sub topics for doc %d " % sub_ids[i] continue topic_item = max(doc_topics, key=lambda topic_item: topic_item[1]) topic_id = topic_item[0] db.execute( "update %s set lda_category2=%d where id = %d" % (self.proj_name, topic_offset + topic_id, sub_ids[i])) db.commit() db.close() topic_offset += self.sub_topic_num
def train(self): db = ArticleDB() category = Category() db.execute("update %s set subcluster=null" % self.proj_name) db.commit() # 聚类(for 每个类别) subclt_offset, cat2subclt = read_subclt("subclt") for fcat in subclt_offset.keys(): ids = db.execute("select id from %s where category=%s" % (self.proj_name, fcat)) ids = [row[0] for row in ids] if len(ids) < 10: continue # 读取文本 print "category %d: reading corpus..." % fcat x = [] for id in ids: seg_name = "%s/seg/%d" % (self.proj_name, id) with open(seg_name, "r") as seg_file: lines = [ line.strip() for line in seg_file.readlines() if len(line.strip()) > 0 ] text = " ".join(lines) text = text.split() x.append(text) # dict print "category %d: dictionary..." % fcat id2word = corpora.Dictionary(x) # bow print "category %d: bag of word..." % fcat corpus_bow = [id2word.doc2bow(doc) for doc in x] # lda print "category %d: lda modeling..." % fcat lda = LdaModel(corpus=corpus_bow, id2word=id2word, num_topics=5, alpha='auto') # show topics print "category %d: 【%s】show topics..." % (fcat, category.n2c[fcat]) topics = lda.print_topics(num_topics=-1, num_words=10) for topic in topics: print "topic %d: %s" % (topic[0], topic[1].encode("utf-8")) # 写回SQL # print "category %d: writing cluster result to sql..." % fcat # offset = subclt_offset[fcat] # for i, id in enumerate(ids): # db.execute("update %s set subcluster=%d where id=%d" % (self.proj_name, offset + clt.labels_[i], id)) db.commit() db.close() # 全部结束 print "OK, all done!"
def train(self): db = ArticleDB() db.execute("update %s set subcluster=null" % self.proj_name) db.commit() # 聚类(for 每个类别) subclt_offset, cat2subclt = read_subclt("subclt") for fcat in subclt_offset.keys(): ids = db.execute("select id from %s where category=%s" % (self.proj_name, fcat)) ids = [row[0] for row in ids] if len(ids) < 10: continue # 读取文本 print "category %d: reading corpus..." % fcat x = [] for id in ids: seg_name = "%s/seg/%d" % (self.proj_name, id) with open(seg_name, "r") as seg_file: lines = [ line.strip() for line in seg_file.readlines() if len(line.strip()) > 0 ] text = " ".join(lines) x.append(text) # tfidf计算 print "category %d: calc tfidf..." % fcat vectorizer = TfidfVectorizer() x = vectorizer.fit_transform(x) # 降维 # print "category %d: decomposition..." % fcat # svd = TruncatedSVD(1000) # normalizer = Normalizer(copy=False) # lsa = make_pipeline(svd, normalizer) # x = lsa.fit_transform(x) # 选择合适的k # range_n_clusters = [4, 6, 8, 10, 12, 14, 16] # silhouette_avgs = [] # for n_clusters in range_n_clusters: # clusterer = KMeans(n_clusters=n_clusters, random_state=10) # cluster_labels = clusterer.fit_predict(x) # silhouette_avg = silhouette_score(x, cluster_labels) # print("For n_clusters =", n_clusters, "The average silhouette_score is :", silhouette_avg) # silhouette_avgs.append(silhouette_avg) # max_idx = np.argmax(silhouette_avgs) # cluster_num = range_n_clusters[max_idx] # 训练 subcats = cat2subclt[fcat] cluster_num = len(subcats) print "category %d: clustering (n_cluster=%d)..." % (fcat, cluster_num) # cluster_num = len(cat2subclt[fcat]) clt = KMeans(n_clusters=cluster_num) clt.fit(x) # 写回SQL print "category %d: writing cluster result to sql..." % fcat offset = subclt_offset[fcat] for i, id in enumerate(ids): db.execute("update %s set subcluster=%d where id=%d" % (self.proj_name, offset + clt.labels_[i], id)) # 寻找分类关键词 # feature_names = vectorizer.get_feature_names() # y = clt.labels_ # ch2 = SelectKBest(chi2, k=20) # x = ch2.fit(x, y) # key_words = [feature_names[i] for i in ch2.get_support(indices=True)] # key_words = " ".join(key_words) # print "关键词:" # print key_words # 寻找簇的名称 print "category %d: reading corpus..." % fcat x = [] for id in ids: seg_name = "%s/seg/%d" % (self.proj_name, id) with open(seg_name, "r") as seg_file: lines = [ line.strip() for line in seg_file.readlines() if len(line.strip()) > 0 ] text = " ".join(lines) x.append(text) text_group = [""] * cluster_num doc_num = len(ids) for i in xrange(doc_num): text_group[clt.labels_[i]] += x[i] vectorizer = TfidfVectorizer() matrix = vectorizer.fit_transform(text_group) feature_names = vectorizer.get_feature_names() for i in xrange(cluster_num): row = matrix[i].toarray().flatten() idxs = np.argsort(row) idxs = idxs[::-1] idxs = idxs[:20] keywords = [feature_names[j] for j in idxs] keywords = " ".join(keywords) print "subcluster-", offset + i, subcats[ i].name, ":\t", keywords db.commit() db.close() # 全部结束 print "OK, all done!"
from flask import Flask, render_template, request, jsonify, render_template, url_for, send_from_directory, redirect from flask import send_file import time from myutils import Category, CompareUnit, read_subcat, read_subclt import sys from myutils import ArticleDB from myutils import TopkHeap, Dumper from sklearn.metrics.pairwise import cosine_similarity from treelib import Tree reload(sys) sys.setdefaultencoding('utf-8') app = Flask(__name__) db = ArticleDB() project_name = "article150801160830" # project_name = "article_cat" txt_dir = project_name + "/txt/" attr_dir = project_name + "/attr/" category_dict = Category() thumbs = [ "063226153482.jpg", "071142757085.jpg", "072728052607.jpg", "073704072438.jpg", "081042492216.jpg", "090854778371.jpeg", "092220134851.jpg", "102652963888.png", "103550760868.jpg", "111531748717.jpg", "112713761995.png", "123006128256.jpg", "123641032783.jpg", "125123423041.jpg", "125425018839.png", "143948746006.jpg", "144655265068.jpg", "152502600146.jpg", "160004485992.png", "171128413415.jpg", "210804890144.png", "220548955698.jpg", "224636881111.jpg", "230826808516.png"
def train_label(): db = ArticleDB() test_proj_name = "article150801160830" test_seg_dir = test_proj_name + "/seg/" test_obj_dir = test_proj_name + "/clf_tfidf/" shutil.rmtree(test_obj_dir, ignore_errors=True) os.makedirs(test_obj_dir) corpus = [] # 分类器训练 print "1. trainning tfidf clf..." # clf = TextClassifierTfidf(proj_name="article_cat") # clf.train() # Dumper.save(clf, "tfidf_clf.dat") clf = Dumper.load("tfidf_clf.dat") # 获取测试语料库文档数量 results = db.execute("select count(id) from %s" % test_proj_name) test_count = results[0][0] # 读取测试语料库 print "2. reading corpus..." for i in xrange(1, test_count + 1): test_seg_name = test_seg_dir + str(i) with open(test_seg_name, "r") as test_seg_file: lines = [ line.strip() for line in test_seg_file.readlines() if len(line.strip()) > 0 ] doc = " ".join(lines) corpus.append(doc) # 预测测试语料库category print "3. predicting corpus" categories = clf.predict(corpus) # 预测测试语料库category print "4 predicting corpus (probability)" categories_probs = clf.predict_proba(corpus) # 往数据库写入category属性 print "5. writing corpus prediction to mysql..." for i, a_category in enumerate(categories): a_id = i + 1 sql = "update %s set category=%d where id=%s" % (test_proj_name, a_category, a_id) db.execute(sql) print "6. writing corpus prediction probability to mysql..." for i, a_category_probs in enumerate(categories_probs): a_id = i + 1 set_statements = [] for class_label, a_category_prob in zip( clf.pipeline.named_steps["clf"].classes_, a_category_probs): set_statements.append("p%d=%f" % (class_label, a_category_prob)) set_statements = ", ".join(set_statements) sql = "update %s set %s where id=%s" % (test_proj_name, set_statements, a_id) db.execute(sql) db.commit() db.close() # 保存测试语料库到对象二进制文件中 print "5. generating corpus tfidf vectors..." tfidf_vectors = clf.transform(corpus) print "6. writing corpus tfidf vectors to disk..." for i in xrange(test_count): print "%d/%d" % (i + 1, test_count) obj_name = test_obj_dir + str(i + 1) Dumper.save(tfidf_vectors[i, :], obj_name) print "ok, successfully complete!"
def train(self): # 读取子类分类规格文件 cat2subcat, tag2id = read_subcat(self.subcat_profile) # 标注训练数据 print "dividing category into sub-categories..." db = ArticleDB() db.execute("update %s set subcategory=null" % self.project_name) db.commit() for fcat, subcats in cat2subcat.items(): ids = db.execute("select id from %s where category=%s" % (self.project_name, fcat)) ids = [id[0] for id in ids] print "\tcategory %d has %d files, divide into %d subcategories" % (fcat, len(ids), len(subcats)) for id in ids: attr_name = "%s/attr/%d" % (self.project_name, id) with open(attr_name, "r") as attr_file: tags = attr_file.readlines()[3].strip() if len(tags) > 0: tags = tags.split(" ") subtag2id = tag2id[fcat] for tag in tags: if tag in subtag2id: subcat = subtag2id[tag] db.execute("update %s set subcategory=%d where id=%d" % (self.project_name, subcat, id)) break db.commit() # 训练分类器(对每个类别,利用其子类分类) fcats = cat2subcat.keys() for fcat in fcats: id_subcats = db.execute("select id, subcategory from %s where category=%s and subcategory is not null" % (self.project_name, fcat)) ids = [row[0] for row in id_subcats] y = [row[1] for row in id_subcats] x = [] print "category %d: reading corpus..." % fcat for id in ids: seg_name = "%s/seg/%d" % (self.project_name, id) with open(seg_name, "r") as seg_file: lines = [line.strip() for line in seg_file.readlines() if len(line.strip()) > 0] text = " ".join(lines) x.append(text) # 切分训练数据和测试数据 print "category %d: splitting train and test..." % fcat x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1) # 训练 print "category %d: training..." % fcat self.pipeline.fit(x_train, y_train) # 测试 print "category %d: testing..." % fcat y_pred = self.pipeline.predict(x_test) print(metrics.classification_report(y_test, y_pred)) # 预测 test_proj_name = "article150801160830" ids = db.execute("select id from %s where category=%d" % (test_proj_name, fcat)) ids = [row[0] for row in ids] new_x = [] print "category %d: predicting new corpus..." % fcat for id in ids: seg_name = "%s/seg/%d" % (test_proj_name, id) with open(seg_name, "r") as seg_file: lines = [line.strip() for line in seg_file.readlines() if len(line.strip()) > 0] text = " ".join(lines) new_x.append(text) new_y = self.pipeline.predict(new_x) print "category %d: writing predict result to sql..." % fcat for id, label in zip(ids, new_y): db.execute("update %s set subcategory=%d where id=%d" % (test_proj_name, label, id)) db.commit() db.close() # 全部结束 print "OK, all done!"