def on_finish(self): desc = DividendFinancing.__name__ + ": update relevant table" with tqdm(total=len(self.entities), ncols=90, desc=desc, position=2, leave=True) as pbar: session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) for entity in self.entities: code_security = {} code_security[entity.code] = entity need_fill_items = DividendFinancing.query_data( region=self.region, provider=self.provider, codes=list(code_security.keys()), return_type='domain', filters=[ DividendFinancing.ipo_raising_fund.is_(None), DividendFinancing.ipo_issues != 0 ]) for need_fill_item in need_fill_items: need_fill_item.ipo_raising_fund = code_security[ entity.code].raising_fund session.commit() pbar.update() super().on_finish()
def on_finish(self): last_year = str(now_pd_timestamp(self.region).year) codes = [item.code for item in self.entities] need_filleds = DividendFinancing.query_data(region=self.region, provider=self.provider, codes=codes, return_type='domain', filters=[DividendFinancing.rights_raising_fund.is_(None)], end_timestamp=last_year) desc = RightsIssueDetail.__name__ + ": update relevant table" with tqdm(total=len(need_filleds), ncols=90, desc=desc, position=2, leave=True) as pbar: from sqlalchemy import func session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) for item in need_filleds: result = RightsIssueDetail.query_data( region=self.region, provider=self.provider, entity_id=item.entity_id, start_timestamp=item.timestamp, end_timestamp="{}-12-31".format(item.timestamp.year), return_type='func', func=func.sum(RightsIssueDetail.rights_raising_fund)) if isinstance(result, (int, float)): item.rights_raising_fund = result session.commit() pbar.update() super().on_finish()
def init_account(self) -> AccountStats: trader_info = get_trader_info(self.region, trader_name=self.trader_name, return_type='domain', limit=1) if trader_info: self.logger.warning( "trader:{} has run before,old result would be deleted".format( self.trader_name)) session = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=TraderInfo) session.query(TraderInfo).filter( TraderInfo.trader_name == self.trader_name).delete() session.query(AccountStats).filter( AccountStats.trader_name == self.trader_name).delete() session.query(Position).filter( Position.trader_name == self.trader_name).delete() session.query(Order).filter( Order.trader_name == self.trader_name).delete() session.commit() return AccountStats(entity_id=f'trader_zvt_{self.trader_name}', timestamp=self.start_timestamp, trader_name=self.trader_name, cash=self.base_capital, input_money=self.base_capital, all_value=self.base_capital, value=0, closing=False)
def on_finish_entity(self, entity, http_session): kdatas = get_kdata(region=self.region, provider=self.provider, entity_id=entity.id, level=IntervalLevel.LEVEL_1DAY.value, order=Etf1dKdata.timestamp.asc(), return_type='domain', filters=[Etf1dKdata.cumulative_net_value.is_(None)]) if kdatas and len(kdatas) > 0: start = kdatas[0].timestamp end = kdatas[-1].timestamp # 从东方财富获取基金累计净值 df = self.fetch_cumulative_net_value(entity, start, end, http_session) if pd_is_not_null(df): for kdata in kdatas: if kdata.timestamp in df.index: kdata.cumulative_net_value = df.loc[kdata.timestamp, 'LJJZ'] kdata.change_pct = df.loc[kdata.timestamp, 'JZZZL'] session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) session.commit() self.logger.info(f'{entity.code} - {entity.name}累计净值更新完成...')
def on_finish_entity(self, entity, http_session): super().on_finish_entity(entity, http_session) if not self.fetch_jq_timestamp: return # fill the timestamp for report published date the_data_list = get_data( region=self.region, data_schema=self.data_schema, provider=self.provider, entity_id=entity.id, order=self.data_schema.timestamp.asc(), return_type='domain', filters=[ self.data_schema.timestamp == self.data_schema.report_date ]) if the_data_list: if self.data_schema == FinanceFactor: for the_data in the_data_list: self.fill_timestamp_with_jq(entity, the_data) else: df = FinanceFactor.query_data( region=self.region, entity_id=entity.id, columns=[ FinanceFactor.timestamp, FinanceFactor.report_date, FinanceFactor.id ], filters=[ FinanceFactor.timestamp != FinanceFactor.report_date, FinanceFactor.report_date >= the_data_list[0].report_date, FinanceFactor.report_date <= the_data_list[-1].report_date ]) if pd_is_not_null(df): index_df(df, index='report_date', time_field='report_date') for the_data in the_data_list: if pd_is_not_null(df) and the_data.report_date in df.index: the_data.timestamp = df.at[the_data.report_date, 'timestamp'] self.logger.info( 'db fill {} {} timestamp:{} for report_date:{}'. format(self.data_schema, entity.id, the_data.timestamp, the_data.report_date)) session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) session.commit() else: # self.logger.info( # 'waiting jq fill {} {} timestamp:{} for report_date:{}'.format(self.data_schema, # security_item.id, # the_data.timestamp, # the_data.report_date)) self.fill_timestamp_with_jq(entity, the_data)
def get_order_securities(trader_name): items = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=Order).query(Order.entity_id).filter( Order.trader_name == trader_name).group_by( Order.entity_id).all() return [item[0] for item in items]
def clear_trader(trader_name, session=None): if not session: session = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=TraderInfo) session.query(TraderInfo).filter( TraderInfo.trader_name == trader_name).delete() session.query(AccountStats).filter( AccountStats.trader_name == trader_name).delete() session.query(Position).filter( Position.trader_name == trader_name).delete() session.query(Order).filter(Order.trader_name == trader_name).delete() session.commit()
def fill_timestamp_with_jq(self, security_item, the_data): # get report published date from jq df = jq_get_fundamentals(table='indicator', code=to_jq_entity_id(security_item), columns='pubDate', date=to_jq_report_period(the_data.report_date), count=None, parse_dates=['pubDate']) if pd_is_not_null(df): the_data.timestamp = to_pd_timestamp(df['pubDate'][0]) self.logger.info('jq fill {} {} timestamp:{} for report_date:{}'.format( self.data_schema.__name__, security_item.id, the_data.timestamp, the_data.report_date)) session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) session.commit()
def get_group(region: Region, provider: Provider, data_schema, column, group_func=func.count, session=None): if not session: session = get_db_session(region=region, provider=provider, data_schema=data_schema) if group_func: query = session.query(column, group_func(column)).group_by(column) else: query = session.query(column).group_by(column) df = pd.read_sql(query.statement, query.session.bind) return df
def del_data(region: Region, data_schema: Type[Mixin], filters: List = None, provider: Provider = None): if not provider: provider = data_schema.providers[region][0] session = get_db_session(region=region, provider=provider, data_schema=data_schema) query = session.query(data_schema) if filters: for f in filters: query = query.filter(f) query.delete() session.commit()
def on_start(self): # run all the selectors for selector in self.selectors: # run for the history data at first selector.run() if self.entity_ids: entity_ids = json.dumps(self.entity_ids) else: entity_ids = None if self.exchanges: exchanges = json.dumps(self.exchanges) else: exchanges = None if self.codes: codes = json.dumps(self.codes) else: codes = None entity_type = self.entity_schema.__name__.lower() sim_account = TraderInfo( id=self.trader_name, entity_id=f'trader_zvt_{self.trader_name}', timestamp=self.start_timestamp, trader_name=self.trader_name, entity_type=entity_type, entity_ids=entity_ids, exchanges=exchanges, codes=codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp, provider=self.provider, level=self.level.value, real_time=self.real_time, kdata_use_begin_time=self.kdata_use_begin_time, kdata_adjust_type=self.adjust_type.value) session = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=TraderInfo) session.add(sim_account) session.commit()
def recompute_qfq(self, entity, qfq_factor, last_timestamp): # 重新计算前复权数据 if qfq_factor != 0: kdatas = get_kdata(region=self.region, provider=self.provider, entity_id=entity.id, level=self.level.value, order=self.data_schema.timestamp.asc(), return_type='domain', filters=[self.data_schema.timestamp < last_timestamp]) if kdatas: self.logger.info('recomputing {} qfq kdata,factor is:{}'.format(entity.code, qfq_factor)) for kdata in kdatas: kdata.open = round(kdata.open * qfq_factor, 2) kdata.close = round(kdata.close * qfq_factor, 2) kdata.high = round(kdata.high * qfq_factor, 2) kdata.low = round(kdata.low * qfq_factor, 2) session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) session.bulk_save_objects(kdatas) session.commit()
def get_data(region: Region, data_schema, ids: List[str] = None, entity_ids: List[str] = None, entity_id: str = None, codes: List[str] = None, code: str = None, level: Union[IntervalLevel, str] = None, provider: Provider = Provider.Default, columns: List = None, col_label: dict = None, return_type: str = 'df', start_timestamp: Union[pd.Timestamp, str] = None, end_timestamp: Union[pd.Timestamp, str] = None, filters: List = None, session: Session = None, order=None, limit: int = None, index: Union[str, list] = None, time_field: str = 'timestamp', fun=None): assert data_schema is not None assert provider.value is not None assert provider in zvt_context.providers[region] step1 = time.time() precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}' if not session: session = get_db_session(region=region, provider=provider, data_schema=data_schema) time_col = eval('data_schema.{}'.format(time_field)) if fun is not None: query = session.query(fun) elif columns: # support str if type(columns[0]) == str: columns_ = [] for col in columns: assert isinstance(col, str) columns_.append(eval('data_schema.{}'.format(col))) columns = columns_ # make sure get timestamp if time_col not in columns: columns.append(time_col) if col_label: columns_ = [] for col in columns: if col.name in col_label: columns_.append(col.label(col_label.get(col.name))) else: columns_.append(col) columns = columns_ query = session.query(*columns) else: query = session.query(data_schema) if zvt_config['debug'] == 2: cost = precision_str.format(time.time() - step1) logger.debug("get_data query column: {}".format(cost)) if entity_id is not None: query = query.filter(data_schema.entity_id == entity_id) if entity_ids is not None: query = query.filter(data_schema.entity_id.in_(entity_ids)) if code is not None: query = query.filter(data_schema.code == code) if codes is not None: query = query.filter(data_schema.code.in_(codes)) if ids is not None: query = query.filter(data_schema.id.in_(ids)) # we always store different level in different schema,the level param is not useful now # if level: # try: # # some schema has no level,just ignore it # data_schema.level # if type(level) == IntervalLevel: # level = level.value # query = query.filter(data_schema.level == level) # except Exception as _: # pass query = common_filter(query, data_schema=data_schema, start_timestamp=start_timestamp, end_timestamp=end_timestamp, filters=filters, order=order, limit=limit, time_field=time_field) if zvt_config['debug'] == 2: cost = precision_str.format(time.time() - step1) logger.debug("get_data query common: {}".format(cost)) if return_type == 'func': result = query.scalar() return result elif return_type == 'df': df = pd.read_sql(query.statement, query.session.bind, index_col=['id']) if pd_is_not_null(df): if index: df = index_df(df, index=index, time_field=time_field) if zvt_config['debug'] == 2: cost = precision_str.format(time.time() - step1) logger.debug("get_data do query cost: {} type: {} size: {}".format( cost, return_type, len(df))) return df elif return_type == 'domain': # if limit is not None and limit == 1: # result = [query.first()] # else: # result = list(window_query(query, window_size, step1)) # result = list(query.yield_per(window_size)) if zvt_config['debug'] == 2: with profiled(): result = query.all() else: result = query.all() if zvt_config['debug'] == 2: cost = precision_str.format(time.time() - step1) res_cnt = len(result) if result else 0 logger.debug( "get_data do query cost: {} type: {} limit: {} size: {}". format(cost, return_type, limit, res_cnt)) return result elif return_type == 'dict': # if limit is not None and limit == 1: # result = [item.__dict__ for item in query.first()] # else: # result = [item.__dict__ for item in list(window_query(query, window_size, step1))] # result = [item.__dict__ for item in list(query.yield_per(window_size))] if zvt_config['debug'] == 2: with profiled(): result = [item.__dict__ for item in query.all()] else: result = [item.__dict__ for item in query.all()] if zvt_config['debug'] == 2: cost = precision_str.format(time.time() - step1) res_cnt = len(result) if result else 0 logger.debug( "get_data do query cost: {} type: {} limit: {} size: {}". format(cost, return_type, limit, res_cnt)) return result
def process_loop(self, entity, http_session): assert isinstance(entity, StockDetail) step1 = time.time() precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}' self.result = None if entity.exchange == 'sh': fc = "{}01".format(entity.code) if entity.exchange == 'sz': fc = "{}02".format(entity.code) # 基本资料 param = {"color": "w", "fc": fc, "SecurityCode": "SZ300059"} url = 'https://emh5.eastmoney.com/api/GongSiGaiKuang/GetJiBenZiLiao' json_result = sync_post(http_session, url, json=param) if json_result is None: return resp_json = json_result['JiBenZiLiao'] entity.profile = resp_json['CompRofile'] entity.main_business = resp_json['MainBusiness'] entity.date_of_establishment = to_pd_timestamp(resp_json['FoundDate']) # 关联行业 industry = ','.join(resp_json['Industry'].split('-')) entity.industry = industry # 关联概念 entity.sector = resp_json['Block'] # 关联地区 entity.area = resp_json['Provice'] self.sleep() # 发行相关 param = {"color": "w", "fc": fc} url = 'https://emh5.eastmoney.com/api/GongSiGaiKuang/GetFaXingXiangGuan' json_result = sync_post(http_session, url, json=param) if json_result is None: return resp_json = json_result['FaXingXiangGuan'] entity.issue_pe = to_float(resp_json['PEIssued']) entity.price = to_float(resp_json['IssuePrice']) entity.issues = to_float(resp_json['ShareIssued']) entity.raising_fund = to_float((resp_json['NetCollection'])) entity.net_winning_rate = pct_to_float(resp_json['LotRateOn']) session = get_db_session(region=self.region, provider=self.provider, data_schema=self.data_schema) session.commit() cost = precision_str.format(time.time() - step1) prefix = "finish~ " if zvt_config['debug'] else "" postfix = "\n" if zvt_config['debug'] else "" if self.result is not None: self.logger.info( "{}{}, {}, time: {}, size: {:>7,}, date: [ {}, {} ]{}".format( prefix, self.data_schema.__name__, entity.id, cost, self.result[0], self.result[1], self.result[2], postfix)) else: self.logger.info("{}{}, {}, time: {}{}".format( prefix, self.data_schema.__name__, entity.id, cost, postfix))
def update_position(self, current_position, order_amount, current_price, order_type, timestamp): """ :param timestamp: :type timestamp: :param current_position: :type current_position: Position :param order_amount: :type order_amount: :param current_price: :type current_price: :param order_type: :type order_type: """ if order_type == ORDER_TYPE_LONG: need_money = (order_amount * current_price) * (1 + self.slippage + self.buy_cost) if self.account.cash < need_money: if self.rich_mode: self.input_money() else: raise NotEnoughMoneyError() self.account.cash -= need_money # 计算平均价 long_amount = current_position.long_amount + order_amount if long_amount == 0: current_position.average_long_price = 0 current_position.average_long_price = ( current_position.average_long_price * current_position.long_amount + current_price * order_amount) / long_amount current_position.long_amount = long_amount if current_position.trading_t == 0: current_position.available_long += order_amount elif order_type == ORDER_TYPE_SHORT: need_money = (order_amount * current_price) * (1 + self.slippage + self.buy_cost) if self.account.cash < need_money: if self.rich_mode: self.input_money() else: raise NotEnoughMoneyError() self.account.cash -= need_money short_amount = current_position.short_amount + order_amount current_position.average_short_price = ( current_position.average_short_price * current_position.short_amount + current_price * order_amount) / short_amount current_position.short_amount = short_amount if current_position.trading_t == 0: current_position.available_short += order_amount elif order_type == ORDER_TYPE_CLOSE_LONG: self.account.cash += (order_amount * current_price * (1 - self.slippage - self.sell_cost)) # FIXME:如果没卖完,重新计算计算平均价 current_position.available_long -= order_amount current_position.long_amount -= order_amount elif order_type == ORDER_TYPE_CLOSE_SHORT: self.account.cash += 2 * (order_amount * current_position.average_short_price) self.account.cash -= order_amount * current_price * ( 1 + self.slippage + self.sell_cost) current_position.available_short -= order_amount current_position.short_amount -= order_amount # save the order info to db order_id = '{}_{}_{}_{}'.format( self.trader_name, order_type, current_position.entity_id, to_time_str(timestamp, TIME_FORMAT_ISO8601)) order = Order(id=order_id, timestamp=to_pd_timestamp(timestamp), trader_name=self.trader_name, entity_id=current_position.entity_id, order_price=current_price, order_amount=order_amount, order_type=order_type, level=self.level.value, status='success') session = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=Order) session.add(order) session.commit()
def on_trading_close(self, timestamp): self.logger.info('on_trading_close:{}'.format(timestamp)) # remove the empty position self.account.positions = [ position for position in self.account.positions if position.long_amount > 0 or position.short_amount > 0 ] # clear the data which need recomputing the_id = '{}_{}'.format(self.trader_name, to_time_str(timestamp, TIME_FORMAT_ISO8601)) self.account.value = 0 self.account.all_value = 0 for position in self.account.positions: entity_type, _, _ = decode_entity_id(position.entity_id) data_schema = get_kdata_schema(entity_type, level=IntervalLevel.LEVEL_1DAY, adjust_type=self.adjust_type) kdata = get_kdata(provider=self.provider, level=IntervalLevel.LEVEL_1DAY, entity_id=position.entity_id, order=data_schema.timestamp.desc(), end_timestamp=timestamp, limit=1, adjust_type=self.adjust_type) closing_price = kdata['close'][0] position.available_long = position.long_amount position.available_short = position.short_amount if closing_price: if (position.long_amount is not None) and position.long_amount > 0: position.value = position.long_amount * closing_price self.account.value += position.value elif (position.short_amount is not None) and position.short_amount > 0: position.value = 2 * (position.short_amount * position.average_short_price) position.value -= position.short_amount * closing_price self.account.value += position.value # refresh profit position.profit = (closing_price - position.average_long_price ) * position.long_amount position.profit_rate = position.profit / ( position.average_long_price * position.long_amount) else: self.logger.warning( 'could not refresh close value for position:{},timestamp:{}' .format(position.entity_id, timestamp)) position.id = '{}_{}_{}'.format( self.trader_name, position.entity_id, to_time_str(timestamp, TIME_FORMAT_ISO8601)) position.timestamp = to_pd_timestamp(timestamp) position.account_stats_id = the_id self.account.id = the_id self.account.all_value = self.account.value + self.account.cash self.account.closing = True self.account.timestamp = to_pd_timestamp(timestamp) self.account.profit = ( self.account.all_value - self.account.input_money) / self.account.input_money session = get_db_session(region=Region.CHN, provider=Provider.ZVT, data_schema=AccountStats) session.add(self.account) session.commit() account_info = f'on_trading_close,holding size:{len(self.account.positions)} profit:{self.account.profit} input_money:{self.account.input_money} ' \ f'cash:{self.account.cash} value:{self.account.value} all_value:{self.account.all_value}' self.logger.info(account_info)
def df_to_db(df: pd.DataFrame, ref_df: pd.DataFrame, region: Region, data_schema: DeclarativeMeta, provider: Provider, drop_duplicates: bool = True, fix_duplicate_way: str = 'ignore', force_update=False) -> object: step1 = time.time() precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}' if not pd_is_not_null(df): return 0 if drop_duplicates and df.duplicated(subset='id').any(): df.drop_duplicates(subset='id', keep='last', inplace=True) schema_cols = get_schema_columns(data_schema) cols = list(set(df.columns.tolist()) & set(schema_cols)) if not cols: logger.error("{} get columns failed".format(data_schema.__tablename__)) return 0 df = df[cols] # force update mode, delete duplicate id data in db, and write new data back if force_update: ids = df["id"].tolist() if len(ids) == 1: sql = f"delete from {data_schema.__tablename__} where id = '{ids[0]}'" else: sql = f"delete from {data_schema.__tablename__} where id in {tuple(ids)}" session = get_db_session(region=region, provider=provider, data_schema=data_schema) session.execute(sql) session.commit() df_new = df else: if ref_df is None: ref_df = get_data(region=region, provider=provider, columns=['id', 'timestamp'], data_schema=data_schema, return_type='df') if ref_df.empty: df_new = df else: df_new = df[~df.id.isin(ref_df.index)] # 不能单靠ID决定是否新增,要全量比对 # if fix_duplicate_way == 'add': # df_add = df[df.id.isin(ref_df.index)] # if not df_add.empty: # df_add.id = uuid.uuid1() # df_new = pd.concat([df_new, df_add]) cost = precision_str.format(time.time() - step1) logger.debug("remove duplicated: {}".format(cost)) saved = 0 if pd_is_not_null(df_new): saved = to_postgresql(region, df_new, data_schema.__tablename__) cost = precision_str.format(time.time() - step1) logger.debug("write db: {}, size: {}".format(cost, saved)) return saved