def train(self):
        db = ArticleDB()
        # 导入数据
        print "reading corpus.txt ..."
        corpus_name = self.proj_name + "/seg_join/corpus_pos.txt"
        with open(corpus_name, "r") as corpus:
            self.texts = corpus.readlines()
        sql = "select id, category from %s" % self.proj_name
        results = db.execute(sql)
        db.close()
        self.labels = [row[1] for row in results]

        # 添加标题和标签特征
        print "reading title & tags..."
        doc_num = len(self.labels)
        for i in xrange(doc_num):
            tt_name = "%s/tt/%d" % (self.proj_name, i + 1)
            with open(tt_name, "r") as tt_file:
                title = tt_file.readline()
                tags = tt_file.readline()
                self.texts[i] += (" " + title) * 5 + (" " + tags) * 3

        # 切分训练数据和测试数据
        x = self.texts
        y = self.labels
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)

        # 训练
        print "training..."
        self.pipeline.fit(x_train, y_train)

        # 测试
        print "testing..."
        y_pred = self.pipeline.predict(x_test)
        print(metrics.classification_report(y_test, y_pred))

        # 生成heatmap图
        print "drawing..."
        num = 16
        m = np.zeros([num, num])
        for x, y in zip(y_pred, y_test):
            m[x - 1, y - 1] += 1
        for y in xrange(num):
            total = sum(m[:, y])
            if total > 0:
                m[:, y] /= total
        category = Category()
        label_list = [category.n2c[i + 1] for i in xrange(1, 17)]
        index = pd.Index(label_list, dtype=str)
        df = pd.DataFrame(m, index=index, columns=label_list)
        sns.set(style="white")
        f, ax = plt.subplots(figsize=(11, 9))
        cmap = sns.diverging_palette(220, 10, as_cmap=True)
        sns.heatmap(m, cmap=cmap, vmax=1.0,
                    square=True,
                    linewidths=.5, cbar_kws={"shrink": .5}, ax=ax)
        plt.xticks(rotation=-90)
        plt.yticks(rotation=0)
        f.savefig("tfidf_clf_result.png")
    def tranform(self):
        # 分析每篇文章的主题分布,并保存磁盘作为特征
        corpus_vecs = []
        for i, doc_bow in enumerate(self.corpus_bow):
            print "infer topic vec: %d/%d" % (i + 1, self.doc_num)
            topic_id_weights = self.lda.get_document_topics(
                doc_bow, minimum_probability=-1.0)
            topic_weights = [item[1] for item in topic_id_weights]
            corpus_vecs.append(topic_weights)
            obj_name = self.obj_dir + str(i + 1)
            Dumper.save(topic_weights, obj_name)

        cluster_num1 = 10
        cluster_num2 = 5
        category_offset = 0
        # 第一次聚类
        print "first clustering..."
        corpus_vecs = np.asarray(corpus_vecs)
        clt = KMeans(n_clusters=cluster_num1)
        clt.fit(corpus_vecs)

        # 第一次聚类结果写入mysql
        print "writing clustering result to mysql..."
        db = ArticleDB()
        for i in xrange(self.doc_num):
            db.execute("update %s set category1=%d where id=%d" %
                       (self.proj_name, clt.labels_[i], i + 1))
        category_offset += cluster_num1

        # 按照第一次聚类结果,对文章分组
        clusters = [[] for i in xrange(cluster_num1)]
        for i in xrange(self.doc_num):
            clusters[clt.labels_[i]].append(i + 1)

        # 第二次聚类(分组进行)
        for i in xrange(cluster_num1):
            print "second clustering: %d/%d ..." % (i + 1, cluster_num1)
            # 第二次聚类
            sub_vecs = [corpus_vecs[j - 1] for j in clusters[i]]
            clt = KMeans(n_clusters=cluster_num2)
            clt.fit(sub_vecs)

            # 第二次聚类结果写入mysql
            print "writing clustering result to mysql..."
            for j in xrange(len(clusters[i])):
                db.execute("update %s set category2=%d where id=%d" %
                           (self.proj_name, category_offset + clt.labels_[j],
                            clusters[i][j]))

            # 类别ID起始编码
            category_offset += cluster_num2

        db.commit()
        db.close()
        print "ok, successfully complete!"
Пример #3
0
def set_subcategory():
    proj_name = "article_cat"
    db = ArticleDB()
    cats = db.execute("select distinct category from %s order by category" %
                      proj_name)
    cats = [row[0] for row in cats]
    for cat in cats:
        with open("tags/tag_%d.txt" % cat, "w") as tag_file:
            ids = db.execute(
                "select id from %s where category=%d order by id" %
                (proj_name, cat))
            ids = [row[0] for row in ids]
            tag_dict = defaultdict(lambda: 0)
            has_tag = 0
            for id in ids:
                attr_name = "%s/attr/%d" % (proj_name, id)
                with open(attr_name, "r") as attr_file:
                    lines = attr_file.readlines()
                    tags = lines[3].strip()
                    if len(tags) > 0:
                        has_tag += 1
                        tags = tags.split(" ")
                        for tag in tags:
                            tag_dict[tag] += 1

            id_num = len(ids)
            tag_list = tag_dict.items()
            tag_list.sort(key=lambda x: x[1], reverse=True)
            tag_list = ["%s\t%d\n" % (tag, num) for tag, num in tag_list]
            tag_file.writelines(tag_list)
            print("category %d: %d/%d has tags" % (cat, has_tag, id_num))
    db.close()
    def train(self):
        # 获取文章总数
        db = ArticleDB()
        id_cats = db.execute("select id, category from %s" % self.project_name)
        ids = [row[0] for row in id_cats]
        cats = [row[1] for row in id_cats]
        db.close()
        freq_util = FreqCharUtil()

        # 导入数据
        print "reading corpus.txt ..."
        corpus = []
        mms = [0.0] * freq_util.freq_char_num
        for id, cat in zip(ids, cats):
            txt_name = self.project_name + "/txt/" + str(id)
            with open(txt_name, "r") as txt_file:
                text = txt_file.read()
                vec = freq_util.get_vec(text)
                for i in xrange(freq_util.freq_char_num):
                    mms[i] = max(mms[i], vec[i])
                corpus.append(vec)

        for vec in corpus:
            for i in xrange(freq_util.freq_char_num):
                vec[i] /= mms[i]

        # 切分训练和测试数据
        x = np.asarray(corpus)
        y = np.asarray(cats)
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)

        # 训练
        print "training classifier using doc2vec features..."
        self.clf.fit(x_train, y_train)

        # 测试
        print "testing classifier..."
        y_pred = self.clf.predict(x_test)
        print(metrics.classification_report(y_test, y_pred))
    def train(self):
        # 导入数据
        db = ArticleDB()
        # 训练doc2vec
        print "training dod2vec model..."
        corpus_name = self.project_name + "/seg_join/corpus.txt"
        if not os.path.exists('./news.d2v'):
            self.model = Doc2Vec(min_count=1, window=10, size=400, sample=1e-4, negative=5, workers=8)
            sources = {corpus_name: 'TRAIN'}
            sentences = LabeledLineSentence(sources)
            self.model.build_vocab(sentences.to_array())
            for epoch in range(10):
                self.model.fit(sentences.sentences_perm())
                self.model.save('./news.d2v')
        else:
            self.model = Doc2Vec.load('./news.d2v')

        # 切分训练数据和测试数据
        print "transform docs to vecs..."
        results = db.execute("select id, category from %s" % self.project_name)
        y = [row[1] for row in results]
        results = db.execute("select count(id) from %s" % self.project_name)
        doc_num = results[0][0]
        db.close()
        x = [self.model.docvecs["TRAIN_%d" % i] for i in range(doc_num)]
        Dumper.save(x, "doc_vec_all.dat")
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)

        # 训练
        print "training classifier using doc2vec features..."
        self.clf.fit(x_train, y_train)

        # 测试
        print "testing classifier..."
        y_pred = self.clf.predict(x_test)
        print(metrics.classification_report(y_test, y_pred))

        # 生成heatmap图
        num = 16
        m = np.zeros([num, num])
        for x, y in zip(y_pred, y_test):
            m[x - 1, y - 1] += 1
        for y in xrange(num):
            total = sum(m[:, y])
            if total > 0:
                m[:, y] /= total
        label_list = [Category.n2c[i+1] for i in xrange(1, 17)]
        print label_list
        index = pd.Index(label_list, dtype=str)
        df = pd.DataFrame(m, index=index, columns=label_list)
        sns.set(style="white")
        f, ax = plt.subplots(figsize=(11, 9))
        cmap = sns.diverging_palette(220, 10, as_cmap=True)
        sns.heatmap(df, cmap=cmap, vmax=1.0,
                    square=True,
                    linewidths=.5, cbar_kws={"shrink": .5}, ax=ax)
        plt.xticks(rotation=-90)
        plt.yticks(rotation=0)
        f.savefig("doc2vec_clf_result.png")
    def __init__(self, proj_name, rm_exist=False):
        # 创建分词文件存放目录,已存在则删除重建
        self.proj_name = proj_name
        self.seg_dir = proj_name + "/seg"
        self.seg_join_dir = proj_name + "/seg_join"
        self.tt_dir = proj_name + "/tt"  # title和tags的分词
        self.seg_pos_dir = proj_name + "/seg_pos"  # 词性过滤的分词结果
        if rm_exist:
            shutil.rmtree(self.seg_dir, ignore_errors=True)
            shutil.rmtree(self.seg_join_dir, ignore_errors=True)
            shutil.rmtree(self.tt_dir, ignore_errors=True)
            shutil.rmtree(self.seg_pos_dir, ignore_errors=True)

        try:
            os.makedirs(self.seg_dir)
            os.makedirs(self.seg_join_dir)
            os.makedirs(self.tt_dir)
            os.makedirs(self.seg_pos_dir)
        except:
            pass

        # 分句符号
        self.stop_sent = StopSent()

        # 停用词
        self.stop_word = StopWord("./stopwords_general.txt")

        # 词性过滤
        self.pos_filter = PosFilter()

        # 连接数据库,获取文章总数
        db = ArticleDB()
        try:
            sql = "select count(*) from %s" % proj_name
            self.doc_count = db.execute(sql)[0][0]
            db.close()
        except MySQLdb.Error, e:
            print "Mysql Error %d: %s" % (e.args[0], e.args[1])
Пример #7
0
def generate_train_test_small():
    proj_name = "article_cat"
    db = ArticleDB()

    results = db.execute("select distinct category from article_cat")
    categories = [row[0] for row in results]
    sorted(categories)

    train_ids = []
    test_ids = []

    for cat in categories:
        results = db.execute(
            "select id, category from article_cat where category=%d limit 0,150"
            % cat)
        for item, row in enumerate(results):
            if item < 100:
                train_ids.append((row[0], row[1]))
            else:
                test_ids.append((row[0], row[1]))
    db.close()

    max_len = 0
    with open("%s/seg_join/corpus.txt" % proj_name, "r") as corpus_file, \
            open("%s/seg_join/text_train.csv" % proj_name, "w") as train_file, \
            open("%s/seg_join/text_test.csv" % proj_name, "w") as test_file:
        corpus = corpus_file.readlines()
        train_sentences = [
            '%d,%d,"%s"\n' % (cat, id, corpus[id - 1].strip())
            for id, cat in train_ids
        ]
        test_sentences = [
            '%d,%d,"%s"\n' % (cat, id, corpus[id - 1].strip())
            for id, cat in test_ids
        ]
        train_file.writelines(train_sentences)
        test_file.writelines(test_sentences)

        for sen in train_sentences:
            cur_len = len(sen.split())
            if cur_len > max_len:
                max_len = cur_len
        for sen in test_sentences:
            cur_len = len(sen.split())
            if cur_len > max_len:
                max_len = cur_len
        print "max_words_in_sentence: %d" % max_len
Пример #8
0
def generate_train_test():
    proj_name = "article_cat"
    db = ArticleDB()
    results = db.execute("select count(*) from article_cat")
    doc_num = results[0][0]
    results = db.execute("select id, category from article_cat order by id")
    db.close()

    train_num = int(doc_num * 0.8)
    test_num = int(doc_num * 0.2)

    max_len = 0
    with open("%s/seg_join/corpus.txt" % proj_name, "r") as corpus_file, \
            open("C:/Users/text_train.csv", "w") as train_file, \
            open("C:/Users/text_test.csv", "w") as test_file:
        corpus = corpus_file.readlines()
        new_corpus = []
        for line, row in zip(corpus, results):
            new_corpus.append('%d,%d,"%s"\n' % (row[1], row[0], line.strip()))
        from random import shuffle
        shuffle(new_corpus)

        train = new_corpus[0:train_num]
        test = new_corpus[train_num:train_num + test_num]

        train_file.writelines(train)
        test_file.writelines(test)

        for sen in train:
            cur_len = len(sen.split())
            if cur_len > max_len:
                max_len = cur_len
        for sen in test:
            cur_len = len(sen.split())
            if cur_len > max_len:
                max_len = cur_len
        print "max_words_in_sentence: %d" % max_len
Пример #9
0
    def tranform(self):
        # 分析每篇文章的主题
        db = ArticleDB()
        topic_doc_ids = [[] for i in xrange(self.topic_num)]
        topic_docs = [[] for i in xrange(self.topic_num)]
        for i, doc_bow in enumerate(self.corpus_bow):
            doc_topics = self.lda[doc_bow]
            if len(doc_topics) == 0:
                print "no topics for doc %d " % i + 1
                continue
            topic_item = max(doc_topics, key=lambda topic_item: topic_item[1])
            topic_id = topic_item[0]
            topic_doc_ids[topic_id].append(i + 1)
            topic_docs[topic_id].append(self.corpus[i])
            db.execute("update %s set lda_category1=%d where id = %d" %
                       (self.proj_name, topic_id, i + 1))
        db.commit()
        db.close()

        # 对分组内文章再次进行主题分析
        topic_offset = self.topic_num
        for topic_fid in xrange(self.topic_num):
            sub_ids = topic_doc_ids[topic_fid]
            sub_corpus = topic_docs[topic_fid]

            # 生成字典
            print "creating dictionary"
            sub_id2word = corpora.Dictionary(sub_corpus)

            # 删除低频词
            # ignore words that appear in less than 20 documents or more than 10% documents
            # id2word.filter_extremes(no_below=20, no_above=0.1)

            # 词频统计,转化成空间向量格式
            print "tranforming doc to vector"
            sub_corpus_bow = [
                sub_id2word.doc2bow(doc_bow) for doc_bow in sub_corpus
            ]

            # 训练LDA模型
            print "training lda model"
            sub_lda_model_name = "lda_models/lda_%d.dat" % topic_fid
            if not os.path.exists(sub_lda_model_name):
                sub_lda = LdaModel(corpus=sub_corpus_bow,
                                   id2word=sub_id2word,
                                   num_topics=self.sub_topic_num,
                                   alpha='auto')
                Dumper.save(sub_lda, sub_lda_model_name)
            else:
                sub_lda = Dumper.load(sub_lda_model_name)

            # 给每个主题起名字
            sub_topics = sub_lda.show_topics(num_topics=self.sub_topic_num,
                                             num_words=2,
                                             log=False,
                                             formatted=False)
            sub_topic_names = [
                sub_topic[1][0][0] + "+" + sub_topic[1][1][0]
                for sub_topic in sub_topics
            ]
            for i, sub_topic_name in enumerate(sub_topic_names):
                self.tree.create_node((topic_offset + i, sub_topic_name),
                                      topic_offset + i,
                                      parent=topic_fid)

            # 打印识别出的主题
            sub_topics = sub_lda.print_topics(num_topics=self.sub_topic_num,
                                              num_words=10)
            for sub_topic in sub_topics:
                print "sub topic %d: %s" % (sub_topic[0],
                                            sub_topic[1].encode("utf-8"))
            with open("sub_topics_%d.txt" % topic_fid, "w") as topic_file:
                for sub_topic in sub_topics:
                    print >> topic_file, "topic %d: %s" % (
                        sub_topic[0], sub_topic[1].encode("utf-8"))

            # 分析每篇文章的主题
            db = ArticleDB()
            for i, doc_bow in enumerate(sub_corpus_bow):
                doc_topics = sub_lda[doc_bow]
                if len(doc_topics) == 0:
                    print "no sub topics for doc %d " % sub_ids[i]
                    continue
                topic_item = max(doc_topics,
                                 key=lambda topic_item: topic_item[1])
                topic_id = topic_item[0]
                db.execute(
                    "update %s set lda_category2=%d where id = %d" %
                    (self.proj_name, topic_offset + topic_id, sub_ids[i]))
            db.commit()
            db.close()
            topic_offset += self.sub_topic_num
    def train(self):
        db = ArticleDB()
        category = Category()
        db.execute("update %s set subcluster=null" % self.proj_name)
        db.commit()
        # 聚类(for 每个类别)
        subclt_offset, cat2subclt = read_subclt("subclt")
        for fcat in subclt_offset.keys():
            ids = db.execute("select id from %s where category=%s" %
                             (self.proj_name, fcat))
            ids = [row[0] for row in ids]
            if len(ids) < 10:
                continue

            # 读取文本
            print "category %d: reading corpus..." % fcat
            x = []
            for id in ids:
                seg_name = "%s/seg/%d" % (self.proj_name, id)
                with open(seg_name, "r") as seg_file:
                    lines = [
                        line.strip() for line in seg_file.readlines()
                        if len(line.strip()) > 0
                    ]
                    text = " ".join(lines)
                    text = text.split()
                    x.append(text)

            # dict
            print "category %d: dictionary..." % fcat
            id2word = corpora.Dictionary(x)
            # bow
            print "category %d: bag of word..." % fcat
            corpus_bow = [id2word.doc2bow(doc) for doc in x]
            # lda
            print "category %d: lda modeling..." % fcat
            lda = LdaModel(corpus=corpus_bow,
                           id2word=id2word,
                           num_topics=5,
                           alpha='auto')
            # show topics
            print "category %d: 【%s】show topics..." % (fcat,
                                                       category.n2c[fcat])
            topics = lda.print_topics(num_topics=-1, num_words=10)
            for topic in topics:
                print "topic %d: %s" % (topic[0], topic[1].encode("utf-8"))

            # 写回SQL
            # print "category %d: writing cluster result to sql..." % fcat
            # offset = subclt_offset[fcat]
            # for i, id in enumerate(ids):
            #     db.execute("update %s set subcluster=%d where id=%d" % (self.proj_name, offset + clt.labels_[i], id))
        db.commit()
        db.close()

        # 全部结束
        print "OK, all done!"
    def train(self):
        db = ArticleDB()
        db.execute("update %s set subcluster=null" % self.proj_name)
        db.commit()
        # 聚类(for 每个类别)
        subclt_offset, cat2subclt = read_subclt("subclt")
        for fcat in subclt_offset.keys():
            ids = db.execute("select id from %s where category=%s" %
                             (self.proj_name, fcat))
            ids = [row[0] for row in ids]
            if len(ids) < 10:
                continue

            # 读取文本
            print "category %d: reading corpus..." % fcat
            x = []
            for id in ids:
                seg_name = "%s/seg/%d" % (self.proj_name, id)
                with open(seg_name, "r") as seg_file:
                    lines = [
                        line.strip() for line in seg_file.readlines()
                        if len(line.strip()) > 0
                    ]
                    text = " ".join(lines)
                    x.append(text)

            # tfidf计算
            print "category %d: calc tfidf..." % fcat
            vectorizer = TfidfVectorizer()
            x = vectorizer.fit_transform(x)

            # 降维
            # print "category %d: decomposition..." % fcat
            # svd = TruncatedSVD(1000)
            # normalizer = Normalizer(copy=False)
            # lsa = make_pipeline(svd, normalizer)
            # x = lsa.fit_transform(x)

            # 选择合适的k
            # range_n_clusters = [4, 6, 8, 10, 12, 14, 16]
            # silhouette_avgs = []
            # for n_clusters in range_n_clusters:
            #     clusterer = KMeans(n_clusters=n_clusters, random_state=10)
            #     cluster_labels = clusterer.fit_predict(x)
            #     silhouette_avg = silhouette_score(x, cluster_labels)
            #     print("For n_clusters =", n_clusters, "The average silhouette_score is :", silhouette_avg)
            #     silhouette_avgs.append(silhouette_avg)
            # max_idx = np.argmax(silhouette_avgs)
            # cluster_num = range_n_clusters[max_idx]

            # 训练
            subcats = cat2subclt[fcat]
            cluster_num = len(subcats)
            print "category %d: clustering (n_cluster=%d)..." % (fcat,
                                                                 cluster_num)
            # cluster_num = len(cat2subclt[fcat])
            clt = KMeans(n_clusters=cluster_num)
            clt.fit(x)

            # 写回SQL
            print "category %d: writing cluster result to sql..." % fcat
            offset = subclt_offset[fcat]
            for i, id in enumerate(ids):
                db.execute("update %s set subcluster=%d where id=%d" %
                           (self.proj_name, offset + clt.labels_[i], id))

            # 寻找分类关键词
            # feature_names = vectorizer.get_feature_names()
            # y = clt.labels_
            # ch2 = SelectKBest(chi2, k=20)
            # x = ch2.fit(x, y)
            # key_words = [feature_names[i] for i in ch2.get_support(indices=True)]
            # key_words = " ".join(key_words)
            # print "关键词:"
            # print key_words

            # 寻找簇的名称
            print "category %d: reading corpus..." % fcat
            x = []
            for id in ids:
                seg_name = "%s/seg/%d" % (self.proj_name, id)
                with open(seg_name, "r") as seg_file:
                    lines = [
                        line.strip() for line in seg_file.readlines()
                        if len(line.strip()) > 0
                    ]
                    text = " ".join(lines)
                    x.append(text)
            text_group = [""] * cluster_num
            doc_num = len(ids)
            for i in xrange(doc_num):
                text_group[clt.labels_[i]] += x[i]

            vectorizer = TfidfVectorizer()
            matrix = vectorizer.fit_transform(text_group)
            feature_names = vectorizer.get_feature_names()
            for i in xrange(cluster_num):
                row = matrix[i].toarray().flatten()
                idxs = np.argsort(row)
                idxs = idxs[::-1]
                idxs = idxs[:20]
                keywords = [feature_names[j] for j in idxs]
                keywords = " ".join(keywords)
                print "subcluster-", offset + i, subcats[
                    i].name, ":\t", keywords

        db.commit()
        db.close()

        # 全部结束
        print "OK, all done!"
Пример #12
0
from flask import Flask, render_template, request, jsonify, render_template, url_for, send_from_directory, redirect
from flask import send_file
import time
from myutils import Category, CompareUnit, read_subcat, read_subclt
import sys
from myutils import ArticleDB
from myutils import TopkHeap, Dumper
from sklearn.metrics.pairwise import cosine_similarity
from treelib import Tree

reload(sys)
sys.setdefaultencoding('utf-8')

app = Flask(__name__)

db = ArticleDB()

project_name = "article150801160830"
# project_name = "article_cat"
txt_dir = project_name + "/txt/"
attr_dir = project_name + "/attr/"
category_dict = Category()
thumbs = [
    "063226153482.jpg", "071142757085.jpg", "072728052607.jpg",
    "073704072438.jpg", "081042492216.jpg", "090854778371.jpeg",
    "092220134851.jpg", "102652963888.png", "103550760868.jpg",
    "111531748717.jpg", "112713761995.png", "123006128256.jpg",
    "123641032783.jpg", "125123423041.jpg", "125425018839.png",
    "143948746006.jpg", "144655265068.jpg", "152502600146.jpg",
    "160004485992.png", "171128413415.jpg", "210804890144.png",
    "220548955698.jpg", "224636881111.jpg", "230826808516.png"
Пример #13
0
def train_label():
    db = ArticleDB()
    test_proj_name = "article150801160830"
    test_seg_dir = test_proj_name + "/seg/"
    test_obj_dir = test_proj_name + "/clf_tfidf/"
    shutil.rmtree(test_obj_dir, ignore_errors=True)
    os.makedirs(test_obj_dir)
    corpus = []

    # 分类器训练
    print "1. trainning tfidf clf..."
    # clf = TextClassifierTfidf(proj_name="article_cat")
    # clf.train()
    # Dumper.save(clf, "tfidf_clf.dat")
    clf = Dumper.load("tfidf_clf.dat")

    # 获取测试语料库文档数量
    results = db.execute("select count(id) from %s" % test_proj_name)
    test_count = results[0][0]

    # 读取测试语料库
    print "2. reading corpus..."
    for i in xrange(1, test_count + 1):
        test_seg_name = test_seg_dir + str(i)
        with open(test_seg_name, "r") as test_seg_file:
            lines = [
                line.strip() for line in test_seg_file.readlines()
                if len(line.strip()) > 0
            ]
            doc = " ".join(lines)
            corpus.append(doc)

    # 预测测试语料库category
    print "3. predicting corpus"
    categories = clf.predict(corpus)

    # 预测测试语料库category
    print "4 predicting corpus (probability)"
    categories_probs = clf.predict_proba(corpus)

    # 往数据库写入category属性
    print "5. writing corpus prediction to mysql..."
    for i, a_category in enumerate(categories):
        a_id = i + 1
        sql = "update %s set category=%d where id=%s" % (test_proj_name,
                                                         a_category, a_id)
        db.execute(sql)

    print "6. writing corpus prediction probability to mysql..."
    for i, a_category_probs in enumerate(categories_probs):
        a_id = i + 1

        set_statements = []
        for class_label, a_category_prob in zip(
                clf.pipeline.named_steps["clf"].classes_, a_category_probs):
            set_statements.append("p%d=%f" % (class_label, a_category_prob))
        set_statements = ", ".join(set_statements)
        sql = "update %s set %s where id=%s" % (test_proj_name, set_statements,
                                                a_id)
        db.execute(sql)
    db.commit()
    db.close()

    # 保存测试语料库到对象二进制文件中
    print "5. generating corpus tfidf vectors..."
    tfidf_vectors = clf.transform(corpus)

    print "6. writing corpus tfidf vectors to disk..."
    for i in xrange(test_count):
        print "%d/%d" % (i + 1, test_count)
        obj_name = test_obj_dir + str(i + 1)
        Dumper.save(tfidf_vectors[i, :], obj_name)

    print "ok, successfully complete!"
    def train(self):
        # 读取子类分类规格文件
        cat2subcat, tag2id = read_subcat(self.subcat_profile)

        # 标注训练数据
        print "dividing category into sub-categories..."
        db = ArticleDB()
        db.execute("update %s set subcategory=null" % self.project_name)
        db.commit()
        for fcat, subcats in cat2subcat.items():
            ids = db.execute("select id from %s where category=%s" % (self.project_name, fcat))
            ids = [id[0] for id in ids]
            print "\tcategory %d has %d files, divide into %d subcategories" % (fcat, len(ids), len(subcats))
            for id in ids:
                attr_name = "%s/attr/%d" % (self.project_name, id)
                with open(attr_name, "r") as attr_file:
                    tags = attr_file.readlines()[3].strip()
                    if len(tags) > 0:
                        tags = tags.split(" ")
                        subtag2id = tag2id[fcat]
                        for tag in tags:
                            if tag in subtag2id:
                                subcat = subtag2id[tag]
                                db.execute("update %s set subcategory=%d where id=%d" % (self.project_name, subcat, id))
                                break
        db.commit()

        # 训练分类器(对每个类别,利用其子类分类)
        fcats = cat2subcat.keys()
        for fcat in fcats:
            id_subcats = db.execute("select id, subcategory from %s where category=%s and subcategory is not null" % (self.project_name, fcat))
            ids = [row[0] for row in id_subcats]
            y = [row[1] for row in id_subcats]
            x = []
            print "category %d: reading corpus..." % fcat
            for id in ids:
                seg_name = "%s/seg/%d" % (self.project_name, id)
                with open(seg_name, "r") as seg_file:
                    lines = [line.strip() for line in seg_file.readlines() if len(line.strip()) > 0]
                    text = " ".join(lines)
                    x.append(text)

            # 切分训练数据和测试数据
            print "category %d: splitting train and test..." % fcat
            x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)

            # 训练
            print "category %d: training..." % fcat
            self.pipeline.fit(x_train, y_train)

            # 测试
            print "category %d: testing..." % fcat
            y_pred = self.pipeline.predict(x_test)
            print(metrics.classification_report(y_test, y_pred))

            # 预测
            test_proj_name = "article150801160830"
            ids = db.execute("select id from %s where category=%d" % (test_proj_name, fcat))
            ids = [row[0] for row in ids]
            new_x = []
            print "category %d: predicting new corpus..." % fcat
            for id in ids:
                seg_name = "%s/seg/%d" % (test_proj_name, id)
                with open(seg_name, "r") as seg_file:
                    lines = [line.strip() for line in seg_file.readlines() if len(line.strip()) > 0]
                    text = " ".join(lines)
                    new_x.append(text)
            new_y = self.pipeline.predict(new_x)

            print "category %d: writing predict result to sql..." % fcat
            for id, label in zip(ids, new_y):
                db.execute("update %s set subcategory=%d where id=%d" % (test_proj_name, label, id))

        db.commit()
        db.close()

        # 全部结束
        print "OK, all done!"