async def run(self): # 抓取股票列表 df_entity = self.bao_get_all_securities( to_bao_entity_type(EntityType.Stock)) if pd_valid(df_entity): df_stock = self.to_entity(df_entity, entity_type=EntityType.Stock) # persist to Stock await df_to_db(region=self.region, provider=self.provider, data_schema=Stock, db_session=get_db_session(self.region, self.provider, Stock), df=df_stock) # persist StockDetail too await df_to_db(region=self.region, provider=self.provider, data_schema=StockDetail, db_session=get_db_session(self.region, self.provider, StockDetail), df=df_stock) self.logger.info("persist stock list success")
async def eval_fetch_timestamps(self, entity, ref_record, http_session): timestamps = self.security_timestamps_map.get(entity.id) if not timestamps: timestamps = self.init_timestamps(entity, http_session) if self.start_timestamp: timestamps = [t for t in timestamps if t >= self.start_timestamp] if self.end_timestamp: timestamps = [t for t in timestamps if t <= self.end_timestamp] self.security_timestamps_map[entity.id] = timestamps if not timestamps: return None, None, 0, timestamps timestamps.sort() latest_timestamp = None try: if pd_valid(ref_record): time_field = self.get_evaluated_time_field() latest_timestamp = ref_record[time_field].max(axis=0) except Exception as e: self.logger.warning(f'get ref_record failed with error: {e}') if latest_timestamp is not None and isinstance(latest_timestamp, pd.Timestamp): timestamps = [t for t in timestamps if t >= latest_timestamp] if timestamps: return timestamps[0], timestamps[-1], len(timestamps), timestamps return None, None, 0, None return timestamps[0], timestamps[-1], len(timestamps), timestamps
async def eval_fetch_timestamps(self, entity, ref_record, http_session): latest_timestamp = None try: if pd_valid(ref_record): time_field = self.get_evaluated_time_field() latest_timestamp = ref_record[time_field].max(axis=0) except Exception as e: self.logger.warning("get ref_record failed with error: {}".format(e)) if not latest_timestamp: latest_timestamp = entity.timestamp if not latest_timestamp: return self.start_timestamp, self.end_timestamp, self.default_size, None if latest_timestamp.date() >= now_pd_timestamp(self.region).date(): return latest_timestamp, None, 0, None if len(self.trade_day) > 0 and \ latest_timestamp.date() >= self.trade_day[0].date(): return latest_timestamp, None, 0, None if self.start_timestamp: latest_timestamp = max(latest_timestamp, self.start_timestamp) if self.end_timestamp and latest_timestamp > self.end_timestamp: size = 0 else: size = self.default_size return latest_timestamp, self.end_timestamp, size, None
async def record(self, entity, http_session, db_session, para): start_point = time.time() (ref_record, start, end, size, timestamps) = para if timestamps: original_list = [] for the_timestamp in timestamps: param = self.generate_request_param(entity, start, end, size, the_timestamp, http_session) tmp_list = None try: tmp_list = await self.api_wrapper.request( http_session, url=self.url, param=param, method=self.request_method, path_fields=self.path_fields) except Exception as e: self.logger.error(f"url: {self.url}, error: {e}") if tmp_list is None: continue # fill timestamp field for tmp in tmp_list: tmp[self.get_evaluated_time_field()] = the_timestamp original_list += tmp_list if len(original_list) == self.batch_size: break return False, time.time() - start_point, ( ref_record, pd.DataFrame.from_records(original_list)) else: param = self.generate_request_param(entity, start, end, size, None, http_session) try: result = await self.api_wrapper.request( http_session, url=self.url, param=param, method=self.request_method, path_fields=self.path_fields) df = pd.DataFrame.from_records(result) if pd_valid(df): timefield = self.get_original_time_field() df[timefield] = pd.to_datetime(df[timefield], format=PD_TIME_FORMAT_DAY) return False, time.time() - start_point, (ref_record, self.format( entity, df)) except Exception as e: self.logger.error(f"url: {self.url}, error: {e}") return True, time.time() - start_point, None
async def record(self, entity, http_session, db_session, para): start_point = time.time() # get stock info info = self.tushare_get_info(entity) if pd_valid(info): return False, time.time() - start_point, self.format(entity, info) return True, time.time() - start_point, None
async def record(self, entity, http_session, db_session, para): start_point = time.time() (ref_record, start, end, size, timestamps) = para end_timestamp = to_time_str( self.end_timestamp) if self.end_timestamp else None df = await self.yh_get_bars(http_session, entity, start=start, end=end_timestamp) if pd_valid(df): return False, time.time() - start_point, (ref_record, self.format(entity, df)) return True, time.time() - start_point, None
async def record(self, entity, http_session, db_session, para): start_point = time.time() trade_day, column_names = StockTradeDay.query_data( region=self.region, provider=self.provider, db_session=db_session, func=func.max(StockTradeDay.timestamp)) start = to_time_str(trade_day) if trade_day else "1990-12-19" df = self.bao_get_trade_days(start_date=start) if pd_valid(df): return False, time.time() - start_point, self.format(entity, df) return True, time.time() - start_point, None
def record(self, entity, http_session, db_session, para): start_point = time.time() (ref_record, start, end, size, timestamps) = para # get stock info balance_sheet = self.yh_get_balance_sheet(entity.code) if balance_sheet is None or len(balance_sheet) == 0: return True, time.time() - start_point, None balance_sheet = balance_sheet.T balance_sheet['timestamp'] = balance_sheet.index if pd_valid(balance_sheet): return False, time.time() - start_point, (ref_record, self.format(entity, balance_sheet)) return True, time.time() - start_point, None
async def init_main_index(region: Region, provider=Provider.Exchange): if region == Region.CHN: for item in CHINA_STOCK_MAIN_INDEX: item['timestamp'] = to_pd_timestamp(item['timestamp']) df = pd.DataFrame(CHINA_STOCK_MAIN_INDEX) elif region == Region.US: for item in US_STOCK_MAIN_INDEX: item['timestamp'] = to_pd_timestamp(item['timestamp']) df = pd.DataFrame(US_STOCK_MAIN_INDEX) else: print("index not initialized, in file: init_main_index") df = pd.DataFrame() if pd_valid(df): await df_to_db(region=region, provider=provider, data_schema=Index, db_session=get_db_session(region, provider, Index), df=df)
async def persist(self, entity, http_session, db_session, para): start_point = time.time() (ref_record, df_record) = para saved_counts = 0 is_finished = False if pd_valid(df_record): assert 'id' in df_record.columns saved_counts = await df_to_db(region=self.region, provider=self.provider, data_schema=self.data_schema, db_session=db_session, df=df_record, ref_df=ref_record, fix_duplicate_way=self.fix_duplicate_way) if saved_counts == 0: is_finished = True # could not get more data else: # not realtime if not self.real_time: is_finished = True # realtime and to the close time elif (self.close_hour is not None) and (self.close_minute is not None): now = now_pd_timestamp(self.region) if now.hour >= self.close_hour: if now.minute - self.close_minute >= 5: self.logger.info(f'{entity.id} now is the close time: {now}') is_finished = True if isinstance(self, KDataRecorder): is_finished = True start_timestamp = to_time_str(df_record['timestamp'].min(axis=0)) end_timestamp = to_time_str(df_record['timestamp'].max(axis=0)) self.result = [saved_counts, start_timestamp, end_timestamp] return is_finished, time.time() - start_point, saved_counts
async def process_loop(self, entity, http_session, db_session, kafka_producer, throttler): url = self.category_map_url.get(entity, None) if url is None: return async with throttler: async with http_session.get(url, headers=self.category_map_header[entity]) as response: if response.status != 200: return response = await response.read() df = self.format(resp=response, exchange=entity) if pd_valid(df): await self.persist(df, db_session) (taskid, desc) = self.share_para[1] data = {"task": taskid, "total": len(self.entities), "desc": desc, "leave": True, "update": 1} publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8'))
async def eval_fetch_timestamps(self, entity, referenced_record, http_session): # get latest record latest_record = None try: if pd_valid(referenced_record): latest_record = referenced_record.loc[referenced_record[ self.get_evaluated_time_field()].idxmax()] except Exception as e: self.logger.warning( f"get referenced_record failed with error: {e}") if latest_record is not None: remote_record = self.get_remote_latest_record(entity, http_session) if remote_record is None or (latest_record.index == remote_record.id): return None, None, 0, None else: return None, None, 10, None return None, None, 1000, None
async def record(self, entity, http_session, db_session, para): start_point = time.time() (ref_record, start, end, size, timestamps) = para start = to_time_str(start) if self.bao_trading_level in ['d', 'w', 'm']: start = max(start, "1990-12-19") else: start = max(start, "1999-07-26") df = self.bao_get_bars(to_bao_entity_id(entity), start=start, end=end if end is None else to_time_str(end), frequency=self.bao_trading_level, fields=to_bao_trading_field(self.bao_trading_level), adjustflag=to_bao_adjust_flag(self.adjust_type)) if pd_valid(df): return False, time.time() - start_point, (ref_record, self.format(entity, df)) return True, time.time() - start_point, None
async def eval_fetch_timestamps(self, entity, ref_record, http_session): latest_timestamp = None try: if pd_valid(ref_record): time_field = self.get_evaluated_time_field() latest_timestamp = ref_record[time_field].max(axis=0) except Exception as e: self.logger.warning(f'get ref record failed with error: {e}') if not latest_timestamp: latest_timestamp = entity.timestamp if not latest_timestamp: return self.start_timestamp, self.end_timestamp, self.default_size, None now = now_pd_timestamp(self.region) now_end = now.replace(hour=18, minute=0, second=0) trade_day_index = 0 if len(self.trade_day) > 0: if is_same_date(self.trade_day[trade_day_index], now) and now < now_end: trade_day_index = 1 end = self.trade_day[trade_day_index] else: end = now start_timestamp = next_dates(latest_timestamp) start = max(self.start_timestamp, start_timestamp) if self.start_timestamp else start_timestamp if start >= end: return start, end, 0, None size = self.eval_size_of_timestamp(start_timestamp=start, end_timestamp=end, level=self.level, one_day_trading_minutes=self.one_day_trading_minutes) return start, end, size, None
async def df_to_db(region: Region, provider: Provider, data_schema: DeclarativeMeta, db_session, df: pd.DataFrame, ref_df: pd.DataFrame = None, drop_duplicates: bool = True, fix_duplicate_way: str = 'ignore', force_update=False) -> object: now = time.time() if not pd_valid(df): return 0 if drop_duplicates and df.duplicated(subset='id').any(): df.drop_duplicates(subset='id', keep='last', inplace=True) schema_cols = get_schema_columns(data_schema) cols = list(set(df.columns.tolist()) & set(schema_cols)) if not cols: logger.error(f"{data_schema.__tablename__} get columns failed") return 0 df = df[cols] # force update mode, delete duplicate id data, and rewrite new data back if force_update: ids = df["id"].tolist() if len(ids) == 1: sql = f"delete from {data_schema.__tablename__} where id = '{ids[0]}'" else: sql = f"delete from {data_schema.__tablename__} where id in {tuple(ids)}" try: db_session.execute(sql) except Exception as e: logger.error( f"query {data_schema.__tablename__} failed with error: {e}") try: db_session.commit() except Exception as e: logger.error(f'df_to_db {data_schema.__tablename__}, error: {e}') db_session.rollback() df_new = df else: if ref_df is None: data, column_names = data_schema.query_data( region=region, provider=provider, db_session=db_session, columns=[data_schema.id, data_schema.timestamp]) if data and len(data) > 0: ref_df = pd.DataFrame(data, columns=column_names) if pd_valid(ref_df): df_new = df[~df.id.isin(ref_df.id)] else: df_new = df # 不能单靠ID决定是否新增,要全量比对 # if fix_duplicate_way == 'add': # df_add = df[df.id.isin(ref_df.id)] # if not df_add.empty: # df_add.id = uuid.uuid1() # df_new = pd.concat([df_new, df_add]) cost = PRECISION_STR.format(time.time() - now) logger.debug(f"remove duplicated: {cost}") saved = 0 if pd_valid(df_new): saved = to_postgresql(region, df_new, data_schema.__tablename__) cost = PRECISION_STR.format(time.time() - now) logger.debug(f"write db: {cost}, size: {saved}") return saved
def to_open(s): if pd_valid(s): return s[0]
def to_close(s): if pd_valid(s): return s[-1]
async def on_finish_entity(self, entity, http_session, db_session): total_time = await super().on_finish_entity(entity, http_session, db_session) if not self.fetch_jq_timestamp: return total_time now = time.time() # fill the timestamp for report published date the_data_list, column_names = self.data_schema.query_data( region=self.region, provider=self.provider, db_session=db_session, entity_id=entity.id, order=self.data_schema.timestamp.asc(), filters=[ self.data_schema.timestamp == self.data_schema.report_date ]) if the_data_list and len(the_data_list) > 0: data, column_names = FinanceFactor.query_data( region=self.region, provider=self.provider, db_session=get_db_session(self.region, self.provider, FinanceFactor), entity_id=entity.id, columns=[ FinanceFactor.timestamp, FinanceFactor.report_date, FinanceFactor.id ], filters=[ FinanceFactor.timestamp != FinanceFactor.report_date, FinanceFactor.report_date >= the_data_list[0].report_date, FinanceFactor.report_date <= the_data_list[-1].report_date ]) if data and len(data) > 0: df = pd.DataFrame(data, columns=column_names) df = index_df(df, index='report_date', time_field='report_date') if pd_valid(df): for the_data in the_data_list: if the_data.report_date in df.index: the_data.timestamp = df.at[the_data.report_date, 'timestamp'] self.logger.info( 'db fill {} {} timestamp:{} for report_date:{}' .format(self.data_schema.__name__, entity.id, the_data.timestamp, the_data.report_date)) try: db_session.commit() except Exception as e: self.logger.error( f'{self.__class__.__name__}, error: {e}') db_session.rollback() return total_time + time.time() - now