예제 #1
0
 def __init__(self):
     self.tokenization = Tokenization(
         import_module="jieba",
         user_dict=config.USER_DEFINED_DICT_PATH,
         chn_stop_words_dir=config.CHN_STOP_WORDS_PATH)
     self.database = Database()
     self.classifier = Classifier()
 def __init__(self, database_name, collection_name):
     super(JrjSpyder, self).__init__()
     self.db_obj = Database()
     self.col = self.db_obj.conn[database_name].get_collection(collection_name)
     self.terminated_amount = 0
     self.db_name = database_name
     self.col_name = collection_name
     self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
 def get_all_news_about_specific_stock(self, database_name,
                                       collection_name):
     # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
     # 股票代码保存在新的一列
     _keys_list = list(
         next(
             self.database.get_collection(database_name,
                                          collection_name).find()).keys())
     if "RelatedStockCodes" not in _keys_list:
         tokenization = Tokenization(import_module="jieba",
                                     user_dict="./Leorio/financedict.txt")
         tokenization.update_news_database_rows(database_name,
                                                collection_name)
     # 创建stock_code为名称的collection
     stock_symbol_list = self.database.get_data(
         config.STOCK_DATABASE_NAME,
         config.COLLECTION_NAME_STOCK_BASIC_INFO,
         keys=["symbol"])["symbol"].to_list()
     col_names = self.database.connect_database(
         config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(
             session=None)
     for symbol in stock_symbol_list:
         if symbol not in col_names:
             _collection = self.database.get_collection(
                 config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
             _tmp_num_stat = 0
             for row in self.database.get_collection(
                     database_name, collection_name).find():  # 迭代器
                 if symbol[2:] in row["RelatedStockCodes"].split(" "):
                     # 返回新闻发布后n天的标签
                     _tmp_dict = {}
                     for label_days, key_name in self.label_range.items():
                         _tmp_res = self._label_news(
                             datetime.datetime.strptime(
                                 row["Date"].split(" ")[0], "%Y-%m-%d"),
                             symbol, label_days)
                         _tmp_dict.update({key_name: _tmp_res})
                     _data = {
                         "Date": row["Date"],
                         "Url": row["Url"],
                         "Title": row["Title"],
                         "Article": row["Article"],
                         "OriDB": database_name,
                         "OriCOL": collection_name
                     }
                     _data.update(_tmp_dict)
                     _collection.insert_one(_data)
                     _tmp_num_stat += 1
             logging.info(
                 "there are {} news mentioned {} in {} collection need to be fetched ... "
                 .format(_tmp_num_stat, symbol, collection_name))
         else:
             logging.info(
                 "{} has fetched all related news from {}...".format(
                     symbol, collection_name))
         break
예제 #4
0
 def __init__(self, database_name, collection_name):
     super(NbdSpyder, self).__init__()
     self.db_obj = Database()
     self.col = self.db_obj.conn[database_name].get_collection(collection_name)
     self.terminated_amount = 0
     self.db_name = database_name
     self.col_name = collection_name
     self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
     self.redis_client = redis.StrictRedis(host=config.REDIS_IP,
                                           port=config.REDIS_PORT,
                                           db=config.CACHE_NEWS_REDIS_DB_ID)
예제 #5
0
 def get_all_news_about_specific_stock(self, database_name,
                                       collection_name):
     # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
     # 股票代码保存在新的一列
     _keys_list = list(
         next(
             self.database.get_collection(database_name,
                                          collection_name).find()).keys())
     if "RelatedStockCodes" not in _keys_list:
         tokenization = Tokenization(import_module="jieba",
                                     user_dict="./Leorio/financedict.txt")
         tokenization.update_news_database_rows(database_name,
                                                collection_name)
     # 创建stock_code为名称的collection
     stock_code_list = self.database.get_data("stock",
                                              "basic_info",
                                              keys=["code"
                                                    ])["code"].to_list()
     for code in stock_code_list:
         _collection = self.database.get_collection(
             config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, code)
         _tmp_num_stat = 0
         for row in self.database.get_collection(
                 database_name, collection_name).find():  # 迭代器
             if code in row["RelatedStockCodes"].split(" "):
                 _collection.insert_one({
                     "Date": row["Date"],
                     "Url": row["Url"],
                     "Title": row["Title"],
                     "Article": row["Article"],
                     "OriDB": database_name,
                     "OriCOL": collection_name
                 })
                 _tmp_num_stat += 1
         logging.info(
             "there are {} news mentioned {} in {} collection ... ".format(
                 _tmp_num_stat, code, collection_name))
예제 #6
0
class NbdSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(NbdSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(
            collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(
            import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)

    def get_url_info(self, url):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        span_list = bs.find_all("span")
        part = bs.find_all("p")
        article = ""
        date = ""
        for span in span_list:
            if "class" in span.attrs and span.text and span["class"] == [
                    "time"
            ]:
                string = span.text.split()
                for dt in string:
                    if dt.find("-") != -1:
                        date += dt + " "
                    elif dt.find(":") != -1:
                        date += dt
                break
        for paragraph in part:
            chn_status = utils.count_chn(str(paragraph))
            possible = chn_status[1]
            if possible > self.is_article_prob:
                article += str(paragraph)
        while article.find("<") != -1 and article.find(">") != -1:
            string = article[article.find("<"):article.find(">") + 1]
            article = article.replace(string, "")
        while article.find("\u3000") != -1:
            article = article.replace("\u3000", "")
        article = " ".join(re.split(" +|\n+", article)).strip()

        return [date, article]

    def get_historical_news(self, start_page=684):
        date_list = self.db_obj.get_data(self.db_name,
                                         self.col_name,
                                         keys=["Date"])["Date"].to_list()
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        if len(date_list) == 0:
            # 说明没有历史数据,从头开始爬取
            crawled_urls_list = []
            page_urls = [
                "{}/{}".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_id)
                for page_id in range(start_page, 0, -1)
            ]
            for page_url in page_urls:
                bs = utils.html_parser(page_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "click-statistic" in a.attrs and a.string \
                            and a["click-statistic"].find("Article_") != -1 \
                            and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                        if a["href"] not in crawled_urls_list:
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.NBD_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            if not result:
                                # 爬取失败的情况
                                logging.info("[FAILED] {} {}".format(
                                    a.string, a["href"]))
                            else:
                                # 有返回但是article为null的情况
                                date, article = result
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        # "PageId": page_url.split("/")[-1],
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    # self.col.insert_one(data)
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
        else:
            is_stop = False
            start_date = max(date_list)
            page_start_id = 1
            while not is_stop:
                page_url = "{}/{}".format(
                    config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_start_id)
                bs = utils.html_parser(page_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "click-statistic" in a.attrs and a.string \
                            and a["click-statistic"].find("Article_") != -1 \
                            and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                        result = self.get_url_info(a["href"])
                        while not result:
                            self.terminated_amount += 1
                            if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                # 始终无法爬取的URL保存起来
                                with open(
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                        "a+") as file:
                                    file.write("{}\n".format(a["href"]))
                                logging.info(
                                    "rejected by remote server longer than {} minutes, "
                                    "and the failed url has been written in path {}"
                                    .format(
                                        config.NBD_MAX_REJECTED_AMOUNTS,
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH))
                                break
                            logging.info(
                                "rejected by remote server, request {} again after "
                                "{} seconds...".format(
                                    a["href"], 60 * self.terminated_amount))
                            time.sleep(60 * self.terminated_amount)
                            result = self.get_url_info(a["href"])
                        if not result:
                            # 爬取失败的情况
                            logging.info("[FAILED] {} {}".format(
                                a.string, a["href"]))
                        else:
                            # 有返回但是article为null的情况
                            date, article = result
                            if date > start_date:
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
                            else:
                                is_stop = True
                                break
                if not is_stop:
                    page_start_id += 1

    def get_realtime_news(self, interval=60):
        page_url = "{}/1".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD)
        logging.info(
            "start real-time crawling of URL -> {}, request every {} secs ... "
            .format(page_url, interval))
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls = []
        date_list = self.db_obj.get_data(self.db_name,
                                         self.col_name,
                                         keys=["Date"])["Date"].to_list()
        latest_date = max(date_list)
        while True:
            # 每隔一定时间轮询该网址
            if len(crawled_urls) > 100:
                # 防止list过长,内存消耗大,维持list在100条
                crawled_urls.pop(0)
            bs = utils.html_parser(page_url)
            a_list = bs.find_all("a")
            for a in a_list:
                if "click-statistic" in a.attrs and a.string \
                        and a["click-statistic"].find("Article_") != -1 \
                        and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                    if a["href"] not in crawled_urls:
                        result = self.get_url_info(a["href"])
                        while not result:
                            self.terminated_amount += 1
                            if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                # 始终无法爬取的URL保存起来
                                with open(
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                        "a+") as file:
                                    file.write("{}\n".format(a["href"]))
                                logging.info(
                                    "rejected by remote server longer than {} minutes, "
                                    "and the failed url has been written in path {}"
                                    .format(
                                        config.NBD_MAX_REJECTED_AMOUNTS,
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH))
                                break
                            logging.info(
                                "rejected by remote server, request {} again after "
                                "{} seconds...".format(
                                    a["href"], 60 * self.terminated_amount))
                            time.sleep(60 * self.terminated_amount)
                            result = self.get_url_info(a["href"])
                        if not result:
                            # 爬取失败的情况
                            logging.info("[FAILED] {} {}".format(
                                a.string, a["href"]))
                        else:
                            # 有返回但是article为null的情况
                            date, article = result
                            if date > latest_date:
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        # "PageId": page_url.split("/")[-1],
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    # self.col.insert_one(data)
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    crawled_urls.append(a["href"])
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
            # logging.info("sleep {} secs then request again ... ".format(interval))
            time.sleep(interval)
예제 #7
0
class TopicModelling(object):
    def __init__(self):
        self.tokenization = Tokenization(
            import_module="jieba",
            user_dict=config.USER_DEFINED_DICT_PATH,
            chn_stop_words_dir=config.CHN_STOP_WORDS_PATH)
        self.database = Database()
        self.classifier = Classifier()

    def create_dictionary(self,
                          raw_documents_list,
                          save_path=None,
                          is_saved=False):
        """
        将文中每个词汇关联唯一的ID,因此需要定义词汇表
        :param: raw_documents_list, 原始语料列表,每个元素即文本,如["洗尽铅华...", "风雨赶路人...", ...]
        :param: savepath, corpora.Dictionary对象保存路径
        """
        documents_token_list = []
        for doc in raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary(documents_token_list)
        # 找到只出现一次的token
        once_items = [
            _dict[tokenid] for tokenid, docfreq in _dict.dfs.items()
            if docfreq == 1
        ]
        # 在documents_token_list的每一条语料中,删除只出现一次的token
        for _id, token_list in enumerate(documents_token_list):
            documents_token_list[_id] = list(
                filter(lambda token: token not in once_items, token_list))
        # 极端情况,某一篇语料所有token只出现一次,这样该篇新闻语料的token列表就变为空,因此删除掉
        documents_token_list = [
            token_list for token_list in documents_token_list
            if (len(token_list) != 0)
        ]
        # 找到只出现一次的token对应的id
        once_ids = [
            tokenid for tokenid, docfreq in _dict.dfs.items() if docfreq == 1
        ]
        # 删除仅出现一次的词
        _dict.filter_tokens(once_ids)
        # 消除id序列在删除词后产生的不连续的缺口
        _dict.compactify()
        if is_saved and save_path:
            _dict.save(save_path)
            logging.info(
                "new generated dictionary saved in path -> {} ...".format(
                    save_path))

        return _dict, documents_token_list

    def renew_dictionary(self,
                         old_dict_path,
                         new_raw_documents_list,
                         new_dict_path=None,
                         is_saved=False):
        documents_token_list = []
        for doc in new_raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary.load(old_dict_path)
        _dict.add_documents(documents_token_list)
        if new_dict_path:
            old_dict_path = new_dict_path
        if is_saved:
            _dict.save(old_dict_path)
            logging.info(
                "updated dictionary by another raw documents serialized in {} ... "
                .format(old_dict_path))

        return _dict, documents_token_list

    def create_bag_of_word_representation(self,
                                          raw_documents_list,
                                          old_dict_path=None,
                                          new_dict_path=None,
                                          bow_vector_save_path=None,
                                          is_saved_dict=False):
        if old_dict_path:
            # 如果存在旧的语料词典,就在原先词典的基础上更新,增加未见过的词
            corpora_dictionary, documents_token_list = self.renew_dictionary(
                old_dict_path, raw_documents_list, new_dict_path=new_dict_path)
        else:
            # 否则重新创建词典
            start_time = time.time()
            corpora_dictionary, documents_token_list = self.create_dictionary(
                raw_documents_list,
                save_path=new_dict_path,
                is_saved=is_saved_dict)
            end_time = time.time()
            logging.info(
                "there are {} mins spent to create a new dictionary ... ".
                format((end_time - start_time) / 60))
        # 根据新词典对文档(或语料)生成对应的词袋向量
        start_time = time.time()
        bow_vector = [
            corpora_dictionary.doc2bow(doc_token)
            for doc_token in documents_token_list
        ]
        end_time = time.time()
        logging.info(
            "there are {} mins spent to calculate bow-vector ... ".format(
                (end_time - start_time) / 60))
        if bow_vector_save_path:
            corpora.MmCorpus.serialize(bow_vector_save_path, bow_vector)

        return documents_token_list, corpora_dictionary, bow_vector

    @staticmethod
    def transform_vectorized_corpus(corpora_dictionary,
                                    bow_vector,
                                    model_type="lda",
                                    model_save_path=None):
        # 如何没有保存任何模型,重新训练的情况下,可以选择该函数
        model_vector = None
        if model_type == "lsi":
            # LSI(Latent Semantic Indexing)模型,将文本从词袋向量或者词频向量(更好),转为一个低维度的latent空间
            # 对于现实语料,目标维度在200-500被认为是"黄金标准"
            model_tfidf = models.TfidfModel(bow_vector)
            # model_tfidf.save("model_tfidf.tfidf")
            tfidf_vector = model_tfidf[bow_vector]
            model = models.LsiModel(tfidf_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[tfidf_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "lda":
            model = models.LdaModel(bow_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[bow_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "tfidf":
            model = models.TfidfModel(bow_vector)  # 初始化
            # model = models.TfidfModel.load("model_tfidf.tfidf")
            model_vector = model[bow_vector]  # 将整个语料进行转换
            if model_save_path:
                model.save(model_save_path)

        return model_vector

    def classify_stock_news(self,
                            unseen_raw_document,
                            database_name,
                            collection_name,
                            label_name="60DaysLabel",
                            topic_model_type="lda",
                            classifier_model="svm",
                            ori_dict_path=None,
                            bowvec_save_path=None,
                            is_saved_bow_vector=False):
        historical_raw_documents_list = []
        Y = []
        for row in self.database.get_collection(database_name,
                                                collection_name).find():
            if label_name in row.keys():
                if row[label_name] != "":
                    historical_raw_documents_list.append(row["Article"])
                    Y.append(row[label_name])
        logging.info(
            "fetch symbol '{}' historical news with label '{}' from [DB:'{}' - COL:'{}'] ... "
            .format(collection_name, label_name, database_name,
                    collection_name))

        le = preprocessing.LabelEncoder()
        Y = le.fit_transform(Y)
        logging.info(
            "encode historical label list by sklearn preprocessing for training ... "
        )
        label_name_list = le.classes_  # ['中性' '利好' '利空'] -> [0, 1, 2]

        # 根据历史新闻数据库创建词典,以及计算每个历史新闻的词袋向量;如果历史数据库创建的字典存在,则加载进内存
        # 用未见过的新闻tokens去更新该词典
        if not os.path.exists(ori_dict_path):
            if not os.path.exists(bowvec_save_path):
                _, _, historical_bow_vec = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    bow_vector_save_path=bowvec_save_path,
                    is_saved_dict=True)
                logging.info(
                    "create dictionary of historical news, and serialized in path -> {} ... "
                    .format(ori_dict_path))
                logging.info(
                    "create bow-vector of historical news, and serialized in path -> {} ... "
                    .format(bowvec_save_path))
            else:
                _, _, _ = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    is_saved_dict=True)
                logging.info(
                    "create dictionary of historical news, and serialized in path -> {} ... "
                    .format(ori_dict_path))
        else:
            if not os.path.exists(bowvec_save_path):
                _, _, historical_bow_vec = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    bow_vector_save_path=bowvec_save_path,
                    is_saved_dict=True)
                logging.info(
                    "historical news dictionary existed, which saved in path -> {}, but not the historical bow-vector"
                    " ... ".format(ori_dict_path))
            else:
                historical_bow_vec_mmcorpus = corpora.MmCorpus(
                    bowvec_save_path
                )  # type -> <gensim.corpora.mmcorpus.MmCorpus>
                historical_bow_vec = []
                for _bow in historical_bow_vec_mmcorpus:
                    historical_bow_vec.append(_bow)
                logging.info(
                    "both historical news dictionary and bow-vector existed, load historical bow-vector to memory ... "
                )

        start_time = time.time()
        updated_dictionary_with_old_and_unseen_news, unssen_documents_token_list = self.renew_dictionary(
            ori_dict_path, [unseen_raw_document], is_saved=True)
        end_time = time.time()
        logging.info(
            "renew dictionary with unseen news tokens, and serialized in path -> {}, "
            "which took {} mins ... ".format(ori_dict_path,
                                             (end_time - start_time) / 60))

        unseen_bow_vector = [
            updated_dictionary_with_old_and_unseen_news.doc2bow(doc_token)
            for doc_token in unssen_documents_token_list
        ]
        updated_bow_vector_with_old_and_unseen_news = []
        updated_bow_vector_with_old_and_unseen_news.extend(historical_bow_vec)
        updated_bow_vector_with_old_and_unseen_news.extend(unseen_bow_vector)
        # 原先updated_bow_vector_with_old_and_unseen_news是list类型,
        # 但是经过下面序列化后重新加载进来的类型是gensim.corpora.mmcorpus.MmCorpus
        if is_saved_bow_vector and bowvec_save_path:
            corpora.MmCorpus.serialize(
                bowvec_save_path, updated_bow_vector_with_old_and_unseen_news
            )  # 保存更新后的bow向量,即包括新旧新闻的bow向量集
        logging.info(
            "combined bow vector(type -> 'list') generated by historical news with unseen bow "
            "vector to create a new one ... ")

        if topic_model_type == "lsi":
            start_time = time.time()
            updated_tfidf_model_vector = self.transform_vectorized_corpus(
                updated_dictionary_with_old_and_unseen_news,
                updated_bow_vector_with_old_and_unseen_news,
                model_type="tfidf"
            )  # type -> <gensim.interfaces.TransformedCorpus object>
            end_time = time.time()
            logging.info(
                "regenerated TF-IDF model vector by updated dictionary and updated bow-vector, "
                "which took {} mins ... ".format((end_time - start_time) / 60))

            start_time = time.time()
            model = models.LsiModel(
                updated_tfidf_model_vector,
                id2word=updated_dictionary_with_old_and_unseen_news,
                num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[
                updated_tfidf_model_vector]  # type -> <gensim.interfaces.TransformedCorpus object>
            end_time = time.time()
            logging.info(
                "regenerated LSI model vector space by updated TF-IDF model vector space, "
                "which took {} mins ... ".format((end_time - start_time) / 60))
        elif topic_model_type == "lda":
            start_time = time.time()
            model_vector = self.transform_vectorized_corpus(
                updated_dictionary_with_old_and_unseen_news,
                updated_bow_vector_with_old_and_unseen_news,
                model_type="lda")
            end_time = time.time()
            logging.info(
                "regenerated LDA model vector space by updated dictionary and bow-vector, "
                "which took {} mins ... ".format((end_time - start_time) / 60))

        # 将gensim.interfaces.TransformedCorpus类型的lsi模型向量转为numpy矩阵
        start_time = time.time()
        latest_matrix = corpus2dense(model_vector,
                                     num_terms=model_vector.obj.num_terms).T
        end_time = time.time()
        logging.info(
            "transform {} model vector space to numpy.adarray, "
            "which took {} mins ... ".format(topic_model_type.upper(),
                                             (end_time - start_time) / 60))

        # 利用历史数据的话题模型向量(或特征),进一步训练新闻分类器
        start_time = time.time()
        train_x, train_y, test_x, test_y = utils.generate_training_set(
            latest_matrix[:-1, :], Y)
        clf = self.classifier.train(train_x,
                                    train_y,
                                    test_x,
                                    test_y,
                                    model_type=classifier_model)
        end_time = time.time()
        logging.info(
            "finished training by sklearn {} using latest {} model vector space, which took {} mins ... "
            .format(classifier_model.upper(), topic_model_type.upper(),
                    (end_time - start_time) / 60))

        label_id = clf.predict(latest_matrix[-1, :].reshape(1, -1))[0]

        return label_name_list[label_id]
예제 #8
0
class TopicModelling(object):
    def __init__(self):
        self.tokenization = Tokenization(
            import_module="jieba",
            user_dict=config.USER_DEFINED_DICT_PATH,
            chn_stop_words_dir=config.CHN_STOP_WORDS_PATH)

    def create_dictionary(self, raw_documents_list, savepath=None):
        """
        将文中每个词汇关联唯一的ID,因此需要定义词汇表
        :param: raw_documents_list, 原始语料列表,每个元素即文本,如["洗尽铅华...", "风雨赶路人...", ...]
        :param: savepath, corpora.Dictionary对象保存路径
        """
        documents_token_list = []
        for doc in raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary(documents_token_list)
        # 找到只出现一次的token
        once_items = [
            _dict[tokenid] for tokenid, docfreq in _dict.dfs.items()
            if docfreq == 1
        ]
        # 在documents_token_list的每一条语料中,删除只出现一次的token
        for _id, token_list in enumerate(documents_token_list):
            documents_token_list[_id] = list(
                filter(lambda token: token not in once_items, token_list))
        # 极端情况,某一篇语料所有token只出现一次,这样该篇新闻语料的token列表就变为空,因此删除掉
        documents_token_list = [
            token_list for token_list in documents_token_list
            if (len(token_list) != 0)
        ]
        # 找到只出现一次的token对应的id
        once_ids = [
            tokenid for tokenid, docfreq in _dict.dfs.items() if docfreq == 1
        ]
        # 删除仅出现一次的词
        _dict.filter_tokens(once_ids)
        # 消除id序列在删除词后产生的不连续的缺口
        _dict.compactify()
        if savepath:
            _dict.save(savepath)
        return _dict, documents_token_list

    def create_bag_of_word_representation(self,
                                          raw_documents_list,
                                          dict_save_path=None,
                                          bow_vector_save_path=None):
        corpora_dictionary, documents_token_list = self.create_dictionary(
            raw_documents_list, savepath=dict_save_path)
        bow_vector = [
            corpora_dictionary.doc2bow(doc_token)
            for doc_token in documents_token_list
        ]
        if bow_vector_save_path:
            corpora.MmCorpus.serialize(bow_vector_save_path, bow_vector)
        return documents_token_list, corpora_dictionary, bow_vector

    def transform_vectorized_corpus(self,
                                    corpora_dictionary,
                                    bow_vector,
                                    model_type="lda",
                                    model_save_path=None):
        # 如何没有保存任何模型,重新训练的情况下,可以选择该函数
        model_vector = None
        if model_type == "lsi":
            # LSI(Latent Semantic Indexing)模型,将文本从词袋向量或者词频向量(更好),转为一个低维度的latent空间
            # 对于现实语料,目标维度在200-500被认为是"黄金标准"
            tfidf_vector = models.TfidfModel(bow_vector)[bow_vector]
            model = models.LsiModel(tfidf_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[tfidf_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "lda":
            model = models.LdaModel(bow_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[bow_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "tfidf":
            model = models.TfidfModel(bow_vector)  # 初始化
            model_vector = model[bow_vector]  # 将整个语料进行转换
            if model_save_path:
                model.save(model_save_path)
        return model_vector

    def add_documents_to_serialized_model(self,
                                          old_model_path,
                                          another_raw_documents_list,
                                          latest_model_path=None,
                                          model_type="lsi"):
        # 加载已有的模型,Gensim提供在线学习的模式,不断基于新的documents训练新的模型
        if not os.path.exists(old_model_path):
            raise Exception(
                "the file path {} does not exist ... ".format(old_model_path))
        if model_type == "lsi":
            loaded_model = models.LsiModel.load(old_model_path)
        elif model_type == "lda":
            loaded_model = models.LdaModel.load(old_model_path)

        # loaded_model.add_documents(another_tfidf_corpus)

        if latest_model_path:
            old_model_path = latest_model_path
        loaded_model.save(old_model_path)

    def load_transform_model(self, model_path):
        if ".tfidf" in model_path:
            return models.TfidfModel.load(model_path)
        elif ".lsi" in model_path:
            return models.LsiModel.load(model_path)
        elif ".lda" in model_path:
            return models.LdaModel.load(model_path)
예제 #9
0
# for url_to_be_crawled, type_chn in config.WEBSITES_LIST_TO_BE_CRAWLED_CNSTOCK.items():
#     logging.info("start crawling {} ...".format(url_to_be_crawled))
#     cnstock_spyder.get_historical_news(url_to_be_crawled, category_chn=type_chn)
#     logging.info("finished ...")
#     time.sleep(30)
#
# jrj_spyder = JrjSpyder(config.DATABASE_NAME, config.COLLECTION_NAME_JRJ)
# jrj_spyder.get_historical_news(config.WEBSITES_LIST_TO_BE_CRAWLED_JRJ, "2020-12-04", "2020-12-08")
#
# nbd_spyder = NbdSpyder(config.DATABASE_NAME, config.COLLECTION_NAME_NBD)
# nbd_spyder.get_historical_news(684)

# 2. 抽取出新闻中所涉及的股票,并保存其股票代码在collection中新的一列
from Leorio.tokenization import Tokenization

tokenization = Tokenization(import_module="jieba",
                            user_dict="./Leorio/financedict.txt")
tokenization.update_news_database_rows(config.DATABASE_NAME, "cnstock")
# tokenization.update_news_database_rows(config.DATABASE_NAME, "nbd")
# tokenization.update_news_database_rows(config.DATABASE_NAME, "jrj")

# 3. 针对历史数据进行去重清洗
from Killua.deduplication import Deduplication

Deduplication("finnewshunter", "cnstock").run()
# Deduplication("finnewshunter", "nbd").run()
# Deduplication("finnewshunter", "jrj").run()  # 暂时只有jrj需要去重

# 4. 将历史数据中包含null值的行去掉
from Killua.denull import DeNull

# DeNull("finnewshunter", "cnstock").run()
class CnStockSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(CnStockSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(
            collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(
            import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
        self.redis_client = redis.StrictRedis(host=config.REDIS_IP,
                                              port=config.REDIS_PORT,
                                              db=config.CACHE_NEWS_REDIS_DB_ID)

    def get_url_info(self, url):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        span_list = bs.find_all("span")
        part = bs.find_all("p")
        article = ""
        date = ""
        for span in span_list:
            if "class" in span.attrs and span["class"] == ["timer"]:
                date = span.text
                break
        for paragraph in part:
            chn_status = utils.count_chn(str(paragraph))
            possible = chn_status[1]
            if possible > self.is_article_prob:
                article += str(paragraph)
        while article.find("<") != -1 and article.find(">") != -1:
            string = article[article.find("<"):article.find(">") + 1]
            article = article.replace(string, "")
        while article.find("\u3000") != -1:
            article = article.replace("\u3000", "")
        article = " ".join(re.split(" +|\n+", article)).strip()

        return [date, article]

    def get_historical_news(self, url, category_chn=None, start_date=None):
        """
        :param url: 爬虫网页
        :param category_chn: 所属类别, 中文字符串, 包括'公司聚焦', '公告解读', '公告快讯', '利好公告'
        :param start_date: 数据库中category_chn类别新闻最近一条数据的时间
        """
        assert category_chn is not None
        driver = webdriver.Chrome(executable_path=config.CHROME_DRIVER)
        btn_more_text = ""
        crawled_urls_list = self.extract_data(["Url"])[0]
        logging.info("historical data length -> {} ... ".format(
            len(crawled_urls_list)))
        # crawled_urls_list = []
        driver.get(url)
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        if start_date is None:
            while btn_more_text != "没有更多":
                more_btn = driver.find_element_by_id('j_more_btn')
                btn_more_text = more_btn.text
                logging.info("1-{}".format(more_btn.text))
                if btn_more_text == "加载更多":
                    more_btn.click()
                    time.sleep(random.random())  # sleep random time less 1s
                elif btn_more_text == "加载中...":
                    time.sleep(random.random() + 2)
                    more_btn = driver.find_element_by_id('j_more_btn')
                    btn_more_text = more_btn.text
                    logging.info("2-{}".format(more_btn.text))
                    if btn_more_text == "加载更多":
                        more_btn.click()
                else:
                    more_btn.click()
                    break
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls_list:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            data = {
                                "Date":
                                date,
                                "Category":
                                category_chn,
                                "Url":
                                a["href"],
                                "Title":
                                a["title"],
                                "Article":
                                article,
                                "RelatedStockCodes":
                                " ".join(related_stock_codes_list)
                            }
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name,
                                                    self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
        else:
            # 当start_date不为None时,补充历史数据
            is_click_button = True
            start_get_url_info = False
            tmp_a = None
            while is_click_button:
                bs = BeautifulSoup(driver.page_source, "html.parser")
                for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                    a = li.find_all("h2")[0].find("a")
                    if tmp_a is not None and a["href"] != tmp_a:
                        continue
                    elif tmp_a is not None and a["href"] == tmp_a:
                        start_get_url_info = True
                    if start_get_url_info:
                        date, _ = self.get_url_info(a["href"])
                        if date <= start_date:
                            is_click_button = False
                            break
                tmp_a = a["href"]
                if is_click_button:
                    more_btn = driver.find_element_by_id('j_more_btn')
                    more_btn.click()
            # 从一开始那条新闻到tmp_a都是新增新闻,不包括tmp_a
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] != tmp_a:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            data = {
                                "Date":
                                date,
                                "Category":
                                category_chn,
                                "Url":
                                a["href"],
                                "Title":
                                a["title"],
                                "Article":
                                article,
                                "RelatedStockCodes":
                                " ".join(related_stock_codes_list)
                            }
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name,
                                                    self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
                else:
                    break
        driver.quit()

    def get_realtime_news(self, url, category_chn=None, interval=60):
        logging.info(
            "start real-time crawling of URL -> {}, request every {} secs ... "
            .format(url, interval))
        assert category_chn is not None
        # TODO: 由于cnstock爬取的数据量并不大,这里暂时是抽取历史所有数据进行去重,之后会修改去重策略
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls = self.db_obj.get_data(self.db_name,
                                            self.col_name,
                                            keys=["Url"])["Url"].to_list()
        while True:
            # 每隔一定时间轮询该网址
            bs = utils.html_parser(url)
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls:  # latest_3_days_crawled_href
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            self.db_obj.insert_data(
                                self.db_name, self.col_name, {
                                    "Date":
                                    date,
                                    "Category":
                                    category_chn,
                                    "Url":
                                    a["href"],
                                    "Title":
                                    a["title"],
                                    "Article":
                                    article,
                                    "RelatedStockCodes":
                                    " ".join(related_stock_codes_list)
                                })
                            self.redis_client.lpush(
                                config.CACHE_NEWS_LIST_NAME,
                                json.dumps({
                                    "Date":
                                    date,
                                    "Category":
                                    category_chn,
                                    "Url":
                                    a["href"],
                                    "Title":
                                    a["title"],
                                    "Article":
                                    article,
                                    "RelatedStockCodes":
                                    " ".join(related_stock_codes_list),
                                    "OriDB":
                                    config.DATABASE_NAME,
                                    "OriCOL":
                                    config.COLLECTION_NAME_CNSTOCK
                                }))
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
                            crawled_urls.append(a["href"])
            # logging.info("sleep {} secs then request {} again ... ".format(interval, url))
            time.sleep(interval)
class JrjSpyder(Spyder):

    def __init__(self, database_name, collection_name):
        super(JrjSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)

    def get_url_info(self, url, specific_date):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        date = ""
        for span in bs.find_all("span"):
            if span.contents[0] == "jrj_final_date_start":
                date = span.text.replace("\r", "").replace("\n", "")
                break
        if date == "":
            date = specific_date
        article = ""
        for p in bs.find_all("p"):
            if not p.find_all("jrj_final_daohang_start") and p.attrs == {} and \
                    not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i") and not p.find_all("span"):
                # if p.contents[0] != "jrj_final_daohang_start1" and p.attrs == {} and \
                #         not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i"):
                article += p.text.replace("\r", "").replace("\n", "").replace("\u3000", "")

        return [date, article]

    def get_historical_news(self, url, start_date=None, end_date=None):
        name_code_df = self.db_obj.get_data(config.STOCK_DATABASE_NAME,
                                            config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)

        crawled_urls_list = []
        if end_date is None:
            end_date = datetime.datetime.now().strftime("%Y-%m-%d")

        if start_date is None:
            # 如果start_date是None,则从历史数据库最新的日期补充爬取到最新日期
            # e.g. history_latest_date_str -> "2020-12-08"
            #      history_latest_date_dt -> datetime.date(2020, 12, 08)
            #      start_date -> "2020-12-09"
            history_latest_date_list = self.db_obj.get_data(self.db_name,
                                                            self.col_name,
                                                            keys=["Date"])["Date"].to_list()
            if len(history_latest_date_list) != 0:
                history_latest_date_str = max(history_latest_date_list).split(" ")[0]
                history_latest_date_dt = datetime.datetime.strptime(history_latest_date_str, "%Y-%m-%d").date()
                offset = datetime.timedelta(days=1)
                start_date = (history_latest_date_dt + offset).strftime('%Y-%m-%d')
            else:
                start_date = config.JRJ_REQUEST_DEFAULT_DATE

        dates_list = utils.get_date_list_from_range(start_date, end_date)
        dates_separated_into_ranges_list = utils.gen_dates_list(dates_list, config.JRJ_DATE_RANGE)

        for dates_range in dates_separated_into_ranges_list:
            for date in dates_range:
                first_url = "{}/{}/{}_1.shtml".format(url, date.replace("-", "")[0:6], date.replace("-", ""))
                max_pages_num = utils.search_max_pages_num(first_url, date)
                for num in range(1, max_pages_num + 1):
                    _url = "{}/{}/{}_{}.shtml".format(url, date.replace("-", "")[0:6], date.replace("-", ""), str(num))
                    bs = utils.html_parser(_url)
                    a_list = bs.find_all("a")
                    for a in a_list:
                        if "href" in a.attrs and a.string and \
                                a["href"].find("/{}/{}/".format(date.replace("-", "")[:4],
                                                                date.replace("-", "")[4:6])) != -1:
                            if a["href"] not in crawled_urls_list:
                                # 如果标题不包含"收盘","报于"等字样,即可写入数据库,因为包含这些字样标题的新闻多为机器自动生成
                                if a.string.find("收盘") == -1 and a.string.find("报于") == -1 and \
                                        a.string.find("新三板挂牌上市") == -1:
                                    result = self.get_url_info(a["href"], date)
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                file.write("{}\n".format(a["href"]))
                                            logging.info("rejected by remote server longer than {} minutes, "
                                                         "and the failed url has been written in path {}"
                                                         .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                 config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                            break
                                        logging.info("rejected by remote server, request {} again after "
                                                     "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"], date)
                                    if not result:
                                        # 爬取失败的情况
                                        logging.info("[FAILED] {} {}".format(a.string, a["href"]))
                                    else:
                                        # 有返回但是article为null的情况
                                        article_specific_date, article = result
                                        while article == "" and self.is_article_prob >= .1:
                                            self.is_article_prob -= .1
                                            result = self.get_url_info(a["href"], date)
                                            while not result:
                                                self.terminated_amount += 1
                                                if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                                    # 始终无法爬取的URL保存起来
                                                    with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                        file.write("{}\n".format(a["href"]))
                                                    logging.info("rejected by remote server longer than {} minutes, "
                                                                 "and the failed url has been written in path {}"
                                                                 .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                         config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                                    break
                                                logging.info("rejected by remote server, request {} again after "
                                                             "{} seconds...".format(a["href"],
                                                                                    60 * self.terminated_amount))
                                                time.sleep(60 * self.terminated_amount)
                                                result = self.get_url_info(a["href"], date)
                                            article_specific_date, article = result
                                        self.is_article_prob = .5
                                        if article != "":
                                                related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(article,
                                                                                                                                  name_code_dict)
                                                data = {"Date": article_specific_date,
                                                        "Url": a["href"],
                                                        "Title": a.string,
                                                        "Article": article,
                                                        "RelatedStockCodes": " ".join(related_stock_codes_list)}
                                                # self.col.insert_one(data)
                                                self.db_obj.insert_data(self.db_name, self.col_name, data)
                                                logging.info("[SUCCESS] {} {} {}".format(article_specific_date,
                                                                                         a.string,
                                                                                         a["href"]))
                                    self.terminated_amount = 0  # 爬取结束后重置该参数
                                else:
                                    logging.info("[QUIT] {}".format(a.string))

    def get_realtime_news(self, interval=60):
        name_code_df = self.db_obj.get_data(config.STOCK_DATABASE_NAME,
                                            config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls_list = []
        is_change_date = False
        last_date = datetime.datetime.now().strftime("%Y-%m-%d")
        while True:
            today_date = datetime.datetime.now().strftime("%Y-%m-%d")
            if today_date != last_date:
                is_change_date = True
                last_date = today_date
            if is_change_date:
                crawled_urls_list = []
                is_change_date = False
            _url = "{}/{}/{}_1.shtml".format(config.WEBSITES_LIST_TO_BE_CRAWLED_JRJ,
                                             today_date.replace("-", "")[0:6],
                                             today_date.replace("-", ""))
            max_pages_num = utils.search_max_pages_num(_url, today_date)
            for num in range(1, max_pages_num + 1):
                _url = "{}/{}/{}_{}.shtml".format(config.WEBSITES_LIST_TO_BE_CRAWLED_JRJ,
                                                  today_date.replace("-", "")[0:6],
                                                  today_date.replace("-", ""),
                                                  str(num))
                bs = utils.html_parser(_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "href" in a.attrs and a.string and \
                            a["href"].find("/{}/{}/".format(today_date.replace("-", "")[:4],
                                                            today_date.replace("-", "")[4:6])) != -1:
                        if a["href"] not in crawled_urls_list:
                            # 如果标题不包含"收盘","报于"等字样,即可写入数据库,因为包含这些字样标题的新闻多为机器自动生成
                            if a.string.find("收盘") == -1 and a.string.find("报于") == -1 and \
                                    a.string.find("新三板挂牌上市") == -1:
                                result = self.get_url_info(a["href"], today_date)
                                while not result:
                                    self.terminated_amount += 1
                                    if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                        # 始终无法爬取的URL保存起来
                                        with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                            file.write("{}\n".format(a["href"]))
                                        logging.info("rejected by remote server longer than {} minutes, "
                                                     "and the failed url has been written in path {}"
                                                     .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                             config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                        break
                                    logging.info("rejected by remote server, request {} again after "
                                                 "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                    time.sleep(60 * self.terminated_amount)
                                    result = self.get_url_info(a["href"], today_date)
                                if not result:
                                    # 爬取失败的情况
                                    logging.info("[FAILED] {} {}".format(a.string, a["href"]))
                                else:
                                    # 有返回但是article为null的情况
                                    article_specific_date, article = result
                                    while article == "" and self.is_article_prob >= .1:
                                        self.is_article_prob -= .1
                                        result = self.get_url_info(a["href"], today_date)
                                        while not result:
                                            self.terminated_amount += 1
                                            if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                                # 始终无法爬取的URL保存起来
                                                with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                    file.write("{}\n".format(a["href"]))
                                                logging.info("rejected by remote server longer than {} minutes, "
                                                             "and the failed url has been written in path {}"
                                                             .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                     config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                                break
                                            logging.info("rejected by remote server, request {} again after "
                                                         "{} seconds...".format(a["href"],
                                                                                60 * self.terminated_amount))
                                            time.sleep(60 * self.terminated_amount)
                                            result = self.get_url_info(a["href"], today_date)
                                        article_specific_date, article = result
                                    self.is_article_prob = .5
                                    if article != "":
                                        related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(article,
                                                                                                                          name_code_dict)
                                        data = {"Date": article_specific_date,
                                                "Url": a["href"],
                                                "Title": a.string,
                                                "Article": article,
                                                "RelatedStockCodes": " ".join(related_stock_codes_list)}
                                        # self.col.insert_one(data)
                                        self.db_obj.insert_data(self.db_name, self.col_name, data)
                                        logging.info("[SUCCESS] {} {} {}".format(article_specific_date,
                                                                                 a.string,
                                                                                 a["href"]))
                                self.terminated_amount = 0  # 爬取结束后重置该参数
                            else:
                                logging.info("[QUIT] {}".format(a.string))
                            crawled_urls_list.append(a["href"])
            # logging.info("sleep {} secs then request again ... ".format(interval))
            time.sleep(interval)
    from Kite.database import Database
    from Kite import config
    from concurrent import futures
    import threading

    obj = Database()
    df = obj.get_data(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK, keys=["Date", "Category"])

    cnstock_spyder = CnStockSpyder(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK)
    # 先补充历史数据,比如已爬取数据到2020-12-01,但是启动实时爬取程序在2020-12-23,则先
    # 自动补充爬取2020-12-02至2020-12-23的新闻数据
    for url_to_be_crawled, type_chn in config.WEBSITES_LIST_TO_BE_CRAWLED_CNSTOCK.items():
        # 查询type_chn的最近一条数据的时间
        latets_date_in_db = max(df[df.Category == type_chn]["Date"].to_list())
        cnstock_spyder.get_historical_news(url_to_be_crawled, category_chn=type_chn, start_date=latets_date_in_db)

    tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
    tokenization.update_news_database_rows(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK)
    Deduplication(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK).run()
    DeNull(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK).run()

    # 开启多线程并行实时爬取
    thread_list = []
    for url, type_chn in config.WEBSITES_LIST_TO_BE_CRAWLED_CNSTOCK.items():
        thread = threading.Thread(target=cnstock_spyder.get_realtime_news, args=(url, type_chn, 60))
        thread_list.append(thread)
    for thread in thread_list:
        thread.start()
    for thread in thread_list:
        thread.join()
예제 #13
0
 def __init__(self):
     self.tokenization = Tokenization(import_module="jieba",
                                      user_dict="../Leorio/financedict.txt",
                                      chn_stop_words_dir="../Leorio/chnstopwords.txt")
예제 #14
0
class TopicModelling(object):

    def __init__(self):
        self.tokenization = Tokenization(import_module="jieba",
                                         user_dict="../Leorio/financedict.txt",
                                         chn_stop_words_dir="../Leorio/chnstopwords.txt")

    def create_dictionary(self, raw_documents_list, savepath=None):
        """
        将文中每个词汇关联唯一的ID,因此需要定义词汇表
        :param: raw_documents_list, 原始语料列表,每个元素即文本,如["洗尽铅华...", "风雨赶路人...", ...]
        :param: savepath, corpora.Dictionary对象保存路径
        """
        documents_token_list = []
        for doc in raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary(documents_token_list)
        if savepath:
            _dict.save(savepath)
        return _dict, documents_token_list

    def create_bag_of_word_representation(self, raw_documents_list, dict_save_path=None, bow_vector_save_path=None):
        corpora_dictionary, documents_token_list = self.create_dictionary(raw_documents_list, savepath=dict_save_path)
        bow_vector = [corpora_dictionary.doc2bow(doc_token) for doc_token in documents_token_list]
        if bow_vector_save_path:
            corpora.MmCorpus.serialize(bow_vector_save_path, bow_vector)
        return documents_token_list, corpora_dictionary, bow_vector

    def transform_vectorized_corpus(self, corpora_dictionary, bow_vector, model_type="lda", model_save_path=None):
        if model_type == "lsi":
            tfidf_vector = models.TfidfModel(bow_vector)[bow_vector]
            model = models.LsiModel(tfidf_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[tfidf_vector]
            if model_save_path:
                model.save(model_save_path)
            return tfidf_vector, model_vector
        elif model_type == "lda":
            tfidf_vector = models.TfidfModel(bow_vector)[bow_vector]
            model = models.LdaModel(tfidf_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[tfidf_vector]
            if model_save_path:
                model.save(model_save_path)
            return tfidf_vector, model_vector
        elif model_type == "tfidf":
            model = models.TfidfModel(bow_vector)  # 初始化
            tfidf_vector = model[bow_vector]  # 将整个语料进行转换
            if model_save_path:
                model.save(model_save_path)
            return tfidf_vector

    def load_transform_model(self, model_path):
        if ".tfidf" in model_path:
            return models.TfidfModel.load(model_path)
        elif ".lsi" in model_path:
            return models.LsiModel.load(model_path)
        elif ".lda" in model_path:
            return models.LdaModel.load(model_path)