async def run(self): # 抓取股票列表 df_entity = self.bao_get_all_securities( to_bao_entity_type(EntityType.Stock)) if pd_valid(df_entity): df_stock = self.to_entity(df_entity, entity_type=EntityType.Stock) # persist to Stock await df_to_db(region=self.region, provider=self.provider, data_schema=Stock, db_session=get_db_session(self.region, self.provider, Stock), df=df_stock) # persist StockDetail too await df_to_db(region=self.region, provider=self.provider, data_schema=StockDetail, db_session=get_db_session(self.region, self.provider, StockDetail), df=df_stock) self.logger.info("persist stock list success")
async def on_finish(self): last_year = str(now_pd_timestamp(self.region).year) codes = [item.code for item in self.entities] db_session = get_db_session(self.region, self.provider, DividendFinancing) need_filleds, column_names = DividendFinancing.query_data( region=self.region, provider=self.provider, db_session=db_session, codes=codes, end_timestamp=last_year, filters=[DividendFinancing.rights_raising_fund.is_(None)]) if need_filleds: desc = self.data_schema.__name__ + ": update relevant table" db_session_1 = get_db_session(self.region, self.provider, self.data_schema) kafka_producer = connect_kafka_producer(findy_config['kafka']) for item in need_filleds: result, column_names = self.data_schema.query_data( region=self.region, provider=self.provider, db_session=db_session_1, entity_id=item.entity_id, start_timestamp=item.timestamp, end_timestamp=f"{item.timestamp.year}-12-31", func=func.sum(self.data_schema.rights_raising_fund)) if isinstance(result, (int, float)): item.rights_raising_fund = result data = { "task": 'rig', "total": len(need_filleds), "desc": desc, "leave": True, "update": 1 } publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8')) try: db_session.commit() except Exception as e: self.logger.error(f'{self.__class__.__name__}, error: {e}') db_session.rollback() await super().on_finish()
async def on_finish(self): desc = DividendFinancing.__name__ + ": update relevant table" db_session = get_db_session(self.region, self.provider, DividendFinancing) kafka_producer = connect_kafka_producer(findy_config['kafka']) for entity in self.entities: code_security = {} code_security[entity.code] = entity need_fill_items, column_names = DividendFinancing.query_data( region=self.region, provider=self.provider, db_session=db_session, codes=list(code_security.keys()), filters=[ DividendFinancing.ipo_raising_fund.is_(None), DividendFinancing.ipo_issues != 0]) if need_fill_items and len(need_fill_items) > 0: for need_fill_item in need_fill_items: need_fill_item.ipo_raising_fund = code_security[entity.code].raising_fund data = {"task": 'div', "total": len(self.entities), "desc": desc, "leave": True, "update": 1} publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8')) try: db_session.commit() except Exception as e: self.logger.error(f'{self.__class__.__name__}, error: {e}') db_session.rollback() await super().on_finish()
async def run(self): db_session = get_db_session(self.region, self.provider, self.data_schema) kafka_producer = connect_kafka_producer(findy_config['kafka']) if not hasattr(self, 'entities'): self.entities: List = None await self.init_entities(db_session) if self.entities and len(self.entities) > 0: http_session = get_async_http_session() throttler = asyncio.Semaphore(self.share_para[0]) (taskid, desc) = self.share_para[1] data = {"task": taskid, "total": len(self.entities), "desc": desc, "leave": True, "update": 0} publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8')) # tasks = [asyncio.ensure_future(self.process_loop(entity, http_session, db_session, throttler)) for entity in self.entities] tasks = [self.process_loop(entity, http_session, db_session, kafka_producer, throttler) for entity in self.entities] # for result in asyncio.as_completed(tasks): # await result [await _ for _ in asyncio.as_completed(tasks)] await self.on_finish() return await http_session.close()
async def get_stocks(cls, region: Region, provider: Provider, timestamp, code=None, codes=None, ids=None): """ the publishing policy of portfolio positions is different for different types, overwrite this function for get the holding stocks in specific date :param code: portfolio(etf/block/index...) code :param codes: portfolio(etf/block/index...) codes :param ids: portfolio(etf/block/index...) ids :param timestamp: the date of the holding stocks :param provider: the data provider :return: """ from findy.database.plugins.register import get_schema_by_name from findy.database.context import get_db_session schema_str = f'{cls.__name__}Stock' portfolio_stock = get_schema_by_name(schema_str) db_session = get_db_session(region, provider, data_schema=portfolio_stock) data, column_names = portfolio_stock.query_data( region=region, provider=provider, db_session=db_session, code=code, codes=codes, timestamp=timestamp, ids=ids) if data and len(data) > 0: return pd.DataFrame([s.__dict__ for s in data], columns=column_names) else: return pd.DataFrame()
async def run(self): http_session = get_sync_http_session() db_session = get_db_session(self.region, self.provider, self.data_schema) # 上证、中证 await self.fetch_csi_index(http_session, db_session) # 深证 await self.fetch_szse_index(http_session, db_session)
def load_data(self): self.logger.info('load_data start') start_time = time.time() db_session = get_db_session(self.region, self.provider, self.data_schema) # params = dict(entity_ids=self.entity_ids, provider=self.provider, # columns=self.columns, start_timestamp=self.start_timestamp, # end_timestamp=self.end_timestamp, filters=self.filters, # order=self.order, limit=self.limit, level=self.level, # index=[self.category_field, self.time_field], # time_field=self.time_field) # self.logger.info(f'query_data params:{params}') # 转换成标准entity_id if self.entity_schema and not self.entity_ids: entities, column_names = get_entities( region=self.region, provider=self.provider, db_session=db_session, entity_schema=self.entity_schema, exchanges=self.exchanges, codes=self.codes, columns=[self.entity_schema.entity_id]) if len(entities) > 0: self.entity_ids = [entity.entity_id for entity in entities] else: return data, column_names = self.data_schema.query_data( region=self.region, provider=self.provider, db_session=db_session, entity_ids=self.entity_ids, columns=self.columns, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, filters=self.filters, order=self.order, limit=self.limit, level=self.level, index=[self.category_field, self.time_field], time_field=self.time_field) if data and not self.columns: self.data_df = pd.DataFrame([s.__dict__ for s in data], columns=column_names) else: self.data_df = pd.DataFrame(data, columns=column_names) cost_time = time.time() - start_time self.logger.info(f'load_data finished, cost_time:{cost_time}') for listener in self.data_listeners: listener.on_data_loaded(self.data_df)
def load_company_info(tickers=None): entity_schema = get_entity_schema_by_type(EntityType.StockDetail) db_session = get_db_session(Region.US, Provider.Yahoo, entity_schema) entities, column_names = get_entities(region=Region.US, provider=Provider.Yahoo, entity_schema=entity_schema, db_session=db_session, codes=tickers) df = pd.DataFrame([s.__dict__ for s in entities], columns=column_names) df.reset_index(drop=True, inplace=True) return df
async def run(self): db_session = get_db_session(self.region, self.provider, self.data_schema) trade_days, column_names = StockTradeDay.query_data( region=self.region, provider=self.provider, db_session=db_session, order=StockTradeDay.timestamp.desc()) if trade_days and len(trade_days) > 0: self.trade_day = [day.timestamp for day in trade_days] else: self.trade_day = [] self.logger.warning("load trade days failed") return await super().run()
async def persist_index(self, df) -> None: df['timestamp'] = df['timestamp'].apply(lambda x: to_pd_timestamp(x)) df['list_date'] = df['list_date'].apply(lambda x: to_pd_timestamp(x)) df['id'] = df['code'].apply(lambda code: f'index_cn_{code}') df['entity_id'] = df['id'] df['exchange'] = 'cn' df['entity_type'] = EntityType.Index.value df = df.dropna(axis=0, how='any') df = df.drop_duplicates(subset='id', keep='last') db_session = get_db_session(self.region, self.provider, Index) await df_to_db(region=self.region, provider=self.provider, data_schema=Index, db_session=db_session, df=df)
async def persist(self, df, db_session): start_point = time.time() # persist to Stock saved = await df_to_db(region=self.region, provider=self.provider, data_schema=self.data_schema, db_session=db_session, df=df) # persist StockDetail too await df_to_db(region=self.region, provider=self.provider, data_schema=StockDetail, db_session=get_db_session(self.region, self.provider, StockDetail), df=df, force_update=True) return True, time.time() - start_point, saved
async def init_main_index(region: Region, provider=Provider.Exchange): if region == Region.CHN: for item in CHINA_STOCK_MAIN_INDEX: item['timestamp'] = to_pd_timestamp(item['timestamp']) df = pd.DataFrame(CHINA_STOCK_MAIN_INDEX) elif region == Region.US: for item in US_STOCK_MAIN_INDEX: item['timestamp'] = to_pd_timestamp(item['timestamp']) df = pd.DataFrame(US_STOCK_MAIN_INDEX) else: print("index not initialized, in file: init_main_index") df = pd.DataFrame() if pd_valid(df): await df_to_db(region=region, provider=provider, data_schema=Index, db_session=get_db_session(region, provider, Index), df=df)
async def run(self): http_session = get_sync_http_session() db_session = get_db_session(self.region, self.provider, self.data_schema) # 抓取沪市 ETF 列表 url = 'http://query.sse.com.cn/commonQuery.do?sqlId=COMMON_SSE_ZQPZ_ETFLB_L_NEW' text = sync_get(http_session, url, headers=DEFAULT_SH_ETF_LIST_HEADER, return_type='text') if text is None: return response_dict = demjson.decode(text) df = pd.DataFrame(response_dict.get('result', [])) await self.persist_etf_list(df, ChnExchange.SSE.value, db_session) self.logger.info('沪市 ETF 列表抓取完成...') # 抓取沪市 ETF 成分股 await self.download_sh_etf_component(df, http_session, db_session) self.logger.info('沪市 ETF 成分股抓取完成...') # 抓取深市 ETF 列表 url = 'http://www.szse.cn/api/report/ShowReport?SHOWTYPE=xlsx&CATALOGID=1945' content = sync_get(http_session, url, return_type='content') if content is None: return df = pd.read_excel(io.BytesIO(content), dtype=str) await self.persist_etf_list(df, ChnExchange.SZSE.value, db_session) self.logger.info('深市 ETF 列表抓取完成...') # 抓取深市 ETF 成分股 await self.download_sz_etf_component(df, http_session, db_session) self.logger.info('深市 ETF 成分股抓取完成...')
async def load_window_df(self, data_schema, window): window_df = None db_session = get_db_session(self.region, self.provider, data_schema) dfs = [] for entity_id in self.entity_ids: data, column_names = data_schema.query_data( region=self.region, provider=self.provider, db_session=db_session, index=[self.category_field, self.time_field], order=data_schema.timestamp.desc(), entity_id=entity_id, limit=window) if data and len(data) > 0: df = pd.DataFrame([s.__dict__ for s in data], columns=column_names) dfs.append(df) if dfs: window_df = pd.concat(dfs) window_df = window_df.sort_index(level=[0, 1]) return window_df
async def get_portfolio_stocks(region: Region, provider: Provider, timestamp, portfolio_entity=Fund, code=None, codes=None, ids=None): portfolio_stock = f'{portfolio_entity.__name__}Stock' data_schema: PortfolioStockHistory = get_schema_by_name(portfolio_stock) db_session = get_db_session(region, provider, data_schema) latests, column_names = data_schema.query_data( region=region, provider=provider, db_session=db_session, code=code, end_timestamp=timestamp, order=data_schema.timestamp.desc(), limit=1) if latests and len(latests) > 0: latest_record = latests[0] # 获取最新的报表 data, column_names = data_schema.query_data( region=region, provider=provider, db_session=db_session, code=code, codes=codes, ids=ids, end_timestamp=timestamp, filters=[data_schema.report_date == latest_record.report_date]) if data and len(data) > 0: df = pd.DataFrame([s.__dict__ for s in data], columns=column_names) # 最新的为年报或者半年报 if latest_record.report_period == ReportPeriod.year or latest_record.report_period == ReportPeriod.half_year: return df # 季报,需要结合 年报或半年报 来算持仓 else: step = 0 while step <= 20: report_date = get_recent_report_date( latest_record.report_date, step=step) data, column_names = data_schema.query_data( region=region, provider=provider, db_session=db_session, code=code, codes=codes, ids=ids, end_timestamp=timestamp, filters=[ data_schema.report_date == to_pd_timestamp( report_date) ]) if data and len(data) > 0: pre_df = pd.DataFrame.from_records( [s.__dict__ for s in data], columns=column_names) df = df.append(pre_df) # 半年报和年报 if (ReportPeriod.half_year.value in pre_df['report_period'].tolist()) or ( ReportPeriod.year.value in pre_df['report_period'].tolist()): # 保留最新的持仓 df = df.drop_duplicates(subset=['stock_code'], keep='first') return df step = step + 1
async def on_finish_entity(self, entity, http_session, db_session): total_time = await super().on_finish_entity(entity, http_session, db_session) if not self.fetch_jq_timestamp: return total_time now = time.time() # fill the timestamp for report published date the_data_list, column_names = self.data_schema.query_data( region=self.region, provider=self.provider, db_session=db_session, entity_id=entity.id, order=self.data_schema.timestamp.asc(), filters=[ self.data_schema.timestamp == self.data_schema.report_date ]) if the_data_list and len(the_data_list) > 0: data, column_names = FinanceFactor.query_data( region=self.region, provider=self.provider, db_session=get_db_session(self.region, self.provider, FinanceFactor), entity_id=entity.id, columns=[ FinanceFactor.timestamp, FinanceFactor.report_date, FinanceFactor.id ], filters=[ FinanceFactor.timestamp != FinanceFactor.report_date, FinanceFactor.report_date >= the_data_list[0].report_date, FinanceFactor.report_date <= the_data_list[-1].report_date ]) if data and len(data) > 0: df = pd.DataFrame(data, columns=column_names) df = index_df(df, index='report_date', time_field='report_date') if pd_valid(df): for the_data in the_data_list: if the_data.report_date in df.index: the_data.timestamp = df.at[the_data.report_date, 'timestamp'] self.logger.info( 'db fill {} {} timestamp:{} for report_date:{}' .format(self.data_schema.__name__, entity.id, the_data.timestamp, the_data.report_date)) try: db_session.commit() except Exception as e: self.logger.error( f'{self.__class__.__name__}, error: {e}') db_session.rollback() return total_time + time.time() - now