Exemplo n.º 1
0
class Deduplication(object):
    def __init__(self, database_name, collection_name):
        self.database = Database()
        self.database_name = database_name
        self.collection_name = collection_name
        self.delete_num = 0

    def run(self):
        date_list = self.database.get_data(self.database_name,
                                           self.collection_name,
                                           keys=["Date"])["Date"].tolist()
        collection = self.database.get_collection(self.database_name,
                                                  self.collection_name)
        date_list.sort()  # 升序
        # start_date, end_date = date_list[1].split(" ")[0], date_list[-1].split(" ")[0]
        start_date, end_date = min(date_list).split(" ")[0], max(
            date_list).split(" ")[0]
        for _date in utils.get_date_list_from_range(start_date, end_date):
            # 获取特定时间对应的数据并根据URL去重
            # logging.info(_date)
            try:
                data_df = self.database.get_data(
                    self.database_name,
                    self.collection_name,
                    query={"Date": {
                        "$regex": _date
                    }})
            except Exception:
                continue
            if data_df is None:
                continue
            data_df_drop_duplicate = data_df.drop_duplicates(["Url"])
            for _id in list(
                    set(data_df["_id"]) - set(data_df_drop_duplicate["_id"])):
                collection.delete_one({'_id': _id})
                self.delete_num += 1
            # logging.info("{} finished ... ".format(_date))
        logging.info(
            "DB:{} - COL:{} had {} data length originally, now has deleted {} depulications ... "
            .format(self.database_name, self.collection_name,
                    str(len(date_list)), self.delete_num))
Exemplo n.º 2
0
class GenStockNewsDB(object):
    def __init__(self):
        self.database = Database()

    def get_all_news_about_specific_stock(self, database_name,
                                          collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(
            next(
                self.database.get_collection(database_name,
                                             collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba",
                                        user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name,
                                                   collection_name)
        # 创建stock_code为名称的collection
        stock_code_list = self.database.get_data("stock",
                                                 "basic_info",
                                                 keys=["code"
                                                       ])["code"].to_list()
        for code in stock_code_list:
            _collection = self.database.get_collection(
                config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, code)
            _tmp_num_stat = 0
            for row in self.database.get_collection(
                    database_name, collection_name).find():  # 迭代器
                if code in row["RelatedStockCodes"].split(" "):
                    _collection.insert_one({
                        "Date": row["Date"],
                        "Url": row["Url"],
                        "Title": row["Title"],
                        "Article": row["Article"],
                        "OriDB": database_name,
                        "OriCOL": collection_name
                    })
                    _tmp_num_stat += 1
            logging.info(
                "there are {} news mentioned {} in {} collection ... ".format(
                    _tmp_num_stat, code, collection_name))
Exemplo n.º 3
0
class Tokenization(object):
    def __init__(self,
                 import_module="jieba",
                 user_dict=None,
                 chn_stop_words_dir=None):
        #self.database = Database().conn[config.DATABASE_NAME]  #.get_collection(config.COLLECTION_NAME_CNSTOCK)
        self.database = Database()
        self.import_module = import_module
        self.user_dict = user_dict
        if self.user_dict:
            self.update_user_dict(self.user_dict)
        if chn_stop_words_dir:
            self.stop_words_list = utils.get_chn_stop_words(chn_stop_words_dir)
        else:
            self.stop_words_list = list()

    def update_user_dict(self, old_user_dict_dir, new_user_dict_dir=None):
        # 将缺失的(或新的)股票名称、金融新词等,添加进金融词典中
        word_list = []
        with open(old_user_dict_dir, "r", encoding="utf-8") as file:
            for row in file:
                word_list.append(row.split("\n")[0])
        name_code_df = self.database.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        new_words_list = list(set(name_code_df["name"].tolist()))
        for word in new_words_list:
            if word not in word_list:
                word_list.append(word)
        new_user_dict_dir = old_user_dict_dir if not new_user_dict_dir else new_user_dict_dir
        with open(new_user_dict_dir, "w", encoding="utf-8") as file:
            for word in word_list:
                file.write(word + "\n")

    def cut_words(self, text):
        outstr = list()
        sentence_seged = None
        if self.import_module == "jieba":
            if self.user_dict:
                jieba.load_userdict(self.user_dict)
            sentence_seged = list(jieba.cut(text))
        elif self.import_module == "pkuseg":
            seg = pkuseg.pkuseg(user_dict=self.user_dict)  # 添加自定义词典
            sentence_seged = seg.cut(text)  # 进行分词
        if sentence_seged:
            for word in sentence_seged:
                if word not in self.stop_words_list and word != "\t" and word != " ":
                    outstr.append(word)
            return outstr
        else:
            return False

    def find_relevant_stock_codes_in_article(self, article,
                                             stock_name_code_dict):
        stock_codes_set = list()
        cut_words_list = self.cut_words(article)
        if cut_words_list:
            for word in cut_words_list:
                try:
                    stock_codes_set.append(stock_name_code_dict[word])
                except Exception:
                    pass
        return list(set(stock_codes_set))

    def update_news_database_rows(self, database_name, collection_name):
        name_code_df = self.database.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        data = self.database.get_collection(database_name,
                                            collection_name).find()
        for row in data:
            # if row["Date"] > "2019-05-20 00:00:00":
            related_stock_codes_list = self.find_relevant_stock_codes_in_article(
                row["Article"], name_code_dict)
            self.database.update_row(
                database_name, collection_name, {"_id": row["_id"]},
                {"RelatedStockCodes": " ".join(related_stock_codes_list)})
            logging.info(
                "[{} -> {} -> {}] updated RelatedStockCodes key value ... ".
                format(database_name, collection_name, row["Date"]))
Exemplo n.º 4
0
class NbdSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(NbdSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(
            collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(
            import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)

    def get_url_info(self, url):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        span_list = bs.find_all("span")
        part = bs.find_all("p")
        article = ""
        date = ""
        for span in span_list:
            if "class" in span.attrs and span.text and span["class"] == [
                    "time"
            ]:
                string = span.text.split()
                for dt in string:
                    if dt.find("-") != -1:
                        date += dt + " "
                    elif dt.find(":") != -1:
                        date += dt
                break
        for paragraph in part:
            chn_status = utils.count_chn(str(paragraph))
            possible = chn_status[1]
            if possible > self.is_article_prob:
                article += str(paragraph)
        while article.find("<") != -1 and article.find(">") != -1:
            string = article[article.find("<"):article.find(">") + 1]
            article = article.replace(string, "")
        while article.find("\u3000") != -1:
            article = article.replace("\u3000", "")
        article = " ".join(re.split(" +|\n+", article)).strip()

        return [date, article]

    def get_historical_news(self, start_page=684):
        date_list = self.db_obj.get_data(self.db_name,
                                         self.col_name,
                                         keys=["Date"])["Date"].to_list()
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        if len(date_list) == 0:
            # 说明没有历史数据,从头开始爬取
            crawled_urls_list = []
            page_urls = [
                "{}/{}".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_id)
                for page_id in range(start_page, 0, -1)
            ]
            for page_url in page_urls:
                bs = utils.html_parser(page_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "click-statistic" in a.attrs and a.string \
                            and a["click-statistic"].find("Article_") != -1 \
                            and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                        if a["href"] not in crawled_urls_list:
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.NBD_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            if not result:
                                # 爬取失败的情况
                                logging.info("[FAILED] {} {}".format(
                                    a.string, a["href"]))
                            else:
                                # 有返回但是article为null的情况
                                date, article = result
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        # "PageId": page_url.split("/")[-1],
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    # self.col.insert_one(data)
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
        else:
            is_stop = False
            start_date = max(date_list)
            page_start_id = 1
            while not is_stop:
                page_url = "{}/{}".format(
                    config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_start_id)
                bs = utils.html_parser(page_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "click-statistic" in a.attrs and a.string \
                            and a["click-statistic"].find("Article_") != -1 \
                            and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                        result = self.get_url_info(a["href"])
                        while not result:
                            self.terminated_amount += 1
                            if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                # 始终无法爬取的URL保存起来
                                with open(
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                        "a+") as file:
                                    file.write("{}\n".format(a["href"]))
                                logging.info(
                                    "rejected by remote server longer than {} minutes, "
                                    "and the failed url has been written in path {}"
                                    .format(
                                        config.NBD_MAX_REJECTED_AMOUNTS,
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH))
                                break
                            logging.info(
                                "rejected by remote server, request {} again after "
                                "{} seconds...".format(
                                    a["href"], 60 * self.terminated_amount))
                            time.sleep(60 * self.terminated_amount)
                            result = self.get_url_info(a["href"])
                        if not result:
                            # 爬取失败的情况
                            logging.info("[FAILED] {} {}".format(
                                a.string, a["href"]))
                        else:
                            # 有返回但是article为null的情况
                            date, article = result
                            if date > start_date:
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
                            else:
                                is_stop = True
                                break
                if not is_stop:
                    page_start_id += 1

    def get_realtime_news(self, interval=60):
        page_url = "{}/1".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD)
        logging.info(
            "start real-time crawling of URL -> {}, request every {} secs ... "
            .format(page_url, interval))
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls = []
        date_list = self.db_obj.get_data(self.db_name,
                                         self.col_name,
                                         keys=["Date"])["Date"].to_list()
        latest_date = max(date_list)
        while True:
            # 每隔一定时间轮询该网址
            if len(crawled_urls) > 100:
                # 防止list过长,内存消耗大,维持list在100条
                crawled_urls.pop(0)
            bs = utils.html_parser(page_url)
            a_list = bs.find_all("a")
            for a in a_list:
                if "click-statistic" in a.attrs and a.string \
                        and a["click-statistic"].find("Article_") != -1 \
                        and a["href"].find("http://www.nbd.com.cn/articles/") != -1:
                    if a["href"] not in crawled_urls:
                        result = self.get_url_info(a["href"])
                        while not result:
                            self.terminated_amount += 1
                            if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                # 始终无法爬取的URL保存起来
                                with open(
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                        "a+") as file:
                                    file.write("{}\n".format(a["href"]))
                                logging.info(
                                    "rejected by remote server longer than {} minutes, "
                                    "and the failed url has been written in path {}"
                                    .format(
                                        config.NBD_MAX_REJECTED_AMOUNTS,
                                        config.
                                        RECORD_NBD_FAILED_URL_TXT_FILE_PATH))
                                break
                            logging.info(
                                "rejected by remote server, request {} again after "
                                "{} seconds...".format(
                                    a["href"], 60 * self.terminated_amount))
                            time.sleep(60 * self.terminated_amount)
                            result = self.get_url_info(a["href"])
                        if not result:
                            # 爬取失败的情况
                            logging.info("[FAILED] {} {}".format(
                                a.string, a["href"]))
                        else:
                            # 有返回但是article为null的情况
                            date, article = result
                            if date > latest_date:
                                while article == "" and self.is_article_prob >= .1:
                                    self.is_article_prob -= .1
                                    result = self.get_url_info(a["href"])
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH,
                                                    "a+") as file:
                                                file.write("{}\n".format(
                                                    a["href"]))
                                            logging.info(
                                                "rejected by remote server longer than {} minutes, "
                                                "and the failed url has been written in path {}"
                                                .format(
                                                    config.
                                                    NBD_MAX_REJECTED_AMOUNTS,
                                                    config.
                                                    RECORD_NBD_FAILED_URL_TXT_FILE_PATH
                                                ))
                                            break
                                        logging.info(
                                            "rejected by remote server, request {} again after "
                                            "{} seconds...".format(
                                                a["href"],
                                                60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"])
                                    date, article = result
                                self.is_article_prob = .5
                                if article != "":
                                    related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                        article, name_code_dict)
                                    data = {
                                        "Date":
                                        date,
                                        # "PageId": page_url.split("/")[-1],
                                        "Url":
                                        a["href"],
                                        "Title":
                                        a.string,
                                        "Article":
                                        article,
                                        "RelatedStockCodes":
                                        " ".join(related_stock_codes_list)
                                    }
                                    # self.col.insert_one(data)
                                    self.db_obj.insert_data(
                                        self.db_name, self.col_name, data)
                                    crawled_urls.append(a["href"])
                                    logging.info("[SUCCESS] {} {} {}".format(
                                        date, a.string, a["href"]))
            # logging.info("sleep {} secs then request again ... ".format(interval))
            time.sleep(interval)
Exemplo n.º 5
0
class GenStockNewsDB(object):
    def __init__(self):
        self.database = Database()
        # 获取从1990-12-19至2020-12-31股票交易日数据
        self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist()
        self.label_range = {
            3: "3DaysLabel",
            5: "5DaysLabel",
            10: "10DaysLabel",
            15: "15DaysLabel",
            30: "30DaysLabel",
            60: "60DaysLabel"
        }

    def get_all_news_about_specific_stock(self, database_name,
                                          collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(
            next(
                self.database.get_collection(database_name,
                                             collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba",
                                        user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name,
                                                   collection_name)
        # 创建stock_code为名称的collection
        stock_symbol_list = self.database.get_data(
            "stock", "basic_info", keys=["symbol"])["symbol"].to_list()
        for symbol in stock_symbol_list:
            _collection = self.database.get_collection(
                config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
            _tmp_num_stat = 0
            for row in self.database.get_collection(
                    database_name, collection_name).find():  # 迭代器
                if symbol[2:] in row["RelatedStockCodes"].split(" "):
                    # 返回新闻发布后n天的标签
                    _tmp_dict = {}
                    for label_days, key_name in self.label_range.items():
                        _tmp_res = self._label_news(
                            datetime.datetime.strptime(
                                row["Date"].split(" ")[0], "%Y-%m-%d"), symbol,
                            label_days)
                        if _tmp_res:
                            _tmp_dict.update({key_name: _tmp_res})
                    _data = {
                        "Date": row["Date"],
                        "Url": row["Url"],
                        "Title": row["Title"],
                        "Article": row["Article"],
                        "OriDB": database_name,
                        "OriCOL": collection_name
                    }
                    _data.update(_tmp_dict)
                    _collection.insert_one(_data)
                    _tmp_num_stat += 1
            logging.info(
                "there are {} news mentioned {} in {} collection ... ".format(
                    _tmp_num_stat, symbol, collection_name))

    def _label_news(self, date, symbol, n_days):
        """
        :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0)
        :param symbol: 类型str,表示股票标的,如sh600000
        :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息
        """
        # 计算新闻发布当天经过n_days天后的具体年月日
        this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                symbol,
                                                query={"date": date})
        # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六,
        # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据
        tmp_date = date
        if this_date_data is None:
            i = 1
            while this_date_data is None and i <= 10:
                tmp_date -= datetime.timedelta(days=i)
                # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据
                if tmp_date.strftime("%Y-%m-%d") in self.trade_date:
                    this_date_data = self.database.get_data(
                        config.STOCK_DATABASE_NAME,
                        symbol,
                        query={"date": tmp_date})
                i += 1
        try:
            close_price_this_date = this_date_data["close"][0]
        except Exception:
            close_price_this_date = None
        # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天
        # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据
        new_date = date + datetime.timedelta(days=n_days)
        n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   symbol,
                                                   query={"date": new_date})
        if n_days_later_data is None:
            i = 1
            while n_days_later_data is None and i <= 10:
                new_date = date + datetime.timedelta(days=n_days + i)
                if new_date.strftime("%Y-%m-%d") in self.trade_date:
                    n_days_later_data = self.database.get_data(
                        config.STOCK_DATABASE_NAME,
                        symbol,
                        query={"date": new_date})
                i += 1
        try:
            close_price_n_days_later = n_days_later_data["close"][0]
        except Exception:
            close_price_n_days_later = None
        # 判断条件:
        # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息
        # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息
        # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息
        # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息
        # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响
        param = 0.01
        if n_days <= 10:
            param = 0.03
        elif 10 < n_days <= 15:
            param = 0.05
        elif 15 < n_days <= 30:
            param = 0.10
        elif 30 < n_days <= 60:
            param = 0.15
        if close_price_this_date is not None and close_price_n_days_later is not None:
            if (close_price_n_days_later -
                    close_price_this_date) / close_price_this_date > param:
                return "利好"
            elif (close_price_n_days_later -
                  close_price_this_date) / close_price_this_date < -param:
                return "利空"
            else:
                return "中性"
        else:
            return False
Exemplo n.º 6
0
from Kite import config
from Kite.database import Database
from Killua.denull import DeNull
from Killua.deduplication import Deduplication
from Gon.cnstockspyder import CnStockSpyder

redis_client = redis.StrictRedis(
    config.REDIS_IP,
    port=config.REDIS_PORT,
    db=config.CACHE_RECORED_OPENED_PYTHON_PROGRAM_DB_ID)
redis_client.lpush(config.CACHE_RECORED_OPENED_PYTHON_PROGRAM_VAR,
                   "realtime_starter_cnstock.py")

obj = Database()
df = obj.get_data(config.DATABASE_NAME,
                  config.COLLECTION_NAME_CNSTOCK,
                  keys=["Date", "Category"])

cnstock_spyder = CnStockSpyder(config.DATABASE_NAME,
                               config.COLLECTION_NAME_CNSTOCK)
# 先补充历史数据,比如已爬取数据到2020-12-01,但是启动实时爬取程序在2020-12-23,则先
# 自动补充爬取2020-12-02至2020-12-23的新闻数据
for url_to_be_crawled, type_chn in config.WEBSITES_LIST_TO_BE_CRAWLED_CNSTOCK.items(
):
    # 查询type_chn的最近一条数据的时间
    latets_date_in_db = max(df[df.Category == type_chn]["Date"].to_list())
    cnstock_spyder.get_historical_news(url_to_be_crawled,
                                       category_chn=type_chn,
                                       start_date=latets_date_in_db)

Deduplication(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK).run()
class GenStockNewsDB(object):

    def __init__(self):
        self.database = Database()
        # 获取从1990-12-19至2020-12-31股票交易日数据
        self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist()
        self.label_range = {3: "3DaysLabel",
                            5: "5DaysLabel",
                            10: "10DaysLabel",
                            15: "15DaysLabel",
                            30: "30DaysLabel",
                            60: "60DaysLabel"}
        self.redis_client = redis.StrictRedis(host=config.REDIS_IP,
                                              port=config.REDIS_PORT,
                                              db=config.CACHE_NEWS_REDIS_DB_ID)
        self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d"))
        self.redis_client.delete("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML))
        self._stock_news_nums_stat()

    def get_all_news_about_specific_stock(self, database_name, collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(next(self.database.get_collection(database_name, collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba", user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name, collection_name)
        # 创建stock_code为名称的collection
        stock_symbol_list = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                                   keys=["symbol"])["symbol"].to_list()
        col_names = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None)
        for symbol in stock_symbol_list:
            if symbol not in col_names:
                # if int(symbol[2:]) > 837:
                _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
                _tmp_num_stat = 0
                for row in self.database.get_collection(database_name, collection_name).find():  # 迭代器
                    if symbol[2:] in row["RelatedStockCodes"].split(" "):
                        # 返回新闻发布后n天的标签
                        _tmp_dict = {}
                        for label_days, key_name in self.label_range.items():
                            _tmp_res = self._label_news(
                                datetime.datetime.strptime(row["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days)
                            _tmp_dict.update({key_name: _tmp_res})
                        _data = {"Date": row["Date"],
                                 "Url": row["Url"],
                                 "Title": row["Title"],
                                 "Article": row["Article"],
                                 "OriDB": database_name,
                                 "OriCOL": collection_name}
                        _data.update(_tmp_dict)
                        _collection.insert_one(_data)
                        _tmp_num_stat += 1
                logging.info("there are {} news mentioned {} in {} collection need to be fetched ... "
                             .format(_tmp_num_stat, symbol, collection_name))
            # else:
            #     logging.info("{} has fetched all related news from {}...".format(symbol, collection_name))

    def listen_redis_queue(self):
        # 监听redis消息队列,当新的实时数据过来时,根据"RelatedStockCodes"字段,将新闻分别保存到对应的股票数据库
        # e.g.:缓存新的一条数据中,"RelatedStockCodes"字段数据为"603386 603003 600111 603568",则将该条新闻分别
        # 都存进这四支股票对应的数据库中
        crawled_url_today = set()
        while True:
            date_now = datetime.datetime.now().strftime("%Y-%m-%d")
            if date_now != self.redis_client.get("today_date").decode():
                crawled_url_today = set()
                self.redis_client.set("today_date", date_now)
            if self.redis_client.llen(config.CACHE_NEWS_LIST_NAME) != 0:
                data = json.loads(self.redis_client.lindex(config.CACHE_NEWS_LIST_NAME, -1))
                if data["Url"] not in crawled_url_today:  # 排除重复插入冗余文本
                    crawled_url_today.update({data["Url"]})
                    if data["RelatedStockCodes"] != "":
                        for stock_code in data["RelatedStockCodes"].split(" "):
                            # 将新闻分别送进相关股票数据库
                            symbol = "sh{}".format(stock_code) if stock_code[0] == "6" else "sz{}".format(stock_code)
                            _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
                            _tmp_dict = {}
                            for label_days, key_name in self.label_range.items():
                                _tmp_res = self._label_news(
                                    datetime.datetime.strptime(data["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days)
                                _tmp_dict.update({key_name: _tmp_res})
                            _data = {"Date": data["Date"],
                                     "Url": data["Url"],
                                     "Title": data["Title"],
                                     "Article": data["Article"],
                                     "OriDB": data["OriDB"],
                                     "OriCOL": data["OriCOL"]}
                            _data.update(_tmp_dict)
                            _collection.insert_one(_data)
                            logging.info("the real-time fetched news {}, which was saved in [DB:{} - COL:{}] ...".format(data["Title"],
                                                                                                                         config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE,
                                                                                                                         symbol))
                            #
                            # if symbol.encode() in self.redis_client.lrange("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), 0, -1):
                            #     label_name = "3DaysLabel"
                            #     # classifier_save_path = "{}_classifier.pkl".format(symbol)
                            #     ori_dict_path = "{}_docs_dict.dict".format(symbol)
                            #     bowvec_save_path = "{}_bowvec.mm".format(symbol)
                            #
                            #     topicmodelling = TopicModelling()
                            #     chn_label = topicmodelling.classify_stock_news(data["Article"],
                            #                                                    config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE,
                            #                                                    symbol,
                            #                                                    label_name=label_name,
                            #                                                    topic_model_type="lsi",
                            #                                                    classifier_model="rdforest",  # rdforest / svm
                            #                                                    ori_dict_path=ori_dict_path,
                            #                                                    bowvec_save_path=bowvec_save_path)
                            #     logging.info(
                            #         "document '{}...' was classified with label '{}' for symbol {} ... ".format(
                            #             data["Article"][:20], chn_label, symbol))

                    self.redis_client.rpop(config.CACHE_NEWS_LIST_NAME)
                    logging.info("now pop {} from redis queue of [DB:{} - KEY:{}] ... ".format(data["Title"],
                                                                                               config.CACHE_NEWS_REDIS_DB_ID,
                                                                                               config.CACHE_NEWS_LIST_NAME))

    def _label_news(self, date, symbol, n_days):
        """
        :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0)
        :param symbol: 类型str,表示股票标的,如sh600000
        :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息
        """
        # 计算新闻发布当天经过n_days天后的具体年月日
        this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                symbol,
                                                query={"date": date})
        # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六,
        # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据
        tmp_date = date
        if this_date_data is None:
            i = 1
            while this_date_data is None and i <= 10:
                tmp_date -= datetime.timedelta(days=i)
                # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据
                if tmp_date.strftime("%Y-%m-%d") in self.trade_date:
                    this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                            symbol,
                                                            query={"date": tmp_date})
                i += 1
        try:
            close_price_this_date = this_date_data["close"][0]
        except Exception:
            close_price_this_date = None
        # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天
        # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据
        new_date = date + datetime.timedelta(days=n_days)
        n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   symbol,
                                                   query={"date": new_date})
        if n_days_later_data is None:
            i = 1
            while n_days_later_data is None and i <= 10:
                new_date = date + datetime.timedelta(days=n_days+i)
                if new_date.strftime("%Y-%m-%d") in self.trade_date:
                    n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                               symbol,
                                                               query={"date": new_date})
                i += 1
        try:
            close_price_n_days_later = n_days_later_data["close"][0]
        except Exception:
            close_price_n_days_later = None
        # 判断条件:
        # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息
        # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息
        # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息
        # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息
        # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响
        param = 0.01
        if n_days <= 10:
            param = 0.03
        elif 10 < n_days <= 15:
            param = 0.05
        elif 15 < n_days <= 30:
            param = 0.10
        elif 30 < n_days <= 60:
            param = 0.15
        if close_price_this_date is not None and close_price_n_days_later is not None:
            if (close_price_n_days_later - close_price_this_date) / close_price_this_date > param:
                return "利好"
            elif (close_price_n_days_later - close_price_this_date) / close_price_this_date < -param:
                return "利空"
            else:
                return "中性"
        else:
            return ""

    def _stock_news_nums_stat(self):
        cols_list = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None)
        for sym in cols_list:
            if self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, sym).estimated_document_count() > config.MINIMUM_STOCK_NEWS_NUM_FOR_ML:
                self.redis_client.lpush("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), sym)
class CnStockSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(CnStockSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(
            collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(
            import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
        self.redis_client = redis.StrictRedis(host=config.REDIS_IP,
                                              port=config.REDIS_PORT,
                                              db=config.CACHE_NEWS_REDIS_DB_ID)

    def get_url_info(self, url):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        span_list = bs.find_all("span")
        part = bs.find_all("p")
        article = ""
        date = ""
        for span in span_list:
            if "class" in span.attrs and span["class"] == ["timer"]:
                date = span.text
                break
        for paragraph in part:
            chn_status = utils.count_chn(str(paragraph))
            possible = chn_status[1]
            if possible > self.is_article_prob:
                article += str(paragraph)
        while article.find("<") != -1 and article.find(">") != -1:
            string = article[article.find("<"):article.find(">") + 1]
            article = article.replace(string, "")
        while article.find("\u3000") != -1:
            article = article.replace("\u3000", "")
        article = " ".join(re.split(" +|\n+", article)).strip()

        return [date, article]

    def get_historical_news(self, url, category_chn=None, start_date=None):
        """
        :param url: 爬虫网页
        :param category_chn: 所属类别, 中文字符串, 包括'公司聚焦', '公告解读', '公告快讯', '利好公告'
        :param start_date: 数据库中category_chn类别新闻最近一条数据的时间
        """
        assert category_chn is not None
        driver = webdriver.Chrome(executable_path=config.CHROME_DRIVER)
        btn_more_text = ""
        crawled_urls_list = self.extract_data(["Url"])[0]
        logging.info("historical data length -> {} ... ".format(
            len(crawled_urls_list)))
        # crawled_urls_list = []
        driver.get(url)
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        if start_date is None:
            while btn_more_text != "没有更多":
                more_btn = driver.find_element_by_id('j_more_btn')
                btn_more_text = more_btn.text
                logging.info("1-{}".format(more_btn.text))
                if btn_more_text == "加载更多":
                    more_btn.click()
                    time.sleep(random.random())  # sleep random time less 1s
                elif btn_more_text == "加载中...":
                    time.sleep(random.random() + 2)
                    more_btn = driver.find_element_by_id('j_more_btn')
                    btn_more_text = more_btn.text
                    logging.info("2-{}".format(more_btn.text))
                    if btn_more_text == "加载更多":
                        more_btn.click()
                else:
                    more_btn.click()
                    break
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls_list:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            data = {
                                "Date":
                                date,
                                "Category":
                                category_chn,
                                "Url":
                                a["href"],
                                "Title":
                                a["title"],
                                "Article":
                                article,
                                "RelatedStockCodes":
                                " ".join(related_stock_codes_list)
                            }
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name,
                                                    self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
        else:
            # 当start_date不为None时,补充历史数据
            is_click_button = True
            start_get_url_info = False
            tmp_a = None
            while is_click_button:
                bs = BeautifulSoup(driver.page_source, "html.parser")
                for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                    a = li.find_all("h2")[0].find("a")
                    if tmp_a is not None and a["href"] != tmp_a:
                        continue
                    elif tmp_a is not None and a["href"] == tmp_a:
                        start_get_url_info = True
                    if start_get_url_info:
                        date, _ = self.get_url_info(a["href"])
                        if date <= start_date:
                            is_click_button = False
                            break
                tmp_a = a["href"]
                if is_click_button:
                    more_btn = driver.find_element_by_id('j_more_btn')
                    more_btn.click()
            # 从一开始那条新闻到tmp_a都是新增新闻,不包括tmp_a
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] != tmp_a:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            data = {
                                "Date":
                                date,
                                "Category":
                                category_chn,
                                "Url":
                                a["href"],
                                "Title":
                                a["title"],
                                "Article":
                                article,
                                "RelatedStockCodes":
                                " ".join(related_stock_codes_list)
                            }
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name,
                                                    self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
                else:
                    break
        driver.quit()

    def get_realtime_news(self, url, category_chn=None, interval=60):
        logging.info(
            "start real-time crawling of URL -> {}, request every {} secs ... "
            .format(url, interval))
        assert category_chn is not None
        # TODO: 由于cnstock爬取的数据量并不大,这里暂时是抽取历史所有数据进行去重,之后会修改去重策略
        name_code_df = self.db_obj.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls = self.db_obj.get_data(self.db_name,
                                            self.col_name,
                                            keys=["Url"])["Url"].to_list()
        while True:
            # 每隔一定时间轮询该网址
            bs = utils.html_parser(url)
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls:  # latest_3_days_crawled_href
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                    "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info(
                                "rejected by remote server longer than {} minutes, "
                                "and the failed url has been written in path {}"
                                .format(
                                    config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                    config.
                                    RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info(
                            "rejected by remote server, request {} again after "
                            "{} seconds...".format(a["href"], 60 *
                                                   self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(
                            a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH,
                                            "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info(
                                        "rejected by remote server longer than {} minutes, "
                                        "and the failed url has been written in path {}"
                                        .format(
                                            config.
                                            CNSTOCK_MAX_REJECTED_AMOUNTS,
                                            config.
                                            RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH
                                        ))
                                    break
                                logging.info(
                                    "rejected by remote server, request {} again after "
                                    "{} seconds...".format(
                                        a["href"],
                                        60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(
                                article, name_code_dict)
                            self.db_obj.insert_data(
                                self.db_name, self.col_name, {
                                    "Date":
                                    date,
                                    "Category":
                                    category_chn,
                                    "Url":
                                    a["href"],
                                    "Title":
                                    a["title"],
                                    "Article":
                                    article,
                                    "RelatedStockCodes":
                                    " ".join(related_stock_codes_list)
                                })
                            self.redis_client.lpush(
                                config.CACHE_NEWS_LIST_NAME,
                                json.dumps({
                                    "Date":
                                    date,
                                    "Category":
                                    category_chn,
                                    "Url":
                                    a["href"],
                                    "Title":
                                    a["title"],
                                    "Article":
                                    article,
                                    "RelatedStockCodes":
                                    " ".join(related_stock_codes_list),
                                    "OriDB":
                                    config.DATABASE_NAME,
                                    "OriCOL":
                                    config.COLLECTION_NAME_CNSTOCK
                                }))
                            logging.info("[SUCCESS] {} {} {}".format(
                                date, a["title"], a["href"]))
                            crawled_urls.append(a["href"])
            # logging.info("sleep {} secs then request {} again ... ".format(interval, url))
            time.sleep(interval)
class JrjSpyder(Spyder):

    def __init__(self, database_name, collection_name):
        super(JrjSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name
        self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)

    def get_url_info(self, url, specific_date):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        date = ""
        for span in bs.find_all("span"):
            if span.contents[0] == "jrj_final_date_start":
                date = span.text.replace("\r", "").replace("\n", "")
                break
        if date == "":
            date = specific_date
        article = ""
        for p in bs.find_all("p"):
            if not p.find_all("jrj_final_daohang_start") and p.attrs == {} and \
                    not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i") and not p.find_all("span"):
                # if p.contents[0] != "jrj_final_daohang_start1" and p.attrs == {} and \
                #         not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i"):
                article += p.text.replace("\r", "").replace("\n", "").replace("\u3000", "")

        return [date, article]

    def get_historical_news(self, url, start_date=None, end_date=None):
        name_code_df = self.db_obj.get_data(config.STOCK_DATABASE_NAME,
                                            config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)

        crawled_urls_list = []
        if end_date is None:
            end_date = datetime.datetime.now().strftime("%Y-%m-%d")

        if start_date is None:
            # 如果start_date是None,则从历史数据库最新的日期补充爬取到最新日期
            # e.g. history_latest_date_str -> "2020-12-08"
            #      history_latest_date_dt -> datetime.date(2020, 12, 08)
            #      start_date -> "2020-12-09"
            history_latest_date_list = self.db_obj.get_data(self.db_name,
                                                            self.col_name,
                                                            keys=["Date"])["Date"].to_list()
            if len(history_latest_date_list) != 0:
                history_latest_date_str = max(history_latest_date_list).split(" ")[0]
                history_latest_date_dt = datetime.datetime.strptime(history_latest_date_str, "%Y-%m-%d").date()
                offset = datetime.timedelta(days=1)
                start_date = (history_latest_date_dt + offset).strftime('%Y-%m-%d')
            else:
                start_date = config.JRJ_REQUEST_DEFAULT_DATE

        dates_list = utils.get_date_list_from_range(start_date, end_date)
        dates_separated_into_ranges_list = utils.gen_dates_list(dates_list, config.JRJ_DATE_RANGE)

        for dates_range in dates_separated_into_ranges_list:
            for date in dates_range:
                first_url = "{}/{}/{}_1.shtml".format(url, date.replace("-", "")[0:6], date.replace("-", ""))
                max_pages_num = utils.search_max_pages_num(first_url, date)
                for num in range(1, max_pages_num + 1):
                    _url = "{}/{}/{}_{}.shtml".format(url, date.replace("-", "")[0:6], date.replace("-", ""), str(num))
                    bs = utils.html_parser(_url)
                    a_list = bs.find_all("a")
                    for a in a_list:
                        if "href" in a.attrs and a.string and \
                                a["href"].find("/{}/{}/".format(date.replace("-", "")[:4],
                                                                date.replace("-", "")[4:6])) != -1:
                            if a["href"] not in crawled_urls_list:
                                # 如果标题不包含"收盘","报于"等字样,即可写入数据库,因为包含这些字样标题的新闻多为机器自动生成
                                if a.string.find("收盘") == -1 and a.string.find("报于") == -1 and \
                                        a.string.find("新三板挂牌上市") == -1:
                                    result = self.get_url_info(a["href"], date)
                                    while not result:
                                        self.terminated_amount += 1
                                        if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                            # 始终无法爬取的URL保存起来
                                            with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                file.write("{}\n".format(a["href"]))
                                            logging.info("rejected by remote server longer than {} minutes, "
                                                         "and the failed url has been written in path {}"
                                                         .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                 config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                            break
                                        logging.info("rejected by remote server, request {} again after "
                                                     "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                        time.sleep(60 * self.terminated_amount)
                                        result = self.get_url_info(a["href"], date)
                                    if not result:
                                        # 爬取失败的情况
                                        logging.info("[FAILED] {} {}".format(a.string, a["href"]))
                                    else:
                                        # 有返回但是article为null的情况
                                        article_specific_date, article = result
                                        while article == "" and self.is_article_prob >= .1:
                                            self.is_article_prob -= .1
                                            result = self.get_url_info(a["href"], date)
                                            while not result:
                                                self.terminated_amount += 1
                                                if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                                    # 始终无法爬取的URL保存起来
                                                    with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                        file.write("{}\n".format(a["href"]))
                                                    logging.info("rejected by remote server longer than {} minutes, "
                                                                 "and the failed url has been written in path {}"
                                                                 .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                         config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                                    break
                                                logging.info("rejected by remote server, request {} again after "
                                                             "{} seconds...".format(a["href"],
                                                                                    60 * self.terminated_amount))
                                                time.sleep(60 * self.terminated_amount)
                                                result = self.get_url_info(a["href"], date)
                                            article_specific_date, article = result
                                        self.is_article_prob = .5
                                        if article != "":
                                                related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(article,
                                                                                                                                  name_code_dict)
                                                data = {"Date": article_specific_date,
                                                        "Url": a["href"],
                                                        "Title": a.string,
                                                        "Article": article,
                                                        "RelatedStockCodes": " ".join(related_stock_codes_list)}
                                                # self.col.insert_one(data)
                                                self.db_obj.insert_data(self.db_name, self.col_name, data)
                                                logging.info("[SUCCESS] {} {} {}".format(article_specific_date,
                                                                                         a.string,
                                                                                         a["href"]))
                                    self.terminated_amount = 0  # 爬取结束后重置该参数
                                else:
                                    logging.info("[QUIT] {}".format(a.string))

    def get_realtime_news(self, interval=60):
        name_code_df = self.db_obj.get_data(config.STOCK_DATABASE_NAME,
                                            config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        crawled_urls_list = []
        is_change_date = False
        last_date = datetime.datetime.now().strftime("%Y-%m-%d")
        while True:
            today_date = datetime.datetime.now().strftime("%Y-%m-%d")
            if today_date != last_date:
                is_change_date = True
                last_date = today_date
            if is_change_date:
                crawled_urls_list = []
                is_change_date = False
            _url = "{}/{}/{}_1.shtml".format(config.WEBSITES_LIST_TO_BE_CRAWLED_JRJ,
                                             today_date.replace("-", "")[0:6],
                                             today_date.replace("-", ""))
            max_pages_num = utils.search_max_pages_num(_url, today_date)
            for num in range(1, max_pages_num + 1):
                _url = "{}/{}/{}_{}.shtml".format(config.WEBSITES_LIST_TO_BE_CRAWLED_JRJ,
                                                  today_date.replace("-", "")[0:6],
                                                  today_date.replace("-", ""),
                                                  str(num))
                bs = utils.html_parser(_url)
                a_list = bs.find_all("a")
                for a in a_list:
                    if "href" in a.attrs and a.string and \
                            a["href"].find("/{}/{}/".format(today_date.replace("-", "")[:4],
                                                            today_date.replace("-", "")[4:6])) != -1:
                        if a["href"] not in crawled_urls_list:
                            # 如果标题不包含"收盘","报于"等字样,即可写入数据库,因为包含这些字样标题的新闻多为机器自动生成
                            if a.string.find("收盘") == -1 and a.string.find("报于") == -1 and \
                                    a.string.find("新三板挂牌上市") == -1:
                                result = self.get_url_info(a["href"], today_date)
                                while not result:
                                    self.terminated_amount += 1
                                    if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                        # 始终无法爬取的URL保存起来
                                        with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                            file.write("{}\n".format(a["href"]))
                                        logging.info("rejected by remote server longer than {} minutes, "
                                                     "and the failed url has been written in path {}"
                                                     .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                             config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                        break
                                    logging.info("rejected by remote server, request {} again after "
                                                 "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                    time.sleep(60 * self.terminated_amount)
                                    result = self.get_url_info(a["href"], today_date)
                                if not result:
                                    # 爬取失败的情况
                                    logging.info("[FAILED] {} {}".format(a.string, a["href"]))
                                else:
                                    # 有返回但是article为null的情况
                                    article_specific_date, article = result
                                    while article == "" and self.is_article_prob >= .1:
                                        self.is_article_prob -= .1
                                        result = self.get_url_info(a["href"], today_date)
                                        while not result:
                                            self.terminated_amount += 1
                                            if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS:
                                                # 始终无法爬取的URL保存起来
                                                with open(config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                                    file.write("{}\n".format(a["href"]))
                                                logging.info("rejected by remote server longer than {} minutes, "
                                                             "and the failed url has been written in path {}"
                                                             .format(config.JRJ_MAX_REJECTED_AMOUNTS,
                                                                     config.RECORD_JRJ_FAILED_URL_TXT_FILE_PATH))
                                                break
                                            logging.info("rejected by remote server, request {} again after "
                                                         "{} seconds...".format(a["href"],
                                                                                60 * self.terminated_amount))
                                            time.sleep(60 * self.terminated_amount)
                                            result = self.get_url_info(a["href"], today_date)
                                        article_specific_date, article = result
                                    self.is_article_prob = .5
                                    if article != "":
                                        related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article(article,
                                                                                                                          name_code_dict)
                                        data = {"Date": article_specific_date,
                                                "Url": a["href"],
                                                "Title": a.string,
                                                "Article": article,
                                                "RelatedStockCodes": " ".join(related_stock_codes_list)}
                                        # self.col.insert_one(data)
                                        self.db_obj.insert_data(self.db_name, self.col_name, data)
                                        logging.info("[SUCCESS] {} {} {}".format(article_specific_date,
                                                                                 a.string,
                                                                                 a["href"]))
                                self.terminated_amount = 0  # 爬取结束后重置该参数
                            else:
                                logging.info("[QUIT] {}".format(a.string))
                            crawled_urls_list.append(a["href"])
            # logging.info("sleep {} secs then request again ... ".format(interval))
            time.sleep(interval)
class CnStockSpyder(Spyder):

    def __init__(self, database_name, collection_name):
        super(CnStockSpyder, self).__init__()
        self.db_obj = Database()
        self.col = self.db_obj.conn[database_name].get_collection(collection_name)
        self.terminated_amount = 0
        self.db_name = database_name
        self.col_name = collection_name

    def get_url_info(self, url):
        try:
            bs = utils.html_parser(url)
        except Exception:
            return False
        span_list = bs.find_all("span")
        part = bs.find_all("p")
        article = ""
        date = ""
        for span in span_list:
            if "class" in span.attrs and span["class"] == ["timer"]:
                date = span.text
                break
        for paragraph in part:
            chn_status = utils.count_chn(str(paragraph))
            possible = chn_status[1]
            if possible > self.is_article_prob:
                article += str(paragraph)
        while article.find("<") != -1 and article.find(">") != -1:
            string = article[article.find("<"):article.find(">")+1]
            article = article.replace(string, "")
        while article.find("\u3000") != -1:
            article = article.replace("\u3000", "")
        article = " ".join(re.split(" +|\n+", article)).strip()

        return [date, article]

    def get_historical_news(self, url, category_chn=None, start_date=None):
        """
        :param url: 爬虫网页
        :param category_chn: 所属类别, 中文字符串, 包括'公司聚焦', '公告解读', '公告快讯', '利好公告'
        :param start_date: 数据库中category_chn类别新闻最近一条数据的时间
        """
        assert category_chn is not None
        driver = webdriver.Chrome(executable_path=config.CHROME_DRIVER)
        btn_more_text = ""
        crawled_urls_list = self.extract_data(["Url"])[0]
        logging.info("historical data length -> {} ... ".format(len(crawled_urls_list)))
        # crawled_urls_list = []
        driver.get(url)
        if start_date is None:
            while btn_more_text != "没有更多":
                more_btn = driver.find_element_by_id('j_more_btn')
                btn_more_text = more_btn.text
                logging.info("1-{}".format(more_btn.text))
                if btn_more_text == "加载更多":
                    more_btn.click()
                    time.sleep(random.random())  # sleep random time less 1s
                elif btn_more_text == "加载中...":
                    time.sleep(random.random()+2)
                    more_btn = driver.find_element_by_id('j_more_btn')
                    btn_more_text = more_btn.text
                    logging.info("2-{}".format(more_btn.text))
                    if btn_more_text == "加载更多":
                        more_btn.click()
                else:
                    more_btn.click()
                    break
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls_list:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info("rejected by remote server longer than {} minutes, "
                                         "and the failed url has been written in path {}"
                                         .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                 config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info("rejected by remote server, request {} again after "
                                     "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info("rejected by remote server longer than {} minutes, "
                                                 "and the failed url has been written in path {}"
                                                 .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                         config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                                    break
                                logging.info("rejected by remote server, request {} again after "
                                             "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            data = {"Date": date,
                                    "Category": category_chn,
                                    "Url": a["href"],
                                    "Title": a["title"],
                                    "Article": article}
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name, self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(date, a["title"], a["href"]))
        else:
            # 当start_date不为None时,补充历史数据
            is_click_button = True
            start_get_url_info = False
            tmp_a = None
            while is_click_button:
                bs = BeautifulSoup(driver.page_source, "html.parser")
                for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                    a = li.find_all("h2")[0].find("a")
                    if tmp_a is not None and a["href"] != tmp_a:
                        continue
                    elif tmp_a is not None and a["href"] == tmp_a:
                        start_get_url_info = True
                    if start_get_url_info:
                        date, _ = self.get_url_info(a["href"])
                        if date <= start_date:
                            is_click_button = False
                            break
                tmp_a = a["href"]
                if is_click_button:
                    more_btn = driver.find_element_by_id('j_more_btn')
                    more_btn.click()
            # 从一开始那条新闻到tmp_a都是新增新闻,不包括tmp_a
            bs = BeautifulSoup(driver.page_source, "html.parser")
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] != tmp_a:
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info("rejected by remote server longer than {} minutes, "
                                         "and the failed url has been written in path {}"
                                         .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                 config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info("rejected by remote server, request {} again after "
                                     "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info("rejected by remote server longer than {} minutes, "
                                                 "and the failed url has been written in path {}"
                                                 .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                         config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                                    break
                                logging.info("rejected by remote server, request {} again after "
                                             "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            data = {"Date": date,
                                    "Category": category_chn,
                                    "Url": a["href"],
                                    "Title": a["title"],
                                    "Article": article}
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name, self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(date, a["title"], a["href"]))
                else:
                    break
        driver.quit()

    def get_realtime_news(self, url, category_chn=None, interval=60):
        logging.info("start real-time crawling of URL -> {} ... ".format(url))
        assert category_chn is not None
        # today_date = time.strftime("%Y-%m-%d", time.localtime(time.time()))
        # last_date = utils.get_date_before(1)
        # last_2_date = utils.get_date_before(2)
        # latest_3_days_crawled_href = self.db_obj.get_data(self.db_name,
        #                                                   self.col_name,
        #                                                   query={"Date": {"$regex": today_date}},
        #                                                   keys=["Url"])["Url"].to_list()
        # latest_3_days_crawled_href.extend(self.db_obj.get_data(self.db_name,
        #                                                        self.col_name,
        #                                                        query={"Date": {"$regex": last_date}},
        #                                                        keys=["Url"])["Url"].to_list())
        # latest_3_days_crawled_href.extend(self.db_obj.get_data(self.db_name,
        #                                                        self.col_name,
        #                                                        query={"Date": {"$regex": last_2_date}},
        #                                                        keys=["Url"])["Url"].to_list())
        crawled_urls = self.db_obj.get_data(self.db_name,
                                            self.col_name,
                                            keys=["Url"])["Url"].to_list()
        while True:
            # 每隔一定时间轮询该网址
            bs = utils.html_parser(url)
            for li in bs.find_all("li", attrs={"class": ["newslist"]}):
                a = li.find_all("h2")[0].find("a")
                if a["href"] not in crawled_urls:  # latest_3_days_crawled_href
                    result = self.get_url_info(a["href"])
                    while not result:
                        self.terminated_amount += 1
                        if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                            # 始终无法爬取的URL保存起来
                            with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                file.write("{}\n".format(a["href"]))
                            logging.info("rejected by remote server longer than {} minutes, "
                                         "and the failed url has been written in path {}"
                                         .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                 config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                            break
                        logging.info("rejected by remote server, request {} again after "
                                     "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                        time.sleep(60 * self.terminated_amount)
                        result = self.get_url_info(a["href"])
                    if not result:
                        # 爬取失败的情况
                        logging.info("[FAILED] {} {}".format(a["title"], a["href"]))
                    else:
                        # 有返回但是article为null的情况
                        date, article = result
                        while article == "" and self.is_article_prob >= .1:
                            self.is_article_prob -= .1
                            result = self.get_url_info(a["href"])
                            while not result:
                                self.terminated_amount += 1
                                if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS:
                                    # 始终无法爬取的URL保存起来
                                    with open(config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file:
                                        file.write("{}\n".format(a["href"]))
                                    logging.info("rejected by remote server longer than {} minutes, "
                                                 "and the failed url has been written in path {}"
                                                 .format(config.CNSTOCK_MAX_REJECTED_AMOUNTS,
                                                         config.RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH))
                                    break
                                logging.info("rejected by remote server, request {} again after "
                                             "{} seconds...".format(a["href"], 60 * self.terminated_amount))
                                time.sleep(60 * self.terminated_amount)
                                result = self.get_url_info(a["href"])
                            date, article = result
                        self.is_article_prob = .5
                        if article != "":
                            data = {"Date": date,
                                    "Category": category_chn,
                                    "Url": a["href"],
                                    "Title": a["title"],
                                    "Article": article}
                            # self.col.insert_one(data)
                            self.db_obj.insert_data(self.db_name, self.col_name, data)
                            logging.info("[SUCCESS] {} {} {}".format(date, a["title"], a["href"]))
                            crawled_urls.append(a["href"])
            logging.info("sleep {} secs then request {} again ... ".format(interval, url))
            time.sleep(interval)
class StockInfoSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(StockInfoSpyder, self).__init__()
        self.db_obj = Database()
        self.col_basic_info = self.db_obj.get_collection(
            database_name, collection_name)
        self.database_name = database_name
        self.collection_name = collection_name

    def get_stock_code_info(self):
        stock_info_df = ak.stock_info_a_code_name()  # 获取所有A股code和name
        stock_symbol_code = ak.stock_zh_a_spot().get(
            ["symbol", "code"])  # 获取A股所有股票的symbol和code
        for _id in range(stock_info_df.shape[0]):
            _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df
                                        .iloc[_id].code].symbol.values
            if len(_symbol) != 0:
                _dict = {"symbol": _symbol[0]}
                _dict.update(stock_info_df.iloc[_id].to_dict())
                self.col_basic_info.insert_one(_dict)

    def get_historical_news(self, start_date=None, end_date=None, freq="day"):
        if end_date is None:
            end_date = datetime.datetime.now().strftime("%Y%m%d")
        stock_symbol_list = self.col_basic_info.distinct("symbol")
        if len(stock_symbol_list) == 0:
            self.get_stock_code_info()
            stock_symbol_list = self.col_basic_info.distinct("symbol")
        if freq == "day":
            if os.path.exists(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH):
                with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                          "r") as file:
                    start_stock_code = file.read()
                logging.info(
                    "read {} to get start code number is {} ... ".format(
                        config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                        start_stock_code))
            else:
                start_stock_code = 0
            for symbol in stock_symbol_list:
                if int(symbol[2:]) >= int(start_stock_code):
                    try:
                        if start_date is None:
                            # 如果该symbol有历史数据,如果有则从API获取从数据库中最近的时间开始直到现在的所有价格数据
                            # 如果该symbol无历史数据,则从API获取从2015年1月1日开始直到现在的所有价格数据
                            symbol_date_list = self.db_obj.get_data(
                                self.database_name, symbol,
                                keys=["date"])["date"].to_list()
                            if len(symbol_date_list) == 0:
                                start_date = config.STOCK_PRICE_REQUEST_DEFAULT_DATE
                            else:
                                tmp_date = str(
                                    max(symbol_date_list)).split(" ")[0]
                                tmp_date_dt = datetime.datetime.strptime(
                                    tmp_date, "%Y-%m-%d").date()
                                offset = datetime.timedelta(days=1)
                                start_date = (tmp_date_dt +
                                              offset).strftime('%Y%m%d')
                        stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily(
                            symbol=symbol,
                            start_date=start_date,
                            end_date=end_date,
                            adjust="hfq")
                        stock_zh_a_daily_hfq_df.insert(
                            0, 'date', stock_zh_a_daily_hfq_df.index.tolist())
                        stock_zh_a_daily_hfq_df.index = range(
                            len(stock_zh_a_daily_hfq_df))
                        _col = self.db_obj.get_collection(
                            self.database_name, symbol)
                        for _id in range(stock_zh_a_daily_hfq_df.shape[0]):
                            _col.insert_one(
                                stock_zh_a_daily_hfq_df.iloc[_id].to_dict())
                        logging.info("{} finished saving ... ".format(symbol))
                    except Exception:
                        with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                                  "w") as file:
                            file.write(symbol[2:])
        elif freq == "week":
            pass
        elif freq == "month":
            pass
        elif freq == "5mins":
            pass
        elif freq == "15mins":
            pass
        elif freq == "30mins":
            pass
        elif freq == "60mins":
            pass