def __init__(self): self.tokenization = Tokenization( import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH, chn_stop_words_dir=config.CHN_STOP_WORDS_PATH) self.database = Database() self.classifier = Classifier()
def __init__(self, database_name, collection_name): super(NbdSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection(collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name
def __init__(self, database_name, collection_name): super(StockInfoSpyder, self).__init__() self.db_obj = Database() self.col_basic_info = self.db_obj.get_collection( database_name, collection_name) self.database_name = database_name self.collection_name = collection_name
class StockInfoSpyder(Spyder): def __init__(self, database_name, collection_name): super(StockInfoSpyder, self).__init__() self.db_obj = Database() self.col_basic_info = self.db_obj.get_collection(database_name, collection_name) self.database_name = database_name self.collection_name = collection_name def get_stock_code_info(self): stock_info_df = ak.stock_info_a_code_name() # 获取所有A股code和name stock_symbol_code = ak.stock_zh_a_spot().get(["symbol", "code"]) # 获取A股所有股票的symbol和code for _id in range(stock_info_df.shape[0]): _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df.iloc[_id].code].symbol.values if len(_symbol) != 0: _dict = {"symbol": _symbol[0]} _dict.update(stock_info_df.iloc[_id].to_dict()) self.col_basic_info.insert_one(_dict) return stock_info_df def get_historical_news(self, start_date, end_date, freq="day"): stock_symbol_list = self.col_basic_info.distinct("symbol") if len(stock_symbol_list) == 0: stock_symbol_list = self.get_stock_code_info() if freq == "day": if os.path.exists(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH): with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, "r") as file: start_stock_code = file.read() logging.info("read {} to get start code number is {} ... " .format(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, start_stock_code)) else: start_stock_code = 0 for symbol in stock_symbol_list: if int(symbol[2:]) >= int(start_stock_code): try: stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily(symbol=symbol, start_date=start_date, end_date=end_date, adjust="hfq") stock_zh_a_daily_hfq_df.insert(0, 'date', stock_zh_a_daily_hfq_df.index.tolist()) stock_zh_a_daily_hfq_df.index = range(len(stock_zh_a_daily_hfq_df)) _col = self.db_obj.get_collection(self.database_name, symbol) for _id in range(stock_zh_a_daily_hfq_df.shape[0]): _col.insert_one(stock_zh_a_daily_hfq_df.iloc[_id].to_dict()) logging.info("{} finished saving ... ".format(symbol)) except Exception: with open(config.STOCK_DAILY_EXCEPTION_TXT_FILE_PATH, "w") as file: file.write(symbol[2:]) elif freq == "week": pass elif freq == "month": pass elif freq == "5mins": pass elif freq == "15mins": pass elif freq == "30mins": pass elif freq == "60mins": pass
def __init__(self, database_name, collection_name): super(JrjSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection(collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH)
def __init__(self, database_name, collection_name): super(CnStockSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection(collection_name) self.driver = webdriver.Chrome(executable_path=config.CHROME_DRIVER) self.btn_more_text = "" self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name
def __init__(self, database_name, collection_name): super(NbdSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection(collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name self.tokenization = Tokenization(import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH) self.redis_client = redis.StrictRedis(host=config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_NEWS_REDIS_DB_ID)
def __init__(self, import_module="jieba", user_dict=None, chn_stop_words_dir=None): #self.database = Database().conn[config.DATABASE_NAME] #.get_collection(config.COLLECTION_NAME_CNSTOCK) self.database = Database() self.import_module = import_module self.user_dict = user_dict if self.user_dict: self.update_user_dict(self.user_dict) if chn_stop_words_dir: self.stop_words_list = utils.get_chn_stop_words(chn_stop_words_dir) else: self.stop_words_list = list()
def __init__(self): self.database = Database() # 获取从1990-12-19至2020-12-31股票交易日数据 self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist() self.label_range = { 3: "3DaysLabel", 5: "5DaysLabel", 10: "10DaysLabel", 15: "15DaysLabel", 30: "30DaysLabel", 60: "60DaysLabel" }
def __init__(self, database_name, collection_name): super(StockInfoSpyder, self).__init__() self.db_obj = Database() self.col_basic_info = self.db_obj.get_collection( database_name, collection_name) self.database_name = database_name self.collection_name = collection_name self.start_program_date = datetime.datetime.now().strftime("%Y%m%d") self.redis_client = redis.StrictRedis( host="localhost", port=6379, db=config.REDIS_CLIENT_FOR_CACHING_STOCK_INFO_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d"))
class Spyder(object): def __init__(self): self.db_obj = Database() self.db = self.db_obj.create_db(config.DATABASE_NAME) self.is_article_prob = .5 def extract_data(self, tag_list): data = list() for tag in tag_list: exec(tag + " = self.col.distinct('" + tag + "')") exec("data.append(" + tag + ")") return data def query_news(self, _key, param): # 模糊查询 return self.col.find({_key: {'$regex': ".*{}.*".format(param)}}) def get_url_info(self, url): pass def get_historical_news(self, url): pass def get_realtime_news(self, url): pass
def __init__(self): self.database = Database() # 获取从1990-12-19至2020-12-31股票交易日数据 self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist() self.label_range = {3: "3DaysLabel", 5: "5DaysLabel", 10: "10DaysLabel", 15: "15DaysLabel", 30: "30DaysLabel", 60: "60DaysLabel"} self.redis_client = redis.StrictRedis(host=config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_NEWS_REDIS_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d")) self.redis_client.delete("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML)) self._stock_news_nums_stat()
def __init__(self): self.database = Database() # 获取从1990-12-19至2020-12-31股票交易日数据 self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist() self.label_range = { 3: "3DaysLabel", 5: "5DaysLabel", 10: "10DaysLabel", 15: "15DaysLabel", 30: "30DaysLabel", 60: "60DaysLabel" } self.redis_client = redis.StrictRedis(host=config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_NEWS_REDIS_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d"))
class DeNull(object): def __init__(self, database_name, collection_name): self.database = Database() self.database_name = database_name self.collection_name = collection_name self.delete_num = 0 def run(self): collection = self.database.get_collection(self.database_name, self.collection_name) for row in self.database.get_collection(self.database_name, self.collection_name).find(): for _key in list(row.keys()): if _key != "RelatedStockCodes" and row[_key] == "": collection.delete_one({'_id': row["_id"]}) self.delete_num += 1 break logging.info( "there are {} news contained NULL value in {} collection ... ". format(self.delete_num, self.collection_name))
class GenStockNewsDB(object): def __init__(self): self.database = Database() def get_all_news_about_specific_stock(self, database_name, collection_name): # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的 # 股票代码保存在新的一列 _keys_list = list( next( self.database.get_collection(database_name, collection_name).find()).keys()) if "RelatedStockCodes" not in _keys_list: tokenization = Tokenization(import_module="jieba", user_dict="./Leorio/financedict.txt") tokenization.update_news_database_rows(database_name, collection_name) # 创建stock_code为名称的collection stock_code_list = self.database.get_data("stock", "basic_info", keys=["code" ])["code"].to_list() for code in stock_code_list: _collection = self.database.get_collection( config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, code) _tmp_num_stat = 0 for row in self.database.get_collection( database_name, collection_name).find(): # 迭代器 if code in row["RelatedStockCodes"].split(" "): _collection.insert_one({ "Date": row["Date"], "Url": row["Url"], "Title": row["Title"], "Article": row["Article"], "OriDB": database_name, "OriCOL": collection_name }) _tmp_num_stat += 1 logging.info( "there are {} news mentioned {} in {} collection ... ".format( _tmp_num_stat, code, collection_name))
class Deduplication(object): def __init__(self, database_name, collection_name): self.database = Database() self.database_name = database_name self.collection_name = collection_name self.delete_num = 0 def run(self): date_list = self.database.get_data(self.database_name, self.collection_name, keys=["Date"])["Date"].tolist() collection = self.database.get_collection(self.database_name, self.collection_name) date_list.sort() # 升序 # start_date, end_date = date_list[1].split(" ")[0], date_list[-1].split(" ")[0] start_date, end_date = min(date_list).split(" ")[0], max( date_list).split(" ")[0] for _date in utils.get_date_list_from_range(start_date, end_date): # 获取特定时间对应的数据并根据URL去重 # logging.info(_date) try: data_df = self.database.get_data( self.database_name, self.collection_name, query={"Date": { "$regex": _date }}) except Exception: continue if data_df is None: continue data_df_drop_duplicate = data_df.drop_duplicates(["Url"]) for _id in list( set(data_df["_id"]) - set(data_df_drop_duplicate["_id"])): collection.delete_one({'_id': _id}) self.delete_num += 1 # logging.info("{} finished ... ".format(_date)) logging.info( "DB:{} - COL:{} had {} data length originally, now has deleted {} depulications ... " .format(self.database_name, self.collection_name, str(len(date_list)), self.delete_num))
def __init__(self, database_name, collection_name): super(StockInfoSpyder, self).__init__() self.db_obj = Database() self.col_basic_info = self.db_obj.get_collection( database_name, collection_name) self.database_name = database_name self.collection_name = collection_name self.start_program_date = datetime.datetime.now().strftime("%Y%m%d") self.redis_client = redis.StrictRedis( host="localhost", port=6379, db=config.REDIS_CLIENT_FOR_CACHING_STOCK_INFO_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d")) while True: if_updated = input( "Has the stock price dataset been updated today? (Y/N) \n") if if_updated == "Y": self.redis_client.set("is_today_updated", "1") break elif if_updated == "N": self.redis_client.set("is_today_updated", "") break
class NbdSpyder(Spyder): def __init__(self, database_name, collection_name): super(NbdSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection(collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name def get_url_info(self, url): try: bs = utils.html_parser(url) except Exception: return False span_list = bs.find_all("span") part = bs.find_all("p") article = "" date = "" for span in span_list: if "class" in span.attrs and span.text and span["class"] == ["time"]: string = span.text.split() for dt in string: if dt.find("-") != -1: date += dt + " " elif dt.find(":") != -1: date += dt break for paragraph in part: chn_status = utils.count_chn(str(paragraph)) possible = chn_status[1] if possible > self.is_article_prob: article += str(paragraph) while article.find("<") != -1 and article.find(">") != -1: string = article[article.find("<"):article.find(">")+1] article = article.replace(string, "") while article.find("\u3000") != -1: article = article.replace("\u3000", "") article = " ".join(re.split(" +|\n+", article)).strip() return [date, article] def get_historical_news(self, start_page): # extracted_data_list = self.extract_data(["PageId"])[0] # if len(extracted_data_list) != 0: # latest_page_id = min(extracted_data_list) # else: # latest_page_id = start_page # crawled_urls_list = list() # for page_id in range(start_page, int(latest_page_id)-1, -1): # query_results = self.query_news("PageId", page_id) # for qr in query_results: # crawled_urls_list.append(qr["Url"]) # logging.info("the length of crawled data from page {} to page {} is {} ... ".format(start_page, # latest_page_id, # len(crawled_urls_list))) crawled_urls_list = [] page_urls = ["{}/{}".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_id) for page_id in range(start_page, 0, -1)] for page_url in page_urls: bs = utils.html_parser(page_url) a_list = bs.find_all("a") for a in a_list: if "click-statistic" in a.attrs and a.string \ and a["click-statistic"].find("Article_") != -1 \ and a["href"].find("http://www.nbd.com.cn/articles/") != -1: if a["href"] not in crawled_urls_list: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open(config.RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info("rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format(config.NBD_MAX_REJECTED_AMOUNTS, config.RECORD_NBD_FAILED_URL_TXT_FILE_PATH)) break logging.info("rejected by remote server, request {} again after " "{} seconds...".format(a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format(a.string, a["href"])) else: # 有返回但是article为null的情况 date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open(config.RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info("rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format(config.NBD_MAX_REJECTED_AMOUNTS, config.RECORD_NBD_FAILED_URL_TXT_FILE_PATH)) break logging.info("rejected by remote server, request {} again after " "{} seconds...".format(a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": data = {"Date": date, "PageId": page_url.split("/")[-1], "Url": a["href"], "Title": a.string, "Article": article} # self.col.insert_one(data) self.db_obj.insert_data(self.db_name, self.col_name, data) logging.info("[SUCCESS] {} {} {}".format(date, a.string, a["href"])) def get_realtime_news(self, url): pass
class Tokenization(object): def __init__(self, import_module="jieba", user_dict=None, chn_stop_words_dir=None): #self.database = Database().conn[config.DATABASE_NAME] #.get_collection(config.COLLECTION_NAME_CNSTOCK) self.database = Database() self.import_module = import_module self.user_dict = user_dict if self.user_dict: self.update_user_dict(self.user_dict) if chn_stop_words_dir: self.stop_words_list = utils.get_chn_stop_words(chn_stop_words_dir) else: self.stop_words_list = list() def update_user_dict(self, old_user_dict_dir, new_user_dict_dir=None): # 将缺失的(或新的)股票名称、金融新词等,添加进金融词典中 word_list = [] with open(old_user_dict_dir, "r", encoding="utf-8") as file: for row in file: word_list.append(row.split("\n")[0]) name_code_df = self.database.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) new_words_list = list(set(name_code_df["name"].tolist())) for word in new_words_list: if word not in word_list: word_list.append(word) new_user_dict_dir = old_user_dict_dir if not new_user_dict_dir else new_user_dict_dir with open(new_user_dict_dir, "w", encoding="utf-8") as file: for word in word_list: file.write(word + "\n") def cut_words(self, text): outstr = list() sentence_seged = None if self.import_module == "jieba": if self.user_dict: jieba.load_userdict(self.user_dict) sentence_seged = list(jieba.cut(text)) elif self.import_module == "pkuseg": seg = pkuseg.pkuseg(user_dict=self.user_dict) # 添加自定义词典 sentence_seged = seg.cut(text) # 进行分词 if sentence_seged: for word in sentence_seged: if word not in self.stop_words_list and word != "\t" and word != " ": outstr.append(word) return outstr else: return False def find_relevant_stock_codes_in_article(self, article, stock_name_code_dict): stock_codes_set = list() cut_words_list = self.cut_words(article) if cut_words_list: for word in cut_words_list: try: stock_codes_set.append(stock_name_code_dict[word]) except Exception: pass return list(set(stock_codes_set)) def update_news_database_rows(self, database_name, collection_name): name_code_df = self.database.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) name_code_dict = dict(name_code_df.values) data = self.database.get_collection(database_name, collection_name).find() for row in data: # if row["Date"] > "2019-05-20 00:00:00": related_stock_codes_list = self.find_relevant_stock_codes_in_article( row["Article"], name_code_dict) self.database.update_row( database_name, collection_name, {"_id": row["_id"]}, {"RelatedStockCodes": " ".join(related_stock_codes_list)}) logging.info( "[{} -> {} -> {}] updated RelatedStockCodes key value ... ". format(database_name, collection_name, row["Date"]))
class NbdSpyder(Spyder): def __init__(self, database_name, collection_name): super(NbdSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection( collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name self.tokenization = Tokenization( import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH) def get_url_info(self, url): try: bs = utils.html_parser(url) except Exception: return False span_list = bs.find_all("span") part = bs.find_all("p") article = "" date = "" for span in span_list: if "class" in span.attrs and span.text and span["class"] == [ "time" ]: string = span.text.split() for dt in string: if dt.find("-") != -1: date += dt + " " elif dt.find(":") != -1: date += dt break for paragraph in part: chn_status = utils.count_chn(str(paragraph)) possible = chn_status[1] if possible > self.is_article_prob: article += str(paragraph) while article.find("<") != -1 and article.find(">") != -1: string = article[article.find("<"):article.find(">") + 1] article = article.replace(string, "") while article.find("\u3000") != -1: article = article.replace("\u3000", "") article = " ".join(re.split(" +|\n+", article)).strip() return [date, article] def get_historical_news(self, start_page=684): date_list = self.db_obj.get_data(self.db_name, self.col_name, keys=["Date"])["Date"].to_list() name_code_df = self.db_obj.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) name_code_dict = dict(name_code_df.values) if len(date_list) == 0: # 说明没有历史数据,从头开始爬取 crawled_urls_list = [] page_urls = [ "{}/{}".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_id) for page_id in range(start_page, 0, -1) ] for page_url in page_urls: bs = utils.html_parser(page_url) a_list = bs.find_all("a") for a in a_list: if "click-statistic" in a.attrs and a.string \ and a["click-statistic"].find("Article_") != -1 \ and a["href"].find("http://www.nbd.com.cn/articles/") != -1: if a["href"] not in crawled_urls_list: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a.string, a["href"])) else: # 有返回但是article为null的情况 date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format( a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) data = { "Date": date, # "PageId": page_url.split("/")[-1], "Url": a["href"], "Title": a.string, "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) } # self.col.insert_one(data) self.db_obj.insert_data( self.db_name, self.col_name, data) logging.info("[SUCCESS] {} {} {}".format( date, a.string, a["href"])) else: is_stop = False start_date = max(date_list) page_start_id = 1 while not is_stop: page_url = "{}/{}".format( config.WEBSITES_LIST_TO_BE_CRAWLED_NBD, page_start_id) bs = utils.html_parser(page_url) a_list = bs.find_all("a") for a in a_list: if "click-statistic" in a.attrs and a.string \ and a["click-statistic"].find("Article_") != -1 \ and a["href"].find("http://www.nbd.com.cn/articles/") != -1: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH)) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a.string, a["href"])) else: # 有返回但是article为null的情况 date, article = result if date > start_date: while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format( a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) data = { "Date": date, "Url": a["href"], "Title": a.string, "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) } self.db_obj.insert_data( self.db_name, self.col_name, data) logging.info("[SUCCESS] {} {} {}".format( date, a.string, a["href"])) else: is_stop = True break if not is_stop: page_start_id += 1 def get_realtime_news(self, interval=60): page_url = "{}/1".format(config.WEBSITES_LIST_TO_BE_CRAWLED_NBD) logging.info( "start real-time crawling of URL -> {}, request every {} secs ... " .format(page_url, interval)) name_code_df = self.db_obj.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) name_code_dict = dict(name_code_df.values) crawled_urls = [] date_list = self.db_obj.get_data(self.db_name, self.col_name, keys=["Date"])["Date"].to_list() latest_date = max(date_list) while True: # 每隔一定时间轮询该网址 if len(crawled_urls) > 100: # 防止list过长,内存消耗大,维持list在100条 crawled_urls.pop(0) bs = utils.html_parser(page_url) a_list = bs.find_all("a") for a in a_list: if "click-statistic" in a.attrs and a.string \ and a["click-statistic"].find("Article_") != -1 \ and a["href"].find("http://www.nbd.com.cn/articles/") != -1: if a["href"] not in crawled_urls: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH)) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a.string, a["href"])) else: # 有返回但是article为null的情况 date, article = result if date > latest_date: while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.NBD_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format( a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. NBD_MAX_REJECTED_AMOUNTS, config. RECORD_NBD_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) data = { "Date": date, # "PageId": page_url.split("/")[-1], "Url": a["href"], "Title": a.string, "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) } # self.col.insert_one(data) self.db_obj.insert_data( self.db_name, self.col_name, data) crawled_urls.append(a["href"]) logging.info("[SUCCESS] {} {} {}".format( date, a.string, a["href"])) # logging.info("sleep {} secs then request again ... ".format(interval)) time.sleep(interval)
class StockInfoSpyder(Spyder): def __init__(self, database_name, collection_name): super(StockInfoSpyder, self).__init__() self.db_obj = Database() self.col_basic_info = self.db_obj.get_collection( database_name, collection_name) self.database_name = database_name self.collection_name = collection_name self.start_program_date = datetime.datetime.now().strftime("%Y%m%d") self.redis_client = redis.StrictRedis( host="localhost", port=6379, db=config.REDIS_CLIENT_FOR_CACHING_STOCK_INFO_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d")) def get_stock_code_info(self): # TODO:每半年需要更新一次 stock_info_df = ak.stock_info_a_code_name() # 获取所有A股code和name stock_symbol_code = ak.stock_zh_a_spot().get( ["symbol", "code"]) # 获取A股所有股票的symbol和code for _id in range(stock_info_df.shape[0]): _symbol = stock_symbol_code[stock_symbol_code.code == stock_info_df .iloc[_id].code].symbol.values if len(_symbol) != 0: _dict = {"symbol": _symbol[0]} _dict.update(stock_info_df.iloc[_id].to_dict()) self.col_basic_info.insert_one(_dict) def get_historical_news(self, start_date=None, end_date=None, freq="day"): if end_date is None: end_date = datetime.datetime.now().strftime("%Y%m%d") stock_symbol_list = self.col_basic_info.distinct("symbol") if len(stock_symbol_list) == 0: self.get_stock_code_info() stock_symbol_list = self.col_basic_info.distinct("symbol") if freq == "day": start_stock_code = 0 if self.redis_client.get( "start_stock_code") is None else int( self.redis_client.get("start_stock_code").decode()) for symbol in stock_symbol_list: if int(symbol[2:]) > start_stock_code: if start_date is None: # 如果该symbol有历史数据,如果有则从API获取从数据库中最近的时间开始直到现在的所有价格数据 # 如果该symbol无历史数据,则从API获取从2015年1月1日开始直到现在的所有价格数据 _latest_date = self.redis_client.get(symbol) if _latest_date is None: symbol_start_date = config.STOCK_PRICE_REQUEST_DEFAULT_DATE else: tmp_date_dt = datetime.datetime.strptime( _latest_date.decode(), "%Y-%m-%d").date() offset = datetime.timedelta(days=1) symbol_start_date = (tmp_date_dt + offset).strftime('%Y%m%d') if symbol_start_date < end_date: stock_zh_a_daily_hfq_df = ak.stock_zh_a_daily( symbol=symbol, start_date=symbol_start_date, end_date=end_date, adjust="qfq") stock_zh_a_daily_hfq_df.insert( 0, 'date', stock_zh_a_daily_hfq_df.index.tolist()) stock_zh_a_daily_hfq_df.index = range( len(stock_zh_a_daily_hfq_df)) _col = self.db_obj.get_collection( self.database_name, symbol) for _id in range(stock_zh_a_daily_hfq_df.shape[0]): _tmp_dict = stock_zh_a_daily_hfq_df.iloc[ _id].to_dict() _tmp_dict.pop("outstanding_share") _tmp_dict.pop("turnover") _col.insert_one(_tmp_dict) self.redis_client.set( symbol, str(_tmp_dict["date"]).split(" ")[0]) logging.info( "{} finished saving from {} to {} ... ".format( symbol, symbol_start_date, end_date)) self.redis_client.set("start_stock_code", int(symbol[2:])) self.redis_client.set("start_stock_code", 0) elif freq == "week": pass elif freq == "month": pass elif freq == "5mins": pass elif freq == "15mins": pass elif freq == "30mins": pass elif freq == "60mins": pass def get_realtime_news(self, freq="day"): while True: if_updated = input( "Has the stock price dataset been updated today? (Y/N) \n") if if_updated == "Y": self.redis_client.set("is_today_updated", "1") break elif if_updated == "N": self.redis_client.set("is_today_updated", "") break self.get_historical_news() # 对所有股票补充数据到最新 while True: if freq == "day": time_now = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") if time_now.split(" ")[0] != self.redis_client.get( "today_date").decode(): self.redis_client.set("today_date", time_now.split(" ")[0]) self.redis_client.set("is_today_updated", "") # 过了凌晨,该参数设置回空值,表示今天未进行数据更新 if not bool( self.redis_client.get("is_today_updated").decode()): update_time = "{} {}".format( time_now.split(" ")[0], "15:30:00") if time_now >= update_time: stock_zh_a_spot_df = ak.stock_zh_a_spot() # 当天的日数据行情下载 for _id, sym in enumerate( stock_zh_a_spot_df["symbol"]): _col = self.db_obj.get_collection( self.database_name, sym) _tmp_dict = {} _tmp_dict.update({ "date": Timestamp("{} 00:00:00".format( time_now.split(" ")[0])) }) _tmp_dict.update( {"open": stock_zh_a_spot_df.iloc[_id].open}) _tmp_dict.update( {"high": stock_zh_a_spot_df.iloc[_id].high}) _tmp_dict.update( {"low": stock_zh_a_spot_df.iloc[_id].low}) _tmp_dict.update( {"close": stock_zh_a_spot_df.iloc[_id].trade}) _tmp_dict.update({ "volume": stock_zh_a_spot_df.iloc[_id].volume }) _col.insert_one(_tmp_dict) self.redis_client.set(sym, time_now.split(" ")[0]) logging.info( "finished updating {} price data of {} ... ". format(sym, time_now.split(" ")[0])) self.redis_client.set("is_today_updated", "1")
class GenStockNewsDB(object): def __init__(self): self.database = Database() # 获取从1990-12-19至2020-12-31股票交易日数据 self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist() self.label_range = { 3: "3DaysLabel", 5: "5DaysLabel", 10: "10DaysLabel", 15: "15DaysLabel", 30: "30DaysLabel", 60: "60DaysLabel" } def get_all_news_about_specific_stock(self, database_name, collection_name): # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的 # 股票代码保存在新的一列 _keys_list = list( next( self.database.get_collection(database_name, collection_name).find()).keys()) if "RelatedStockCodes" not in _keys_list: tokenization = Tokenization(import_module="jieba", user_dict="./Leorio/financedict.txt") tokenization.update_news_database_rows(database_name, collection_name) # 创建stock_code为名称的collection stock_symbol_list = self.database.get_data( "stock", "basic_info", keys=["symbol"])["symbol"].to_list() for symbol in stock_symbol_list: _collection = self.database.get_collection( config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol) _tmp_num_stat = 0 for row in self.database.get_collection( database_name, collection_name).find(): # 迭代器 if symbol[2:] in row["RelatedStockCodes"].split(" "): # 返回新闻发布后n天的标签 _tmp_dict = {} for label_days, key_name in self.label_range.items(): _tmp_res = self._label_news( datetime.datetime.strptime( row["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days) if _tmp_res: _tmp_dict.update({key_name: _tmp_res}) _data = { "Date": row["Date"], "Url": row["Url"], "Title": row["Title"], "Article": row["Article"], "OriDB": database_name, "OriCOL": collection_name } _data.update(_tmp_dict) _collection.insert_one(_data) _tmp_num_stat += 1 logging.info( "there are {} news mentioned {} in {} collection ... ".format( _tmp_num_stat, symbol, collection_name)) def _label_news(self, date, symbol, n_days): """ :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0) :param symbol: 类型str,表示股票标的,如sh600000 :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息 """ # 计算新闻发布当天经过n_days天后的具体年月日 this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": date}) # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六, # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据 tmp_date = date if this_date_data is None: i = 1 while this_date_data is None and i <= 10: tmp_date -= datetime.timedelta(days=i) # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据 if tmp_date.strftime("%Y-%m-%d") in self.trade_date: this_date_data = self.database.get_data( config.STOCK_DATABASE_NAME, symbol, query={"date": tmp_date}) i += 1 try: close_price_this_date = this_date_data["close"][0] except Exception: close_price_this_date = None # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天 # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据 new_date = date + datetime.timedelta(days=n_days) n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": new_date}) if n_days_later_data is None: i = 1 while n_days_later_data is None and i <= 10: new_date = date + datetime.timedelta(days=n_days + i) if new_date.strftime("%Y-%m-%d") in self.trade_date: n_days_later_data = self.database.get_data( config.STOCK_DATABASE_NAME, symbol, query={"date": new_date}) i += 1 try: close_price_n_days_later = n_days_later_data["close"][0] except Exception: close_price_n_days_later = None # 判断条件: # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息 # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息 # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息 # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息 # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响 param = 0.01 if n_days <= 10: param = 0.03 elif 10 < n_days <= 15: param = 0.05 elif 15 < n_days <= 30: param = 0.10 elif 30 < n_days <= 60: param = 0.15 if close_price_this_date is not None and close_price_n_days_later is not None: if (close_price_n_days_later - close_price_this_date) / close_price_this_date > param: return "利好" elif (close_price_n_days_later - close_price_this_date) / close_price_this_date < -param: return "利空" else: return "中性" else: return False
class CnStockSpyder(Spyder): def __init__(self, database_name, collection_name): super(CnStockSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection( collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name self.tokenization = Tokenization( import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH) self.redis_client = redis.StrictRedis(host=config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_NEWS_REDIS_DB_ID) def get_url_info(self, url): try: bs = utils.html_parser(url) except Exception: return False span_list = bs.find_all("span") part = bs.find_all("p") article = "" date = "" for span in span_list: if "class" in span.attrs and span["class"] == ["timer"]: date = span.text break for paragraph in part: chn_status = utils.count_chn(str(paragraph)) possible = chn_status[1] if possible > self.is_article_prob: article += str(paragraph) while article.find("<") != -1 and article.find(">") != -1: string = article[article.find("<"):article.find(">") + 1] article = article.replace(string, "") while article.find("\u3000") != -1: article = article.replace("\u3000", "") article = " ".join(re.split(" +|\n+", article)).strip() return [date, article] def get_historical_news(self, url, category_chn=None, start_date=None): """ :param url: 爬虫网页 :param category_chn: 所属类别, 中文字符串, 包括'公司聚焦', '公告解读', '公告快讯', '利好公告' :param start_date: 数据库中category_chn类别新闻最近一条数据的时间 """ assert category_chn is not None driver = webdriver.Chrome(executable_path=config.CHROME_DRIVER) btn_more_text = "" crawled_urls_list = self.extract_data(["Url"])[0] logging.info("historical data length -> {} ... ".format( len(crawled_urls_list))) # crawled_urls_list = [] driver.get(url) name_code_df = self.db_obj.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) name_code_dict = dict(name_code_df.values) if start_date is None: while btn_more_text != "没有更多": more_btn = driver.find_element_by_id('j_more_btn') btn_more_text = more_btn.text logging.info("1-{}".format(more_btn.text)) if btn_more_text == "加载更多": more_btn.click() time.sleep(random.random()) # sleep random time less 1s elif btn_more_text == "加载中...": time.sleep(random.random() + 2) more_btn = driver.find_element_by_id('j_more_btn') btn_more_text = more_btn.text logging.info("2-{}".format(more_btn.text)) if btn_more_text == "加载更多": more_btn.click() else: more_btn.click() break bs = BeautifulSoup(driver.page_source, "html.parser") for li in bs.find_all("li", attrs={"class": ["newslist"]}): a = li.find_all("h2")[0].find("a") if a["href"] not in crawled_urls_list: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH)) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format(a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a["title"], a["href"])) else: # 有返回但是article为null的情况 date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) data = { "Date": date, "Category": category_chn, "Url": a["href"], "Title": a["title"], "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) } # self.col.insert_one(data) self.db_obj.insert_data(self.db_name, self.col_name, data) logging.info("[SUCCESS] {} {} {}".format( date, a["title"], a["href"])) else: # 当start_date不为None时,补充历史数据 is_click_button = True start_get_url_info = False tmp_a = None while is_click_button: bs = BeautifulSoup(driver.page_source, "html.parser") for li in bs.find_all("li", attrs={"class": ["newslist"]}): a = li.find_all("h2")[0].find("a") if tmp_a is not None and a["href"] != tmp_a: continue elif tmp_a is not None and a["href"] == tmp_a: start_get_url_info = True if start_get_url_info: date, _ = self.get_url_info(a["href"]) if date <= start_date: is_click_button = False break tmp_a = a["href"] if is_click_button: more_btn = driver.find_element_by_id('j_more_btn') more_btn.click() # 从一开始那条新闻到tmp_a都是新增新闻,不包括tmp_a bs = BeautifulSoup(driver.page_source, "html.parser") for li in bs.find_all("li", attrs={"class": ["newslist"]}): a = li.find_all("h2")[0].find("a") if a["href"] != tmp_a: result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH)) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format(a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a["title"], a["href"])) else: # 有返回但是article为null的情况 date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) data = { "Date": date, "Category": category_chn, "Url": a["href"], "Title": a["title"], "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) } # self.col.insert_one(data) self.db_obj.insert_data(self.db_name, self.col_name, data) logging.info("[SUCCESS] {} {} {}".format( date, a["title"], a["href"])) else: break driver.quit() def get_realtime_news(self, url, category_chn=None, interval=60): logging.info( "start real-time crawling of URL -> {}, request every {} secs ... " .format(url, interval)) assert category_chn is not None # TODO: 由于cnstock爬取的数据量并不大,这里暂时是抽取历史所有数据进行去重,之后会修改去重策略 name_code_df = self.db_obj.get_data( config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["name", "code"]) name_code_dict = dict(name_code_df.values) crawled_urls = self.db_obj.get_data(self.db_name, self.col_name, keys=["Url"])["Url"].to_list() while True: # 每隔一定时间轮询该网址 bs = utils.html_parser(url) for li in bs.find_all("li", attrs={"class": ["newslist"]}): a = li.find_all("h2")[0].find("a") if a["href"] not in crawled_urls: # latest_3_days_crawled_href result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config.CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH)) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format(a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a["title"], a["href"])) else: # 有返回但是article为null的情况 date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info(a["href"]) while not result: self.terminated_amount += 1 if self.terminated_amount > config.CNSTOCK_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format(a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. CNSTOCK_MAX_REJECTED_AMOUNTS, config. RECORD_CNSTOCK_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info(a["href"]) date, article = result self.is_article_prob = .5 if article != "": related_stock_codes_list = self.tokenization.find_relevant_stock_codes_in_article( article, name_code_dict) self.db_obj.insert_data( self.db_name, self.col_name, { "Date": date, "Category": category_chn, "Url": a["href"], "Title": a["title"], "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list) }) self.redis_client.lpush( config.CACHE_NEWS_LIST_NAME, json.dumps({ "Date": date, "Category": category_chn, "Url": a["href"], "Title": a["title"], "Article": article, "RelatedStockCodes": " ".join(related_stock_codes_list), "OriDB": config.DATABASE_NAME, "OriCOL": config.COLLECTION_NAME_CNSTOCK })) logging.info("[SUCCESS] {} {} {}".format( date, a["title"], a["href"])) crawled_urls.append(a["href"]) # logging.info("sleep {} secs then request {} again ... ".format(interval, url)) time.sleep(interval)
class JrjSpyder(Spyder): def __init__(self, database_name, collection_name): super(JrjSpyder, self).__init__() self.db_obj = Database() self.col = self.db_obj.conn[database_name].get_collection( collection_name) self.terminated_amount = 0 self.db_name = database_name self.col_name = collection_name def get_url_info(self, url, specific_date): try: bs = utils.html_parser(url) except Exception: return False date = "" for span in bs.find_all("span"): if span.contents[0] == "jrj_final_date_start": date = span.text.replace("\r", "").replace("\n", "") break if date == "": date = specific_date article = "" for p in bs.find_all("p"): if not p.find_all("jrj_final_daohang_start") and p.attrs == {} and \ not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i") and not p.find_all("span"): # if p.contents[0] != "jrj_final_daohang_start1" and p.attrs == {} and \ # not p.find_all("input") and not p.find_all("a", attrs={"class": "red"}) and not p.find_all("i"): article += p.text.replace("\r", "").replace("\n", "").replace( "\u3000", "") return [date, article] def get_historical_news(self, url, start_date, end_date): # # 抽取数据库中已爬取的从start_date到latest_date_str所有新闻,避免重复爬取 # # 比如数据断断续续爬到了2016-10-10 15:00:00时间节点,但是在没调整参数的情 # # 况下,从2015-01-01(自己设定)开始重跑程序会导致大量重复数据,因此在这里稍 # # 作去重。直接从最新的时间节点开始跑是完全没问题,但从2015-01-01(自己设定) # # 开始重跑程序可以尝试将前面未成功爬取的URL重新再试一遍 # extracted_data_list = self.extract_data(["Date"])[0] # if len(extracted_data_list) != 0: # latest_date_str = max(extracted_data_list).split(" ")[0] # else: # latest_date_str = start_date # logging.info("latest time in database is {} ... ".format(latest_date_str)) # crawled_urls_list = list() # for _date in utils.get_date_list_from_range(start_date, latest_date_str): # query_results = self.query_news("Date", _date) # for qr in query_results: # crawled_urls_list.append(qr["Url"]) # # crawled_urls_list = self.extract_data(["Url"])[0] # abandoned # logging.info("the length of crawled data from {} to {} is {} ... ".format(start_date, # latest_date_str, # len(crawled_urls_list))) crawled_urls_list = [] dates_list = utils.get_date_list_from_range(start_date, end_date) dates_separated_into_ranges_list = utils.gen_dates_list( dates_list, config.JRJ_DATE_RANGE) for dates_range in dates_separated_into_ranges_list: for date in dates_range: first_url = "{}/{}/{}_1.shtml".format( url, date.replace("-", "")[0:6], date.replace("-", "")) max_pages_num = utils.search_max_pages_num(first_url, date) for num in range(1, max_pages_num + 1): _url = "{}/{}/{}_{}.shtml".format( url, date.replace("-", "")[0:6], date.replace("-", ""), str(num)) bs = utils.html_parser(_url) a_list = bs.find_all("a") for a in a_list: if "href" in a.attrs and a.string and \ a["href"].find("/{}/{}/".format(date.replace("-", "")[:4], date.replace("-", "")[4:6])) != -1: if a["href"] not in crawled_urls_list: # 如果标题不包含"收盘","报于"等字样,即可写入数据库,因为包含这些字样标题的新闻多为机器自动生成 if a.string.find("收盘") == -1 and a.string.find("报于") == -1 and \ a.string.find("新三板挂牌上市") == -1: result = self.get_url_info(a["href"], date) while not result: self.terminated_amount += 1 if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write("{}\n".format( a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. JRJ_MAX_REJECTED_AMOUNTS, config. RECORD_JRJ_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount)) time.sleep(60 * self.terminated_amount) result = self.get_url_info( a["href"], date) if not result: # 爬取失败的情况 logging.info("[FAILED] {} {}".format( a.string, a["href"])) else: # 有返回但是article为null的情况 article_specific_date, article = result while article == "" and self.is_article_prob >= .1: self.is_article_prob -= .1 result = self.get_url_info( a["href"], date) while not result: self.terminated_amount += 1 if self.terminated_amount > config.JRJ_MAX_REJECTED_AMOUNTS: # 始终无法爬取的URL保存起来 with open( config. RECORD_JRJ_FAILED_URL_TXT_FILE_PATH, "a+") as file: file.write( "{}\n".format( a["href"])) logging.info( "rejected by remote server longer than {} minutes, " "and the failed url has been written in path {}" .format( config. JRJ_MAX_REJECTED_AMOUNTS, config. RECORD_JRJ_FAILED_URL_TXT_FILE_PATH )) break logging.info( "rejected by remote server, request {} again after " "{} seconds...".format( a["href"], 60 * self.terminated_amount) ) time.sleep( 60 * self.terminated_amount) result = self.get_url_info( a["href"], date) article_specific_date, article = result self.is_article_prob = .5 if article != "": data = { "Date": article_specific_date, "Url": a["href"], "Title": a.string, "Article": article } # self.col.insert_one(data) self.db_obj.insert_data( self.db_name, self.col_name, data) logging.info( "[SUCCESS] {} {} {}".format( article_specific_date, a.string, a["href"])) self.terminated_amount = 0 # 爬取结束后重置该参数 else: logging.info("[QUIT] {}".format(a.string)) def get_realtime_news(self, url): pass
class GenStockNewsDB(object): def __init__(self): self.database = Database() # 获取从1990-12-19至2020-12-31股票交易日数据 self.trade_date = ak.tool_trade_date_hist_sina()["trade_date"].tolist() self.label_range = {3: "3DaysLabel", 5: "5DaysLabel", 10: "10DaysLabel", 15: "15DaysLabel", 30: "30DaysLabel", 60: "60DaysLabel"} self.redis_client = redis.StrictRedis(host=config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_NEWS_REDIS_DB_ID) self.redis_client.set("today_date", datetime.datetime.now().strftime("%Y-%m-%d")) self.redis_client.delete("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML)) self._stock_news_nums_stat() def get_all_news_about_specific_stock(self, database_name, collection_name): # 获取collection_name的key值,看是否包含RelatedStockCodes,如果没有说明,没有做将新闻中所涉及的 # 股票代码保存在新的一列 _keys_list = list(next(self.database.get_collection(database_name, collection_name).find()).keys()) if "RelatedStockCodes" not in _keys_list: tokenization = Tokenization(import_module="jieba", user_dict="./Leorio/financedict.txt") tokenization.update_news_database_rows(database_name, collection_name) # 创建stock_code为名称的collection stock_symbol_list = self.database.get_data(config.STOCK_DATABASE_NAME, config.COLLECTION_NAME_STOCK_BASIC_INFO, keys=["symbol"])["symbol"].to_list() col_names = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None) for symbol in stock_symbol_list: if symbol not in col_names: # if int(symbol[2:]) > 837: _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol) _tmp_num_stat = 0 for row in self.database.get_collection(database_name, collection_name).find(): # 迭代器 if symbol[2:] in row["RelatedStockCodes"].split(" "): # 返回新闻发布后n天的标签 _tmp_dict = {} for label_days, key_name in self.label_range.items(): _tmp_res = self._label_news( datetime.datetime.strptime(row["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days) _tmp_dict.update({key_name: _tmp_res}) _data = {"Date": row["Date"], "Url": row["Url"], "Title": row["Title"], "Article": row["Article"], "OriDB": database_name, "OriCOL": collection_name} _data.update(_tmp_dict) _collection.insert_one(_data) _tmp_num_stat += 1 logging.info("there are {} news mentioned {} in {} collection need to be fetched ... " .format(_tmp_num_stat, symbol, collection_name)) # else: # logging.info("{} has fetched all related news from {}...".format(symbol, collection_name)) def listen_redis_queue(self): # 监听redis消息队列,当新的实时数据过来时,根据"RelatedStockCodes"字段,将新闻分别保存到对应的股票数据库 # e.g.:缓存新的一条数据中,"RelatedStockCodes"字段数据为"603386 603003 600111 603568",则将该条新闻分别 # 都存进这四支股票对应的数据库中 crawled_url_today = set() while True: date_now = datetime.datetime.now().strftime("%Y-%m-%d") if date_now != self.redis_client.get("today_date").decode(): crawled_url_today = set() self.redis_client.set("today_date", date_now) if self.redis_client.llen(config.CACHE_NEWS_LIST_NAME) != 0: data = json.loads(self.redis_client.lindex(config.CACHE_NEWS_LIST_NAME, -1)) if data["Url"] not in crawled_url_today: # 排除重复插入冗余文本 crawled_url_today.update({data["Url"]}) if data["RelatedStockCodes"] != "": for stock_code in data["RelatedStockCodes"].split(" "): # 将新闻分别送进相关股票数据库 symbol = "sh{}".format(stock_code) if stock_code[0] == "6" else "sz{}".format(stock_code) _collection = self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol) _tmp_dict = {} for label_days, key_name in self.label_range.items(): _tmp_res = self._label_news( datetime.datetime.strptime(data["Date"].split(" ")[0], "%Y-%m-%d"), symbol, label_days) _tmp_dict.update({key_name: _tmp_res}) _data = {"Date": data["Date"], "Url": data["Url"], "Title": data["Title"], "Article": data["Article"], "OriDB": data["OriDB"], "OriCOL": data["OriCOL"]} _data.update(_tmp_dict) _collection.insert_one(_data) logging.info("the real-time fetched news {}, which was saved in [DB:{} - COL:{}] ...".format(data["Title"], config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, symbol)) # # if symbol.encode() in self.redis_client.lrange("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), 0, -1): # label_name = "3DaysLabel" # # classifier_save_path = "{}_classifier.pkl".format(symbol) # ori_dict_path = "{}_docs_dict.dict".format(symbol) # bowvec_save_path = "{}_bowvec.mm".format(symbol) # # topicmodelling = TopicModelling() # chn_label = topicmodelling.classify_stock_news(data["Article"], # config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, # symbol, # label_name=label_name, # topic_model_type="lsi", # classifier_model="rdforest", # rdforest / svm # ori_dict_path=ori_dict_path, # bowvec_save_path=bowvec_save_path) # logging.info( # "document '{}...' was classified with label '{}' for symbol {} ... ".format( # data["Article"][:20], chn_label, symbol)) self.redis_client.rpop(config.CACHE_NEWS_LIST_NAME) logging.info("now pop {} from redis queue of [DB:{} - KEY:{}] ... ".format(data["Title"], config.CACHE_NEWS_REDIS_DB_ID, config.CACHE_NEWS_LIST_NAME)) def _label_news(self, date, symbol, n_days): """ :param date: 类型datetime.datetime,表示新闻发布的日期,只包括年月日,不包括具体时刻,如datetime.datetime(2015, 1, 5, 0, 0) :param symbol: 类型str,表示股票标的,如sh600000 :param n_days: 类型int,表示根据多少天后的价格设定标签,如新闻发布后n_days天,如果收盘价格上涨,则认为该则新闻是利好消息 """ # 计算新闻发布当天经过n_days天后的具体年月日 this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": date}) # 考虑情况:新闻发布日期是非交易日,因此该日期没有价格数据,则往前寻找,比如新闻发布日期是2020-12-12是星期六, # 则考虑2020-12-11日的收盘价作为该新闻发布时的数据 tmp_date = date if this_date_data is None: i = 1 while this_date_data is None and i <= 10: tmp_date -= datetime.timedelta(days=i) # 判断日期是否是交易日,如果是再去查询数据库;如果this_date_data还是NULL值,则说明数据库没有该交易日数据 if tmp_date.strftime("%Y-%m-%d") in self.trade_date: this_date_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": tmp_date}) i += 1 try: close_price_this_date = this_date_data["close"][0] except Exception: close_price_this_date = None # 考虑情况:新闻发布后n_days天是非交易日,或者没有采集到数据,因此向后寻找,如新闻发布日期是2020-12-08,5天 # 后的日期是2020-12-13是周日,因此将2020-12-14日周一的收盘价作为n_days后的数据 new_date = date + datetime.timedelta(days=n_days) n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": new_date}) if n_days_later_data is None: i = 1 while n_days_later_data is None and i <= 10: new_date = date + datetime.timedelta(days=n_days+i) if new_date.strftime("%Y-%m-%d") in self.trade_date: n_days_later_data = self.database.get_data(config.STOCK_DATABASE_NAME, symbol, query={"date": new_date}) i += 1 try: close_price_n_days_later = n_days_later_data["close"][0] except Exception: close_price_n_days_later = None # 判断条件: # (1)如果n_days个交易日后且n_days<=10天,则价格上涨(下跌)超过3%,则认为该新闻是利好(利空)消息;如果价格在3%的范围内,则为中性消息 # (2)如果n_days个交易日后且10<n_days<=15天,则价格上涨(下跌)超过5%,则认为该新闻是利好(利空)消息;如果价格在5%的范围内,则为中性消息 # (3)如果n_days个交易日后且15<n_days<=30天,则价格上涨(下跌)超过10%,则认为该新闻是利好(利空)消息;如果价格在10%的范围内,则为中性消息 # (4)如果n_days个交易日后且30<n_days<=60天,则价格上涨(下跌)超过15%,则认为该新闻是利好(利空)消息;如果价格在15%的范围内,则为中性消息 # Note:中性消息定义为,该消息迅速被市场消化,并没有持续性影响 param = 0.01 if n_days <= 10: param = 0.03 elif 10 < n_days <= 15: param = 0.05 elif 15 < n_days <= 30: param = 0.10 elif 30 < n_days <= 60: param = 0.15 if close_price_this_date is not None and close_price_n_days_later is not None: if (close_price_n_days_later - close_price_this_date) / close_price_this_date > param: return "利好" elif (close_price_n_days_later - close_price_this_date) / close_price_this_date < -param: return "利空" else: return "中性" else: return "" def _stock_news_nums_stat(self): cols_list = self.database.connect_database(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE).list_collection_names(session=None) for sym in cols_list: if self.database.get_collection(config.ALL_NEWS_OF_SPECIFIC_STOCK_DATABASE, sym).estimated_document_count() > config.MINIMUM_STOCK_NEWS_NUM_FOR_ML: self.redis_client.lpush("stock_news_num_over_{}".format(config.MINIMUM_STOCK_NEWS_NUM_FOR_ML), sym)
class TopicModelling(object): def __init__(self): self.tokenization = Tokenization( import_module="jieba", user_dict=config.USER_DEFINED_DICT_PATH, chn_stop_words_dir=config.CHN_STOP_WORDS_PATH) self.database = Database() self.classifier = Classifier() def create_dictionary(self, raw_documents_list, save_path=None, is_saved=False): """ 将文中每个词汇关联唯一的ID,因此需要定义词汇表 :param: raw_documents_list, 原始语料列表,每个元素即文本,如["洗尽铅华...", "风雨赶路人...", ...] :param: savepath, corpora.Dictionary对象保存路径 """ documents_token_list = [] for doc in raw_documents_list: documents_token_list.append(self.tokenization.cut_words(doc)) _dict = corpora.Dictionary(documents_token_list) # 找到只出现一次的token once_items = [ _dict[tokenid] for tokenid, docfreq in _dict.dfs.items() if docfreq == 1 ] # 在documents_token_list的每一条语料中,删除只出现一次的token for _id, token_list in enumerate(documents_token_list): documents_token_list[_id] = list( filter(lambda token: token not in once_items, token_list)) # 极端情况,某一篇语料所有token只出现一次,这样该篇新闻语料的token列表就变为空,因此删除掉 documents_token_list = [ token_list for token_list in documents_token_list if (len(token_list) != 0) ] # 找到只出现一次的token对应的id once_ids = [ tokenid for tokenid, docfreq in _dict.dfs.items() if docfreq == 1 ] # 删除仅出现一次的词 _dict.filter_tokens(once_ids) # 消除id序列在删除词后产生的不连续的缺口 _dict.compactify() if is_saved and save_path: _dict.save(save_path) logging.info( "new generated dictionary saved in path -> {} ...".format( save_path)) return _dict, documents_token_list def renew_dictionary(self, old_dict_path, new_raw_documents_list, new_dict_path=None, is_saved=False): documents_token_list = [] for doc in new_raw_documents_list: documents_token_list.append(self.tokenization.cut_words(doc)) _dict = corpora.Dictionary.load(old_dict_path) _dict.add_documents(documents_token_list) if new_dict_path: old_dict_path = new_dict_path if is_saved: _dict.save(old_dict_path) logging.info( "updated dictionary by another raw documents serialized in {} ... " .format(old_dict_path)) return _dict, documents_token_list def create_bag_of_word_representation(self, raw_documents_list, old_dict_path=None, new_dict_path=None, bow_vector_save_path=None, is_saved_dict=False): if old_dict_path: # 如果存在旧的语料词典,就在原先词典的基础上更新,增加未见过的词 corpora_dictionary, documents_token_list = self.renew_dictionary( old_dict_path, raw_documents_list, new_dict_path=new_dict_path) else: # 否则重新创建词典 start_time = time.time() corpora_dictionary, documents_token_list = self.create_dictionary( raw_documents_list, save_path=new_dict_path, is_saved=is_saved_dict) end_time = time.time() logging.info( "there are {} mins spent to create a new dictionary ... ". format((end_time - start_time) / 60)) # 根据新词典对文档(或语料)生成对应的词袋向量 start_time = time.time() bow_vector = [ corpora_dictionary.doc2bow(doc_token) for doc_token in documents_token_list ] end_time = time.time() logging.info( "there are {} mins spent to calculate bow-vector ... ".format( (end_time - start_time) / 60)) if bow_vector_save_path: corpora.MmCorpus.serialize(bow_vector_save_path, bow_vector) return documents_token_list, corpora_dictionary, bow_vector @staticmethod def transform_vectorized_corpus(corpora_dictionary, bow_vector, model_type="lda", model_save_path=None): # 如何没有保存任何模型,重新训练的情况下,可以选择该函数 model_vector = None if model_type == "lsi": # LSI(Latent Semantic Indexing)模型,将文本从词袋向量或者词频向量(更好),转为一个低维度的latent空间 # 对于现实语料,目标维度在200-500被认为是"黄金标准" model_tfidf = models.TfidfModel(bow_vector) # model_tfidf.save("model_tfidf.tfidf") tfidf_vector = model_tfidf[bow_vector] model = models.LsiModel(tfidf_vector, id2word=corpora_dictionary, num_topics=config.TOPIC_NUMBER) # 初始化模型 model_vector = model[tfidf_vector] if model_save_path: model.save(model_save_path) elif model_type == "lda": model = models.LdaModel(bow_vector, id2word=corpora_dictionary, num_topics=config.TOPIC_NUMBER) # 初始化模型 model_vector = model[bow_vector] if model_save_path: model.save(model_save_path) elif model_type == "tfidf": model = models.TfidfModel(bow_vector) # 初始化 # model = models.TfidfModel.load("model_tfidf.tfidf") model_vector = model[bow_vector] # 将整个语料进行转换 if model_save_path: model.save(model_save_path) return model_vector def classify_stock_news(self, unseen_raw_document, database_name, collection_name, label_name="60DaysLabel", topic_model_type="lda", classifier_model="svm", ori_dict_path=None, bowvec_save_path=None, is_saved_bow_vector=False): historical_raw_documents_list = [] Y = [] for row in self.database.get_collection(database_name, collection_name).find(): if label_name in row.keys(): if row[label_name] != "": historical_raw_documents_list.append(row["Article"]) Y.append(row[label_name]) logging.info( "fetch symbol '{}' historical news with label '{}' from [DB:'{}' - COL:'{}'] ... " .format(collection_name, label_name, database_name, collection_name)) le = preprocessing.LabelEncoder() Y = le.fit_transform(Y) logging.info( "encode historical label list by sklearn preprocessing for training ... " ) label_name_list = le.classes_ # ['中性' '利好' '利空'] -> [0, 1, 2] # 根据历史新闻数据库创建词典,以及计算每个历史新闻的词袋向量;如果历史数据库创建的字典存在,则加载进内存 # 用未见过的新闻tokens去更新该词典 if not os.path.exists(ori_dict_path): if not os.path.exists(bowvec_save_path): _, _, historical_bow_vec = self.create_bag_of_word_representation( historical_raw_documents_list, new_dict_path=ori_dict_path, bow_vector_save_path=bowvec_save_path, is_saved_dict=True) logging.info( "create dictionary of historical news, and serialized in path -> {} ... " .format(ori_dict_path)) logging.info( "create bow-vector of historical news, and serialized in path -> {} ... " .format(bowvec_save_path)) else: _, _, _ = self.create_bag_of_word_representation( historical_raw_documents_list, new_dict_path=ori_dict_path, is_saved_dict=True) logging.info( "create dictionary of historical news, and serialized in path -> {} ... " .format(ori_dict_path)) else: if not os.path.exists(bowvec_save_path): _, _, historical_bow_vec = self.create_bag_of_word_representation( historical_raw_documents_list, new_dict_path=ori_dict_path, bow_vector_save_path=bowvec_save_path, is_saved_dict=True) logging.info( "historical news dictionary existed, which saved in path -> {}, but not the historical bow-vector" " ... ".format(ori_dict_path)) else: historical_bow_vec_mmcorpus = corpora.MmCorpus( bowvec_save_path ) # type -> <gensim.corpora.mmcorpus.MmCorpus> historical_bow_vec = [] for _bow in historical_bow_vec_mmcorpus: historical_bow_vec.append(_bow) logging.info( "both historical news dictionary and bow-vector existed, load historical bow-vector to memory ... " ) start_time = time.time() updated_dictionary_with_old_and_unseen_news, unssen_documents_token_list = self.renew_dictionary( ori_dict_path, [unseen_raw_document], is_saved=True) end_time = time.time() logging.info( "renew dictionary with unseen news tokens, and serialized in path -> {}, " "which took {} mins ... ".format(ori_dict_path, (end_time - start_time) / 60)) unseen_bow_vector = [ updated_dictionary_with_old_and_unseen_news.doc2bow(doc_token) for doc_token in unssen_documents_token_list ] updated_bow_vector_with_old_and_unseen_news = [] updated_bow_vector_with_old_and_unseen_news.extend(historical_bow_vec) updated_bow_vector_with_old_and_unseen_news.extend(unseen_bow_vector) # 原先updated_bow_vector_with_old_and_unseen_news是list类型, # 但是经过下面序列化后重新加载进来的类型是gensim.corpora.mmcorpus.MmCorpus if is_saved_bow_vector and bowvec_save_path: corpora.MmCorpus.serialize( bowvec_save_path, updated_bow_vector_with_old_and_unseen_news ) # 保存更新后的bow向量,即包括新旧新闻的bow向量集 logging.info( "combined bow vector(type -> 'list') generated by historical news with unseen bow " "vector to create a new one ... ") if topic_model_type == "lsi": start_time = time.time() updated_tfidf_model_vector = self.transform_vectorized_corpus( updated_dictionary_with_old_and_unseen_news, updated_bow_vector_with_old_and_unseen_news, model_type="tfidf" ) # type -> <gensim.interfaces.TransformedCorpus object> end_time = time.time() logging.info( "regenerated TF-IDF model vector by updated dictionary and updated bow-vector, " "which took {} mins ... ".format((end_time - start_time) / 60)) start_time = time.time() model = models.LsiModel( updated_tfidf_model_vector, id2word=updated_dictionary_with_old_and_unseen_news, num_topics=config.TOPIC_NUMBER) # 初始化模型 model_vector = model[ updated_tfidf_model_vector] # type -> <gensim.interfaces.TransformedCorpus object> end_time = time.time() logging.info( "regenerated LSI model vector space by updated TF-IDF model vector space, " "which took {} mins ... ".format((end_time - start_time) / 60)) elif topic_model_type == "lda": start_time = time.time() model_vector = self.transform_vectorized_corpus( updated_dictionary_with_old_and_unseen_news, updated_bow_vector_with_old_and_unseen_news, model_type="lda") end_time = time.time() logging.info( "regenerated LDA model vector space by updated dictionary and bow-vector, " "which took {} mins ... ".format((end_time - start_time) / 60)) # 将gensim.interfaces.TransformedCorpus类型的lsi模型向量转为numpy矩阵 start_time = time.time() latest_matrix = corpus2dense(model_vector, num_terms=model_vector.obj.num_terms).T end_time = time.time() logging.info( "transform {} model vector space to numpy.adarray, " "which took {} mins ... ".format(topic_model_type.upper(), (end_time - start_time) / 60)) # 利用历史数据的话题模型向量(或特征),进一步训练新闻分类器 start_time = time.time() train_x, train_y, test_x, test_y = utils.generate_training_set( latest_matrix[:-1, :], Y) clf = self.classifier.train(train_x, train_y, test_x, test_y, model_type=classifier_model) end_time = time.time() logging.info( "finished training by sklearn {} using latest {} model vector space, which took {} mins ... " .format(classifier_model.upper(), topic_model_type.upper(), (end_time - start_time) / 60)) label_id = clf.predict(latest_matrix[-1, :].reshape(1, -1))[0] return label_name_list[label_id]
def __init__(self, database_name, collection_name): self.database = Database() self.database_name = database_name self.collection_name = collection_name self.delete_num = 0
import logging import threading from Kite import config from Kite.database import Database from Killua.denull import DeNull from Killua.deduplication import Deduplication from Gon.cnstockspyder import CnStockSpyder redis_client = redis.StrictRedis( config.REDIS_IP, port=config.REDIS_PORT, db=config.CACHE_RECORED_OPENED_PYTHON_PROGRAM_DB_ID) redis_client.lpush(config.CACHE_RECORED_OPENED_PYTHON_PROGRAM_VAR, "realtime_starter_cnstock.py") obj = Database() df = obj.get_data(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK, keys=["Date", "Category"]) cnstock_spyder = CnStockSpyder(config.DATABASE_NAME, config.COLLECTION_NAME_CNSTOCK) # 先补充历史数据,比如已爬取数据到2020-12-01,但是启动实时爬取程序在2020-12-23,则先 # 自动补充爬取2020-12-02至2020-12-23的新闻数据 for url_to_be_crawled, type_chn in config.WEBSITES_LIST_TO_BE_CRAWLED_CNSTOCK.items( ): # 查询type_chn的最近一条数据的时间 latets_date_in_db = max(df[df.Category == type_chn]["Date"].to_list()) cnstock_spyder.get_historical_news(url_to_be_crawled, category_chn=type_chn, start_date=latets_date_in_db)
def __init__(self): self.db_obj = Database() self.db = self.db_obj.create_db(config.DATABASE_NAME) self.is_article_prob = .5
def load_transform_model(self, model_path): if ".tfidf" in model_path: return models.TfidfModel.load(model_path) elif ".lsi" in model_path: return models.LsiModel.load(model_path) elif ".lda" in model_path: return models.LdaModel.load(model_path) if __name__ == "__main__": from Hisoka.classifier import Classifier from Kite.database import Database from sklearn import preprocessing database = Database() topicmodelling = TopicModelling() raw_documents_list = [] Y = [] for row in database.get_collection("stocknews", "sz000001").find(): if "30DaysLabel" in row.keys(): raw_documents_list.append(row["Article"]) Y.append(row["30DaysLabel"]) le = preprocessing.LabelEncoder() Y = le.fit_transform(Y) _, corpora_dictionary, corpus = topicmodelling.create_bag_of_word_representation( raw_documents_list) model_vector = topicmodelling.transform_vectorized_corpus( corpora_dictionary, corpus, model_type="lsi") csr_matrix = utils.convert_to_csr_matrix(model_vector)