Exemple #1
0
class StockInfoSpyder(Spyder):

    def __init__(self, database_name, collection_name):
        super(StockInfoSpyder, self).__init__()
        self.db_obj = Database()
        self.col_basic_info = self.db_obj.get_collection(database_name, collection_name)
        self.database_name = database_name
        self.collection_name = collection_name

    def get_stock_code_info(self):
        stock_info_df = ak.stock_info_a_code_name()  # 获取所有A股code和name
        stock_symbol_code = ak.stock_zh_a_spot().get(["symbol", "code"])  # 获取A股所有股票的symbol和code
        for _id in range(stock_info_df.shape[0]):
            _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df.iloc[_id].code].symbol.values
            if len(_symbol) != 0:
                _dict = {"symbol": _symbol[0]}
                _dict.update(stock_info_df.iloc[_id].to_dict())
                self.col_basic_info.insert_one(_dict)
        return stock_info_df

    def get_historical_news(self, start_date, end_date, freq="day"):
        stock_symbol_list = self.col_basic_info.distinct("symbol")
        if len(stock_symbol_list) == 0:
            stock_symbol_list = self.get_stock_code_info()
        if freq == "day":
            if os.path.exists(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH):
                with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, "r") as file:
                    start_stock_code = file.read()
                logging.info("read {} to get start code number is {} ... "
                             .format(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, start_stock_code))
            else:
                start_stock_code = 0
            for symbol in stock_symbol_list:
                if int(symbol[2:]) >= int(start_stock_code):
                    try:
                        stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily(symbol=symbol,
                                                                      start_date=start_date,
                                                                      end_date=end_date,
                                                                      adjust="hfq")
                        stock_zh_a_daily_hfq_df.insert(0, 'date', stock_zh_a_daily_hfq_df.index.tolist())
                        stock_zh_a_daily_hfq_df.index = range(len(stock_zh_a_daily_hfq_df))
                        _col = self.db_obj.get_collection(self.database_name, symbol)
                        for _id in range(stock_zh_a_daily_hfq_df.shape[0]):
                            _col.insert_one(stock_zh_a_daily_hfq_df.iloc[_id].to_dict())
                        logging.info("{} finished saving ... ".format(symbol))
                    except Exception:
                        with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, "w") as file:
                            file.write(symbol[2:])
        elif freq == "week":
            pass
        elif freq == "month":
            pass
        elif freq == "5mins":
            pass
        elif freq == "15mins":
            pass
        elif freq == "30mins":
            pass
        elif freq == "60mins":
            pass
class DeNull(object):
    def __init__(self, database_name, collection_name):
        self.database = Database()
        self.database_name = database_name
        self.collection_name = collection_name
        self.delete_num = 0

    def run(self):
        collection = self.database.get_collection(self.database_name,
                                                  self.collection_name)
        for row in self.database.get_collection(self.database_name,
                                                self.collection_name).find():
            for _key in list(row.keys()):
                if _key != "RelatedStockCodes" and row[_key] == "":
                    collection.delete_one({'_id': row["_id"]})
                    self.delete_num += 1
                    break
        logging.info(
            "there are {} news contained NULL value in {} collection ... ".
            format(self.delete_num, self.collection_name))
Exemple #3
0
class GenStockNewsDB(object):
    def __init__(self):
        self.database = Database()

    def get_all_news_about_specific_stock(self, database_name,
                                          collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(
            next(
                self.database.get_collection(database_name,
                                             collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba",
                                        user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name,
                                                   collection_name)
        # 创建stock_code为名称的collection
        stock_code_list = self.database.get_data("stock",
                                                 "basic_info",
                                                 keys=["code"
                                                       ])["code"].to_list()
        for code in stock_code_list:
            _collection = self.database.get_collection(
                config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, code)
            _tmp_num_stat = 0
            for row in self.database.get_collection(
                    database_name, collection_name).find():  # 迭代器
                if code in row["RelatedStockCodes"].split(" "):
                    _collection.insert_one({
                        "Date": row["Date"],
                        "Url": row["Url"],
                        "Title": row["Title"],
                        "Article": row["Article"],
                        "OriDB": database_name,
                        "OriCOL": collection_name
                    })
                    _tmp_num_stat += 1
            logging.info(
                "there are {} news mentioned {} in {} collection ... ".format(
                    _tmp_num_stat, code, collection_name))
Exemple #4
0
class Deduplication(object):
    def __init__(self, database_name, collection_name):
        self.database = Database()
        self.database_name = database_name
        self.collection_name = collection_name
        self.delete_num = 0

    def run(self):
        date_list = self.database.get_data(self.database_name,
                                           self.collection_name,
                                           keys=["Date"])["Date"].tolist()
        collection = self.database.get_collection(self.database_name,
                                                  self.collection_name)
        date_list.sort()  # 升序
        # start_date, end_date = date_list[1].split(" ")[0], date_list[-1].split(" ")[0]
        start_date, end_date = min(date_list).split(" ")[0], max(
            date_list).split(" ")[0]
        for _date in utils.get_date_list_from_range(start_date, end_date):
            # 获取特定时间对应的数据并根据URL去重
            # logging.info(_date)
            try:
                data_df = self.database.get_data(
                    self.database_name,
                    self.collection_name,
                    query={"Date": {
                        "$regex": _date
                    }})
            except Exception:
                continue
            if data_df is None:
                continue
            data_df_drop_duplicate = data_df.drop_duplicates(["Url"])
            for _id in list(
                    set(data_df["_id"]) - set(data_df_drop_duplicate["_id"])):
                collection.delete_one({'_id': _id})
                self.delete_num += 1
            # logging.info("{} finished ... ".format(_date))
        logging.info(
            "DB:{} - COL:{} had {} data length originally, now has deleted {} depulications ... "
            .format(self.database_name, self.collection_name,
                    str(len(date_list)), self.delete_num))
Exemple #5
0
class Tokenization(object):
    def __init__(self,
                 import_module="jieba",
                 user_dict=None,
                 chn_stop_words_dir=None):
        #self.database = Database().conn[config.DATABASE_NAME]  #.get_collection(config.COLLECTION_NAME_CNSTOCK)
        self.database = Database()
        self.import_module = import_module
        self.user_dict = user_dict
        if self.user_dict:
            self.update_user_dict(self.user_dict)
        if chn_stop_words_dir:
            self.stop_words_list = utils.get_chn_stop_words(chn_stop_words_dir)
        else:
            self.stop_words_list = list()

    def update_user_dict(self, old_user_dict_dir, new_user_dict_dir=None):
        # 将缺失的(或新的)股票名称、金融新词等,添加进金融词典中
        word_list = []
        with open(old_user_dict_dir, "r", encoding="utf-8") as file:
            for row in file:
                word_list.append(row.split("\n")[0])
        name_code_df = self.database.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        new_words_list = list(set(name_code_df["name"].tolist()))
        for word in new_words_list:
            if word not in word_list:
                word_list.append(word)
        new_user_dict_dir = old_user_dict_dir if not new_user_dict_dir else new_user_dict_dir
        with open(new_user_dict_dir, "w", encoding="utf-8") as file:
            for word in word_list:
                file.write(word + "\n")

    def cut_words(self, text):
        outstr = list()
        sentence_seged = None
        if self.import_module == "jieba":
            if self.user_dict:
                jieba.load_userdict(self.user_dict)
            sentence_seged = list(jieba.cut(text))
        elif self.import_module == "pkuseg":
            seg = pkuseg.pkuseg(user_dict=self.user_dict)  # 添加自定义词典
            sentence_seged = seg.cut(text)  # 进行分词
        if sentence_seged:
            for word in sentence_seged:
                if word not in self.stop_words_list and word != "\t" and word != " ":
                    outstr.append(word)
            return outstr
        else:
            return False

    def find_relevant_stock_codes_in_article(self, article,
                                             stock_name_code_dict):
        stock_codes_set = list()
        cut_words_list = self.cut_words(article)
        if cut_words_list:
            for word in cut_words_list:
                try:
                    stock_codes_set.append(stock_name_code_dict[word])
                except Exception:
                    pass
        return list(set(stock_codes_set))

    def update_news_database_rows(self, database_name, collection_name):
        name_code_df = self.database.get_data(
            config.STOCK_DATABASE_NAME,
            config.COLLECTION_NAME_STOCK_BASIC_INFO,
            keys=["name", "code"])
        name_code_dict = dict(name_code_df.values)
        data = self.database.get_collection(database_name,
                                            collection_name).find()
        for row in data:
            # if row["Date"] > "2019-05-20 00:00:00":
            related_stock_codes_list = self.find_relevant_stock_codes_in_article(
                row["Article"], name_code_dict)
            self.database.update_row(
                database_name, collection_name, {"_id": row["_id"]},
                {"RelatedStockCodes": " ".join(related_stock_codes_list)})
            logging.info(
                "[{} -> {} -> {}] updated RelatedStockCodes key value ... ".
                format(database_name, collection_name, row["Date"]))
Exemple #6
0
class StockInfoSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(StockInfoSpyder, self).__init__()
        self.db_obj = Database()
        self.col_basic_info = self.db_obj.get_collection(
            database_name, collection_name)
        self.database_name = database_name
        self.collection_name = collection_name
        self.start_program_date = datetime.datetime.now().strftime("%Y%m%d")
        self.redis_client = redis.StrictRedis(
            host="localhost",
            port=6379,
            db=config.REDIS_CLIENT_FOR_CACHING_STOCK_INFO_DB_ID)
        self.redis_client.set("today_date",
                              datetime.datetime.now().strftime("%Y-%m-%d"))

    def get_stock_code_info(self):
        # TODO:每半年需要更新一次
        stock_info_df = ak.stock_info_a_code_name()  # 获取所有A股code和name
        stock_symbol_code = ak.stock_zh_a_spot().get(
            ["symbol", "code"])  # 获取A股所有股票的symbol和code
        for _id in range(stock_info_df.shape[0]):
            _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df
                                        .iloc[_id].code].symbol.values
            if len(_symbol) != 0:
                _dict = {"symbol": _symbol[0]}
                _dict.update(stock_info_df.iloc[_id].to_dict())
                self.col_basic_info.insert_one(_dict)

    def get_historical_news(self, start_date=None, end_date=None, freq="day"):
        if end_date is None:
            end_date = datetime.datetime.now().strftime("%Y%m%d")
        stock_symbol_list = self.col_basic_info.distinct("symbol")
        if len(stock_symbol_list) == 0:
            self.get_stock_code_info()
            stock_symbol_list = self.col_basic_info.distinct("symbol")
        if freq == "day":
            start_stock_code = 0 if self.redis_client.get(
                "start_stock_code") is None else int(
                    self.redis_client.get("start_stock_code").decode())
            for symbol in stock_symbol_list:
                if int(symbol[2:]) > start_stock_code:
                    if start_date is None:
                        # 如果该symbol有历史数据,如果有则从API获取从数据库中最近的时间开始直到现在的所有价格数据
                        # 如果该symbol无历史数据,则从API获取从2015年1月1日开始直到现在的所有价格数据
                        _latest_date = self.redis_client.get(symbol)
                        if _latest_date is None:
                            symbol_start_date = config.STOCK_PRICE_REQUEST_DEFAULT_DATE
                        else:
                            tmp_date_dt = datetime.datetime.strptime(
                                _latest_date.decode(), "%Y-%m-%d").date()
                            offset = datetime.timedelta(days=1)
                            symbol_start_date = (tmp_date_dt +
                                                 offset).strftime('%Y%m%d')

                    if symbol_start_date < end_date:
                        stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily(
                            symbol=symbol,
                            start_date=symbol_start_date,
                            end_date=end_date,
                            adjust="qfq")
                        stock_zh_a_daily_hfq_df.insert(
                            0, 'date', stock_zh_a_daily_hfq_df.index.tolist())
                        stock_zh_a_daily_hfq_df.index = range(
                            len(stock_zh_a_daily_hfq_df))
                        _col = self.db_obj.get_collection(
                            self.database_name, symbol)
                        for _id in range(stock_zh_a_daily_hfq_df.shape[0]):
                            _tmp_dict = stock_zh_a_daily_hfq_df.iloc[
                                _id].to_dict()
                            _tmp_dict.pop("outstanding_share")
                            _tmp_dict.pop("turnover")
                            _col.insert_one(_tmp_dict)
                            self.redis_client.set(
                                symbol,
                                str(_tmp_dict["date"]).split(" ")[0])

                        logging.info(
                            "{} finished saving from {} to {} ... ".format(
                                symbol, symbol_start_date, end_date))
                self.redis_client.set("start_stock_code", int(symbol[2:]))
            self.redis_client.set("start_stock_code", 0)
        elif freq == "week":
            pass
        elif freq == "month":
            pass
        elif freq == "5mins":
            pass
        elif freq == "15mins":
            pass
        elif freq == "30mins":
            pass
        elif freq == "60mins":
            pass

    def get_realtime_news(self, freq="day"):
        while True:
            if_updated = input(
                "Has the stock price dataset been updated today? (Y/N) \n")
            if if_updated == "Y":
                self.redis_client.set("is_today_updated", "1")
                break
            elif if_updated == "N":
                self.redis_client.set("is_today_updated", "")
                break
        self.get_historical_news()  # 对所有股票补充数据到最新
        while True:
            if freq == "day":
                time_now = datetime.datetime.now().strftime(
                    "%Y-%m-%d %H:%M:%S")
                if time_now.split(" ")[0] != self.redis_client.get(
                        "today_date").decode():
                    self.redis_client.set("today_date", time_now.split(" ")[0])
                    self.redis_client.set("is_today_updated",
                                          "")  # 过了凌晨,该参数设置回空值,表示今天未进行数据更新
                if not bool(
                        self.redis_client.get("is_today_updated").decode()):
                    update_time = "{} {}".format(
                        time_now.split(" ")[0], "15:30:00")
                    if time_now >= update_time:
                        stock_zh_a_spot_df = ak.stock_zh_a_spot()  # 当天的日数据行情下载
                        for _id, sym in enumerate(
                                stock_zh_a_spot_df["symbol"]):
                            _col = self.db_obj.get_collection(
                                self.database_name, sym)
                            _tmp_dict = {}
                            _tmp_dict.update({
                                "date":
                                Timestamp("{} 00:00:00".format(
                                    time_now.split(" ")[0]))
                            })
                            _tmp_dict.update(
                                {"open": stock_zh_a_spot_df.iloc[_id].open})
                            _tmp_dict.update(
                                {"high": stock_zh_a_spot_df.iloc[_id].high})
                            _tmp_dict.update(
                                {"low": stock_zh_a_spot_df.iloc[_id].low})
                            _tmp_dict.update(
                                {"close": stock_zh_a_spot_df.iloc[_id].trade})
                            _tmp_dict.update({
                                "volume":
                                stock_zh_a_spot_df.iloc[_id].volume
                            })
                            _col.insert_one(_tmp_dict)
                            self.redis_client.set(sym, time_now.split(" ")[0])
                            logging.info(
                                "finished updating {} price data of {} ... ".
                                format(sym,
                                       time_now.split(" ")[0]))
                        self.redis_client.set("is_today_updated", "1")
Exemple #7
0
class GenStockNewsDB(object):
    def __init__(self):
        self.database = Database()
        # 获取从1990-12-19至2020-12-31股票交易日数据
        self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist()
        self.label_range = {
            3: "3DaysLabel",
            5: "5DaysLabel",
            10: "10DaysLabel",
            15: "15DaysLabel",
            30: "30DaysLabel",
            60: "60DaysLabel"
        }

    def get_all_news_about_specific_stock(self, database_name,
                                          collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(
            next(
                self.database.get_collection(database_name,
                                             collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba",
                                        user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name,
                                                   collection_name)
        # 创建stock_code为名称的collection
        stock_symbol_list = self.database.get_data(
            "stock", "basic_info", keys=["symbol"])["symbol"].to_list()
        for symbol in stock_symbol_list:
            _collection = self.database.get_collection(
                config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
            _tmp_num_stat = 0
            for row in self.database.get_collection(
                    database_name, collection_name).find():  # 迭代器
                if symbol[2:] in row["RelatedStockCodes"].split(" "):
                    # 返回新闻发布后n天的标签
                    _tmp_dict = {}
                    for label_days, key_name in self.label_range.items():
                        _tmp_res = self._label_news(
                            datetime.datetime.strptime(
                                row["Date"].split(" ")[0], "%Y-%m-%d"), symbol,
                            label_days)
                        if _tmp_res:
                            _tmp_dict.update({key_name: _tmp_res})
                    _data = {
                        "Date": row["Date"],
                        "Url": row["Url"],
                        "Title": row["Title"],
                        "Article": row["Article"],
                        "OriDB": database_name,
                        "OriCOL": collection_name
                    }
                    _data.update(_tmp_dict)
                    _collection.insert_one(_data)
                    _tmp_num_stat += 1
            logging.info(
                "there are {} news mentioned {} in {} collection ... ".format(
                    _tmp_num_stat, symbol, collection_name))

    def _label_news(self, date, symbol, n_days):
        """
        :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0)
        :param symbol: 类型str,表示股票标的,如sh600000
        :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息
        """
        # 计算新闻发布当天经过n_days天后的具体年月日
        this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                symbol,
                                                query={"date": date})
        # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六,
        # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据
        tmp_date = date
        if this_date_data is None:
            i = 1
            while this_date_data is None and i <= 10:
                tmp_date -= datetime.timedelta(days=i)
                # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据
                if tmp_date.strftime("%Y-%m-%d") in self.trade_date:
                    this_date_data = self.database.get_data(
                        config.STOCK_DATABASE_NAME,
                        symbol,
                        query={"date": tmp_date})
                i += 1
        try:
            close_price_this_date = this_date_data["close"][0]
        except Exception:
            close_price_this_date = None
        # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天
        # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据
        new_date = date + datetime.timedelta(days=n_days)
        n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   symbol,
                                                   query={"date": new_date})
        if n_days_later_data is None:
            i = 1
            while n_days_later_data is None and i <= 10:
                new_date = date + datetime.timedelta(days=n_days + i)
                if new_date.strftime("%Y-%m-%d") in self.trade_date:
                    n_days_later_data = self.database.get_data(
                        config.STOCK_DATABASE_NAME,
                        symbol,
                        query={"date": new_date})
                i += 1
        try:
            close_price_n_days_later = n_days_later_data["close"][0]
        except Exception:
            close_price_n_days_later = None
        # 判断条件:
        # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息
        # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息
        # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息
        # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息
        # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响
        param = 0.01
        if n_days <= 10:
            param = 0.03
        elif 10 < n_days <= 15:
            param = 0.05
        elif 15 < n_days <= 30:
            param = 0.10
        elif 30 < n_days <= 60:
            param = 0.15
        if close_price_this_date is not None and close_price_n_days_later is not None:
            if (close_price_n_days_later -
                    close_price_this_date) / close_price_this_date > param:
                return "利好"
            elif (close_price_n_days_later -
                  close_price_this_date) / close_price_this_date < -param:
                return "利空"
            else:
                return "中性"
        else:
            return False
Exemple #8
0
class TopicModelling(object):
    def __init__(self):
        self.tokenization = Tokenization(
            import_module="jieba",
            user_dict=config.USER_DEFINED_DICT_PATH,
            chn_stop_words_dir=config.CHN_STOP_WORDS_PATH)
        self.database = Database()
        self.classifier = Classifier()

    def create_dictionary(self,
                          raw_documents_list,
                          save_path=None,
                          is_saved=False):
        """
        将文中每个词汇关联唯一的ID,因此需要定义词汇表
        :param: raw_documents_list, 原始语料列表,每个元素即文本,如["洗尽铅华...", "风雨赶路人...", ...]
        :param: savepath, corpora.Dictionary对象保存路径
        """
        documents_token_list = []
        for doc in raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary(documents_token_list)
        # 找到只出现一次的token
        once_items = [
            _dict[tokenid] for tokenid, docfreq in _dict.dfs.items()
            if docfreq == 1
        ]
        # 在documents_token_list的每一条语料中,删除只出现一次的token
        for _id, token_list in enumerate(documents_token_list):
            documents_token_list[_id] = list(
                filter(lambda token: token not in once_items, token_list))
        # 极端情况,某一篇语料所有token只出现一次,这样该篇新闻语料的token列表就变为空,因此删除掉
        documents_token_list = [
            token_list for token_list in documents_token_list
            if (len(token_list) != 0)
        ]
        # 找到只出现一次的token对应的id
        once_ids = [
            tokenid for tokenid, docfreq in _dict.dfs.items() if docfreq == 1
        ]
        # 删除仅出现一次的词
        _dict.filter_tokens(once_ids)
        # 消除id序列在删除词后产生的不连续的缺口
        _dict.compactify()
        if is_saved and save_path:
            _dict.save(save_path)
            logging.info(
                "new generated dictionary saved in path -> {} ...".format(
                    save_path))

        return _dict, documents_token_list

    def renew_dictionary(self,
                         old_dict_path,
                         new_raw_documents_list,
                         new_dict_path=None,
                         is_saved=False):
        documents_token_list = []
        for doc in new_raw_documents_list:
            documents_token_list.append(self.tokenization.cut_words(doc))
        _dict = corpora.Dictionary.load(old_dict_path)
        _dict.add_documents(documents_token_list)
        if new_dict_path:
            old_dict_path = new_dict_path
        if is_saved:
            _dict.save(old_dict_path)
            logging.info(
                "updated dictionary by another raw documents serialized in {} ... "
                .format(old_dict_path))

        return _dict, documents_token_list

    def create_bag_of_word_representation(self,
                                          raw_documents_list,
                                          old_dict_path=None,
                                          new_dict_path=None,
                                          bow_vector_save_path=None,
                                          is_saved_dict=False):
        if old_dict_path:
            # 如果存在旧的语料词典,就在原先词典的基础上更新,增加未见过的词
            corpora_dictionary, documents_token_list = self.renew_dictionary(
                old_dict_path, raw_documents_list, new_dict_path=new_dict_path)
        else:
            # 否则重新创建词典
            start_time = time.time()
            corpora_dictionary, documents_token_list = self.create_dictionary(
                raw_documents_list,
                save_path=new_dict_path,
                is_saved=is_saved_dict)
            end_time = time.time()
            logging.info(
                "there are {} mins spent to create a new dictionary ... ".
                format((end_time - start_time) / 60))
        # 根据新词典对文档(或语料)生成对应的词袋向量
        start_time = time.time()
        bow_vector = [
            corpora_dictionary.doc2bow(doc_token)
            for doc_token in documents_token_list
        ]
        end_time = time.time()
        logging.info(
            "there are {} mins spent to calculate bow-vector ... ".format(
                (end_time - start_time) / 60))
        if bow_vector_save_path:
            corpora.MmCorpus.serialize(bow_vector_save_path, bow_vector)

        return documents_token_list, corpora_dictionary, bow_vector

    @staticmethod
    def transform_vectorized_corpus(corpora_dictionary,
                                    bow_vector,
                                    model_type="lda",
                                    model_save_path=None):
        # 如何没有保存任何模型,重新训练的情况下,可以选择该函数
        model_vector = None
        if model_type == "lsi":
            # LSI(Latent Semantic Indexing)模型,将文本从词袋向量或者词频向量(更好),转为一个低维度的latent空间
            # 对于现实语料,目标维度在200-500被认为是"黄金标准"
            model_tfidf = models.TfidfModel(bow_vector)
            # model_tfidf.save("model_tfidf.tfidf")
            tfidf_vector = model_tfidf[bow_vector]
            model = models.LsiModel(tfidf_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[tfidf_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "lda":
            model = models.LdaModel(bow_vector,
                                    id2word=corpora_dictionary,
                                    num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[bow_vector]
            if model_save_path:
                model.save(model_save_path)
        elif model_type == "tfidf":
            model = models.TfidfModel(bow_vector)  # 初始化
            # model = models.TfidfModel.load("model_tfidf.tfidf")
            model_vector = model[bow_vector]  # 将整个语料进行转换
            if model_save_path:
                model.save(model_save_path)

        return model_vector

    def classify_stock_news(self,
                            unseen_raw_document,
                            database_name,
                            collection_name,
                            label_name="60DaysLabel",
                            topic_model_type="lda",
                            classifier_model="svm",
                            ori_dict_path=None,
                            bowvec_save_path=None,
                            is_saved_bow_vector=False):
        historical_raw_documents_list = []
        Y = []
        for row in self.database.get_collection(database_name,
                                                collection_name).find():
            if label_name in row.keys():
                if row[label_name] != "":
                    historical_raw_documents_list.append(row["Article"])
                    Y.append(row[label_name])
        logging.info(
            "fetch symbol '{}' historical news with label '{}' from [DB:'{}' - COL:'{}'] ... "
            .format(collection_name, label_name, database_name,
                    collection_name))

        le = preprocessing.LabelEncoder()
        Y = le.fit_transform(Y)
        logging.info(
            "encode historical label list by sklearn preprocessing for training ... "
        )
        label_name_list = le.classes_  # ['中性' '利好' '利空'] -> [0, 1, 2]

        # 根据历史新闻数据库创建词典,以及计算每个历史新闻的词袋向量;如果历史数据库创建的字典存在,则加载进内存
        # 用未见过的新闻tokens去更新该词典
        if not os.path.exists(ori_dict_path):
            if not os.path.exists(bowvec_save_path):
                _, _, historical_bow_vec = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    bow_vector_save_path=bowvec_save_path,
                    is_saved_dict=True)
                logging.info(
                    "create dictionary of historical news, and serialized in path -> {} ... "
                    .format(ori_dict_path))
                logging.info(
                    "create bow-vector of historical news, and serialized in path -> {} ... "
                    .format(bowvec_save_path))
            else:
                _, _, _ = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    is_saved_dict=True)
                logging.info(
                    "create dictionary of historical news, and serialized in path -> {} ... "
                    .format(ori_dict_path))
        else:
            if not os.path.exists(bowvec_save_path):
                _, _, historical_bow_vec = self.create_bag_of_word_representation(
                    historical_raw_documents_list,
                    new_dict_path=ori_dict_path,
                    bow_vector_save_path=bowvec_save_path,
                    is_saved_dict=True)
                logging.info(
                    "historical news dictionary existed, which saved in path -> {}, but not the historical bow-vector"
                    " ... ".format(ori_dict_path))
            else:
                historical_bow_vec_mmcorpus = corpora.MmCorpus(
                    bowvec_save_path
                )  # type -> <gensim.corpora.mmcorpus.MmCorpus>
                historical_bow_vec = []
                for _bow in historical_bow_vec_mmcorpus:
                    historical_bow_vec.append(_bow)
                logging.info(
                    "both historical news dictionary and bow-vector existed, load historical bow-vector to memory ... "
                )

        start_time = time.time()
        updated_dictionary_with_old_and_unseen_news, unssen_documents_token_list = self.renew_dictionary(
            ori_dict_path, [unseen_raw_document], is_saved=True)
        end_time = time.time()
        logging.info(
            "renew dictionary with unseen news tokens, and serialized in path -> {}, "
            "which took {} mins ... ".format(ori_dict_path,
                                             (end_time - start_time) / 60))

        unseen_bow_vector = [
            updated_dictionary_with_old_and_unseen_news.doc2bow(doc_token)
            for doc_token in unssen_documents_token_list
        ]
        updated_bow_vector_with_old_and_unseen_news = []
        updated_bow_vector_with_old_and_unseen_news.extend(historical_bow_vec)
        updated_bow_vector_with_old_and_unseen_news.extend(unseen_bow_vector)
        # 原先updated_bow_vector_with_old_and_unseen_news是list类型,
        # 但是经过下面序列化后重新加载进来的类型是gensim.corpora.mmcorpus.MmCorpus
        if is_saved_bow_vector and bowvec_save_path:
            corpora.MmCorpus.serialize(
                bowvec_save_path, updated_bow_vector_with_old_and_unseen_news
            )  # 保存更新后的bow向量,即包括新旧新闻的bow向量集
        logging.info(
            "combined bow vector(type -> 'list') generated by historical news with unseen bow "
            "vector to create a new one ... ")

        if topic_model_type == "lsi":
            start_time = time.time()
            updated_tfidf_model_vector = self.transform_vectorized_corpus(
                updated_dictionary_with_old_and_unseen_news,
                updated_bow_vector_with_old_and_unseen_news,
                model_type="tfidf"
            )  # type -> <gensim.interfaces.TransformedCorpus object>
            end_time = time.time()
            logging.info(
                "regenerated TF-IDF model vector by updated dictionary and updated bow-vector, "
                "which took {} mins ... ".format((end_time - start_time) / 60))

            start_time = time.time()
            model = models.LsiModel(
                updated_tfidf_model_vector,
                id2word=updated_dictionary_with_old_and_unseen_news,
                num_topics=config.TOPIC_NUMBER)  # 初始化模型
            model_vector = model[
                updated_tfidf_model_vector]  # type -> <gensim.interfaces.TransformedCorpus object>
            end_time = time.time()
            logging.info(
                "regenerated LSI model vector space by updated TF-IDF model vector space, "
                "which took {} mins ... ".format((end_time - start_time) / 60))
        elif topic_model_type == "lda":
            start_time = time.time()
            model_vector = self.transform_vectorized_corpus(
                updated_dictionary_with_old_and_unseen_news,
                updated_bow_vector_with_old_and_unseen_news,
                model_type="lda")
            end_time = time.time()
            logging.info(
                "regenerated LDA model vector space by updated dictionary and bow-vector, "
                "which took {} mins ... ".format((end_time - start_time) / 60))

        # 将gensim.interfaces.TransformedCorpus类型的lsi模型向量转为numpy矩阵
        start_time = time.time()
        latest_matrix = corpus2dense(model_vector,
                                     num_terms=model_vector.obj.num_terms).T
        end_time = time.time()
        logging.info(
            "transform {} model vector space to numpy.adarray, "
            "which took {} mins ... ".format(topic_model_type.upper(),
                                             (end_time - start_time) / 60))

        # 利用历史数据的话题模型向量(或特征),进一步训练新闻分类器
        start_time = time.time()
        train_x, train_y, test_x, test_y = utils.generate_training_set(
            latest_matrix[:-1, :], Y)
        clf = self.classifier.train(train_x,
                                    train_y,
                                    test_x,
                                    test_y,
                                    model_type=classifier_model)
        end_time = time.time()
        logging.info(
            "finished training by sklearn {} using latest {} model vector space, which took {} mins ... "
            .format(classifier_model.upper(), topic_model_type.upper(),
                    (end_time - start_time) / 60))

        label_id = clf.predict(latest_matrix[-1, :].reshape(1, -1))[0]

        return label_name_list[label_id]
Exemple #9
0
        elif ".lsi" in model_path:
            return models.LsiModel.load(model_path)
        elif ".lda" in model_path:
            return models.LdaModel.load(model_path)


if __name__ == "__main__":
    from Hisoka.classifier import Classifier
    from Kite.database import Database
    from sklearn import preprocessing

    database = Database()
    topicmodelling = TopicModelling()
    raw_documents_list = []
    Y = []
    for row in database.get_collection("stocknews", "sz000001").find():
        if "30DaysLabel" in row.keys():
            raw_documents_list.append(row["Article"])
            Y.append(row["30DaysLabel"])
    le = preprocessing.LabelEncoder()
    Y = le.fit_transform(Y)

    _, corpora_dictionary, corpus = topicmodelling.create_bag_of_word_representation(
        raw_documents_list)
    model_vector = topicmodelling.transform_vectorized_corpus(
        corpora_dictionary, corpus, model_type="lsi")
    csr_matrix = utils.convert_to_csr_matrix(model_vector)
    train_x, train_y, test_x, test_y = utils.generate_training_set(
        csr_matrix, Y)
    classifier = Classifier()
    classifier.svm(train_x, train_y, test_x, test_y)
class GenStockNewsDB(object):

    def __init__(self):
        self.database = Database()
        # 获取从1990-12-19至2020-12-31股票交易日数据
        self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist()
        self.label_range = {3: "3DaysLabel",
                            5: "5DaysLabel",
                            10: "10DaysLabel",
                            15: "15DaysLabel",
                            30: "30DaysLabel",
                            60: "60DaysLabel"}
        self.redis_client = redis.StrictRedis(host=config.REDIS_IP,
                                              port=config.REDIS_PORT,
                                              db=config.CACHE_NEWS_REDIS_DB_ID)
        self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d"))
        self.redis_client.delete("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML))
        self._stock_news_nums_stat()

    def get_all_news_about_specific_stock(self, database_name, collection_name):
        # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的
        # 股票代码保存在新的一列
        _keys_list = list(next(self.database.get_collection(database_name, collection_name).find()).keys())
        if "RelatedStockCodes" not in _keys_list:
            tokenization = Tokenization(import_module="jieba", user_dict="./Leorio/financedict.txt")
            tokenization.update_news_database_rows(database_name, collection_name)
        # 创建stock_code为名称的collection
        stock_symbol_list = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   config.COLLECTION_NAME_STOCK_BASIC_INFO,
                                                   keys=["symbol"])["symbol"].to_list()
        col_names = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None)
        for symbol in stock_symbol_list:
            if symbol not in col_names:
                # if int(symbol[2:]) > 837:
                _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
                _tmp_num_stat = 0
                for row in self.database.get_collection(database_name, collection_name).find():  # 迭代器
                    if symbol[2:] in row["RelatedStockCodes"].split(" "):
                        # 返回新闻发布后n天的标签
                        _tmp_dict = {}
                        for label_days, key_name in self.label_range.items():
                            _tmp_res = self._label_news(
                                datetime.datetime.strptime(row["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days)
                            _tmp_dict.update({key_name: _tmp_res})
                        _data = {"Date": row["Date"],
                                 "Url": row["Url"],
                                 "Title": row["Title"],
                                 "Article": row["Article"],
                                 "OriDB": database_name,
                                 "OriCOL": collection_name}
                        _data.update(_tmp_dict)
                        _collection.insert_one(_data)
                        _tmp_num_stat += 1
                logging.info("there are {} news mentioned {} in {} collection need to be fetched ... "
                             .format(_tmp_num_stat, symbol, collection_name))
            # else:
            #     logging.info("{} has fetched all related news from {}...".format(symbol, collection_name))

    def listen_redis_queue(self):
        # 监听redis消息队列,当新的实时数据过来时,根据"RelatedStockCodes"字段,将新闻分别保存到对应的股票数据库
        # e.g.:缓存新的一条数据中,"RelatedStockCodes"字段数据为"603386 603003 600111 603568",则将该条新闻分别
        # 都存进这四支股票对应的数据库中
        crawled_url_today = set()
        while True:
            date_now = datetime.datetime.now().strftime("%Y-%m-%d")
            if date_now != self.redis_client.get("today_date").decode():
                crawled_url_today = set()
                self.redis_client.set("today_date", date_now)
            if self.redis_client.llen(config.CACHE_NEWS_LIST_NAME) != 0:
                data = json.loads(self.redis_client.lindex(config.CACHE_NEWS_LIST_NAME, -1))
                if data["Url"] not in crawled_url_today:  # 排除重复插入冗余文本
                    crawled_url_today.update({data["Url"]})
                    if data["RelatedStockCodes"] != "":
                        for stock_code in data["RelatedStockCodes"].split(" "):
                            # 将新闻分别送进相关股票数据库
                            symbol = "sh{}".format(stock_code) if stock_code[0] == "6" else "sz{}".format(stock_code)
                            _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)
                            _tmp_dict = {}
                            for label_days, key_name in self.label_range.items():
                                _tmp_res = self._label_news(
                                    datetime.datetime.strptime(data["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days)
                                _tmp_dict.update({key_name: _tmp_res})
                            _data = {"Date": data["Date"],
                                     "Url": data["Url"],
                                     "Title": data["Title"],
                                     "Article": data["Article"],
                                     "OriDB": data["OriDB"],
                                     "OriCOL": data["OriCOL"]}
                            _data.update(_tmp_dict)
                            _collection.insert_one(_data)
                            logging.info("the real-time fetched news {}, which was saved in [DB:{} - COL:{}] ...".format(data["Title"],
                                                                                                                         config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE,
                                                                                                                         symbol))
                            #
                            # if symbol.encode() in self.redis_client.lrange("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), 0, -1):
                            #     label_name = "3DaysLabel"
                            #     # classifier_save_path = "{}_classifier.pkl".format(symbol)
                            #     ori_dict_path = "{}_docs_dict.dict".format(symbol)
                            #     bowvec_save_path = "{}_bowvec.mm".format(symbol)
                            #
                            #     topicmodelling = TopicModelling()
                            #     chn_label = topicmodelling.classify_stock_news(data["Article"],
                            #                                                    config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE,
                            #                                                    symbol,
                            #                                                    label_name=label_name,
                            #                                                    topic_model_type="lsi",
                            #                                                    classifier_model="rdforest",  # rdforest / svm
                            #                                                    ori_dict_path=ori_dict_path,
                            #                                                    bowvec_save_path=bowvec_save_path)
                            #     logging.info(
                            #         "document '{}...' was classified with label '{}' for symbol {} ... ".format(
                            #             data["Article"][:20], chn_label, symbol))

                    self.redis_client.rpop(config.CACHE_NEWS_LIST_NAME)
                    logging.info("now pop {} from redis queue of [DB:{} - KEY:{}] ... ".format(data["Title"],
                                                                                               config.CACHE_NEWS_REDIS_DB_ID,
                                                                                               config.CACHE_NEWS_LIST_NAME))

    def _label_news(self, date, symbol, n_days):
        """
        :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0)
        :param symbol: 类型str,表示股票标的,如sh600000
        :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息
        """
        # 计算新闻发布当天经过n_days天后的具体年月日
        this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                symbol,
                                                query={"date": date})
        # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六,
        # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据
        tmp_date = date
        if this_date_data is None:
            i = 1
            while this_date_data is None and i <= 10:
                tmp_date -= datetime.timedelta(days=i)
                # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据
                if tmp_date.strftime("%Y-%m-%d") in self.trade_date:
                    this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                            symbol,
                                                            query={"date": tmp_date})
                i += 1
        try:
            close_price_this_date = this_date_data["close"][0]
        except Exception:
            close_price_this_date = None
        # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天
        # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据
        new_date = date + datetime.timedelta(days=n_days)
        n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                   symbol,
                                                   query={"date": new_date})
        if n_days_later_data is None:
            i = 1
            while n_days_later_data is None and i <= 10:
                new_date = date + datetime.timedelta(days=n_days+i)
                if new_date.strftime("%Y-%m-%d") in self.trade_date:
                    n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME,
                                                               symbol,
                                                               query={"date": new_date})
                i += 1
        try:
            close_price_n_days_later = n_days_later_data["close"][0]
        except Exception:
            close_price_n_days_later = None
        # 判断条件:
        # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息
        # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息
        # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息
        # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息
        # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响
        param = 0.01
        if n_days <= 10:
            param = 0.03
        elif 10 < n_days <= 15:
            param = 0.05
        elif 15 < n_days <= 30:
            param = 0.10
        elif 30 < n_days <= 60:
            param = 0.15
        if close_price_this_date is not None and close_price_n_days_later is not None:
            if (close_price_n_days_later - close_price_this_date) / close_price_this_date > param:
                return "利好"
            elif (close_price_n_days_later - close_price_this_date) / close_price_this_date < -param:
                return "利空"
            else:
                return "中性"
        else:
            return ""

    def _stock_news_nums_stat(self):
        cols_list = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None)
        for sym in cols_list:
            if self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, sym).estimated_document_count() > config.MINIMUM_STOCK_NEWS_NUM_FOR_ML:
                self.redis_client.lpush("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), sym)
class StockInfoSpyder(Spyder):
    def __init__(self, database_name, collection_name):
        super(StockInfoSpyder, self).__init__()
        self.db_obj = Database()
        self.col_basic_info = self.db_obj.get_collection(
            database_name, collection_name)
        self.database_name = database_name
        self.collection_name = collection_name

    def get_stock_code_info(self):
        stock_info_df = ak.stock_info_a_code_name()  # 获取所有A股code和name
        stock_symbol_code = ak.stock_zh_a_spot().get(
            ["symbol", "code"])  # 获取A股所有股票的symbol和code
        for _id in range(stock_info_df.shape[0]):
            _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df
                                        .iloc[_id].code].symbol.values
            if len(_symbol) != 0:
                _dict = {"symbol": _symbol[0]}
                _dict.update(stock_info_df.iloc[_id].to_dict())
                self.col_basic_info.insert_one(_dict)

    def get_historical_news(self, start_date=None, end_date=None, freq="day"):
        if end_date is None:
            end_date = datetime.datetime.now().strftime("%Y%m%d")
        stock_symbol_list = self.col_basic_info.distinct("symbol")
        if len(stock_symbol_list) == 0:
            self.get_stock_code_info()
            stock_symbol_list = self.col_basic_info.distinct("symbol")
        if freq == "day":
            if os.path.exists(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH):
                with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                          "r") as file:
                    start_stock_code = file.read()
                logging.info(
                    "read {} to get start code number is {} ... ".format(
                        config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                        start_stock_code))
            else:
                start_stock_code = 0
            for symbol in stock_symbol_list:
                if int(symbol[2:]) >= int(start_stock_code):
                    try:
                        if start_date is None:
                            # 如果该symbol有历史数据,如果有则从API获取从数据库中最近的时间开始直到现在的所有价格数据
                            # 如果该symbol无历史数据,则从API获取从2015年1月1日开始直到现在的所有价格数据
                            symbol_date_list = self.db_obj.get_data(
                                self.database_name, symbol,
                                keys=["date"])["date"].to_list()
                            if len(symbol_date_list) == 0:
                                start_date = config.STOCK_PRICE_REQUEST_DEFAULT_DATE
                            else:
                                tmp_date = str(
                                    max(symbol_date_list)).split(" ")[0]
                                tmp_date_dt = datetime.datetime.strptime(
                                    tmp_date, "%Y-%m-%d").date()
                                offset = datetime.timedelta(days=1)
                                start_date = (tmp_date_dt +
                                              offset).strftime('%Y%m%d')
                        stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily(
                            symbol=symbol,
                            start_date=start_date,
                            end_date=end_date,
                            adjust="hfq")
                        stock_zh_a_daily_hfq_df.insert(
                            0, 'date', stock_zh_a_daily_hfq_df.index.tolist())
                        stock_zh_a_daily_hfq_df.index = range(
                            len(stock_zh_a_daily_hfq_df))
                        _col = self.db_obj.get_collection(
                            self.database_name, symbol)
                        for _id in range(stock_zh_a_daily_hfq_df.shape[0]):
                            _col.insert_one(
                                stock_zh_a_daily_hfq_df.iloc[_id].to_dict())
                        logging.info("{} finished saving ... ".format(symbol))
                    except Exception:
                        with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH,
                                  "w") as file:
                            file.write(symbol[2:])
        elif freq == "week":
            pass
        elif freq == "month":
            pass
        elif freq == "5mins":
            pass
        elif freq == "15mins":
            pass
        elif freq == "30mins":
            pass
        elif freq == "60mins":
            pass