def get_code_list_by_types(instrument_types: list, all_if_none=True) -> list: """ 输入 instrument_type 列表,返回对应的所有合约列表 :param instrument_types: 可以使 instrument_type 列表 也可以是 (instrument_type,exchange)列表 :param all_if_none 如果 instrument_types 为 None 则返回全部合约代码 :return: wind_code_list list of (wind_code, exchange) """ wind_code_list = [] if all_if_none and instrument_types is None: sql_str = f"SELECT order_book_id, `exchange`, trading_code as symbol FROM rqdatac_future_info" with with_db_session(engine_md) as session: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for order_book_id, exchange, symbol in table.fetchall(): wind_code_list.append((order_book_id, exchange, symbol)) else: for instrument_type in instrument_types: if isinstance(instrument_type, tuple): instrument_type, exchange = instrument_type else: exchange = None sql_str = f"select order_book_id, `exchange`, trading_code as symbol from rqdatac_future_info " \ f"where underlying_symbol=:instrument_type" with with_db_session(engine_md) as session: table = session.execute( sql_str, params={"instrument_type": instrument_type}) # 获取date_from,date_to,将date_from,date_to做为value值 for order_book_id, exchange, symbol in table.fetchall(): wind_code_list.append((order_book_id, exchange, symbol)) return wind_code_list
def fill_col_by_wss(col_name_dic, table_name): """补充历史col数据 :param col_name_dic: :param table_name: :return: """ # 股票列表 col_name_list = [col_name.lower() for col_name in col_name_dic.keys()] # 获取每只股票ipo 日期 及 最小的交易日前一天 sql_str = """select wind_code from %s""" % table_name with with_db_session(engine_md) as session: table = session.execute(sql_str) wind_code_set = {content[0] for content in table.fetchall()} data_count = len(wind_code_set) data_df_list = [] try: # for n, (wind_code, (date_from, date_to)) in enumerate(stock_trade_date_range_dic.items()): for data_num, wind_code in enumerate(wind_code_set, start=1): if wind_code not in wind_code_set: continue # 获取股票量价等行情数据 wind_indictor_str = col_name_list data_df = invoker.wss(wind_code, wind_indictor_str) if data_df is None: logger.warning('%d) %s has no data during %s %s', data_num, wind_code) continue logger.debug('%d/%d) 获取 %s', data_num, data_count, wind_code) # data_df['wind_code'] = wind_code data_df_list.append(data_df) # 仅供调试使用 # if data_num > 10: # break finally: # 导入数据库 if len(data_df_list) > 0: data_df_all = pd.concat(data_df_list) data_df_all.index.rename('wind_code', inplace=True) data_df_all.reset_index(inplace=True) # 只有全部列为空的项才需要剔除 is_na_s = None for col_name in col_name_dic.keys(): col_name = col_name.upper() if is_na_s is None: is_na_s = data_df_all[col_name].isna() else: is_na_s = is_na_s & data_df_all[col_name].isna() data_df_not_null = data_df_all[~is_na_s] data_df_not_null.fillna('null', inplace=True) data_dic_list = data_df_not_null.to_dict(orient='records') sql_str = "update %s set " % table_name + \ ",".join(["%s=:%s" % (db_col_name, col_name.upper()) for col_name, db_col_name in col_name_dic.items()]) + \ " where wind_code=:wind_code" with with_db_session(engine_md) as session: table = session.execute(sql_str, params=data_dic_list) logger.info('%d data updated', data_df_not_null.shape[0]) else: logger.warning('no data for update')
def get_wind_code_list_by_types( instrument_types: list, all_if_none=True, lasttrade_date_lager_than_n_days_before: Optional[int] = 30) -> list: """ 输入 instrument_type 列表,返回对应的所有合约列表 :param instrument_types: 可以使 instrument_type 列表 也可以是 (instrument_type,exchange)列表 :param all_if_none 如果 instrument_types 为 None 则返回全部合约代码 :param lasttrade_date_lager_than_n_days_before 仅返回最后一个交易日 大于 N 日前日期的合约 :return: wind_code_list """ wind_code_list = [] if all_if_none and instrument_types is None: sql_str = f"select wind_code from wind_future_info" with with_db_session(engine_md) as session: if lasttrade_date_lager_than_n_days_before is not None: date_from_str = date_2_str(date.today() - timedelta( days=lasttrade_date_lager_than_n_days_before)) sql_str += " where lasttrade_date > :last_trade_date" table = session.execute( sql_str, params={"last_trade_date": date_from_str}) else: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for row in table.fetchall(): wind_code = row[0] wind_code_list.append(wind_code) else: for instrument_type in instrument_types: if isinstance(instrument_type, tuple): instrument_type, exchange = instrument_type else: exchange = None # re.search(r"(?<=RB)\d{4}(?=\.SHF)", 'RB2101.SHF') # pattern = re.compile(r"(?<=" + instrument_type + r")\d{4}(?=\." + exchange + ")") # MySql: REGEXP 'rb[:digit:]+.[:alpha:]+' # 参考链接: https://blog.csdn.net/qq_22238021/article/details/80929518 sql_str = f"select wind_code from wind_future_info where wind_code " \ f"REGEXP '^{instrument_type}[:digit:]+.{'[:alpha:]+' if exchange is None else exchange}'" with with_db_session(engine_md) as session: if lasttrade_date_lager_than_n_days_before is not None: date_from_str = date_2_str(date.today() - timedelta( days=lasttrade_date_lager_than_n_days_before)) sql_str += " and lasttrade_date > :last_trade_date" table = session.execute( sql_str, params={"last_trade_date": date_from_str}) else: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for row in table.fetchall(): wind_code = row[0] wind_code_list.append(wind_code) return wind_code_list
def import_tushare_adj_factor(chain_param=None, ): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name = 'tushare_stock_daily_adj_factor' logging.info("更新 %s 开始", table_name) has_table = engine_md.has_table(table_name) # 进行表格判断,确定是否含有tushare_stock_daily # 下面一定要注意引用表的来源,否则可能是串,提取混乱!!!比如本表是tushare_daily_basic,所以引用的也是这个,如果引用错误,就全部乱了l if has_table: sql_str = """ select cal_date FROM ( select * from tushare_trade_date trddate where( cal_date>(SELECT max(trade_date) FROM {table_name})) )tt where (is_open=1 and cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) and exchange='SSE') """.format(table_name=table_name) else: sql_str = """ SELECT cal_date FROM tushare_trade_date trddate WHERE (trddate.is_open=1 AND cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) AND exchange='SSE') ORDER BY cal_date""" logger.warning('%s 不存在,仅使用 tushare_stock_info 表进行计算日期范围', table_name) with with_db_session(engine_md) as session: # 获取交易日数据 table = session.execute(sql_str) trddate = list(row[0] for row in table.fetchall()) try: for i in range(len(trddate)): trade_date = datetime_2_str(trddate[i], STR_FORMAT_DATE_TS) data_df = pro.adj_factor(ts_code='', trade_date=trade_date) if len(data_df) > 0: data_count = bunch_insert_on_duplicate_update(data_df, table_name, engine_md, DTYPE_TUSHARE_STOCK_DAILY_ADJ_FACTOR) logging.info(" %s 表自 %s 日起的 %d 条信息被更新", table_name, trade_date, data_count) else: logging.info("无数据信息可被更新") finally: if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) # build_primary_key([table_name]) create_pk_str = """ALTER TABLE {table_name} CHANGE COLUMN `ts_code` `ts_code` VARCHAR(20) NOT NULL FIRST, CHANGE COLUMN `trade_date` `trade_date` DATE NOT NULL AFTER `ts_code`, ADD PRIMARY KEY (`ts_code`, `trade_date`)""".format(table_name=table_name) with with_db_session(engine_md) as session: session.execute(create_pk_str) logger.info('%s 表 `ts_code`, `trade_date` 主键设置完成', table_name)
def wind_future_daily_2_model_server(chain_param=None, instrument_types=None): from tasks.config import config from tasks.backend import engine_dic table_name = 'wind_future_daily' engine_model_server = engine_dic[config.DB_SCHEMA_MODEL] has_table = engine_model_server.has_table(table_name) if not has_table: logger.error('当前数据库 %s 没有 %s 表,建议使用先建立相应的数据库表后再进行导入操作', engine_model_server, table_name) return wind_code_list = get_wind_code_list_by_types(instrument_types) wind_code_count = len(wind_code_list) for n, wind_code in enumerate(wind_code_list, start=1): # symbol, exchange = wind_code.split('.') sql_str = f"select max(trade_date) from wind_future_daily where wind_code = :wind_code" with with_db_session(engine_model_server) as session: trade_date_max = session.scalar(sql_str, params={'wind_code': wind_code}) # 读取日线数据 if trade_date_max is None: sql_str = "select * from wind_future_daily where wind_code = %s and `close` <> 0" df = pd.read_sql(sql_str, engine_md, params=[wind_code]).dropna() else: sql_str = "select * from wind_future_daily where wind_code = %s and trade_date > %s and `close` <> 0" df = pd.read_sql(sql_str, engine_md, params=[wind_code, trade_date_max]).dropna() df_len = df.shape[0] if df_len == 0: continue df.to_sql(table_name, engine_model_server, if_exists='append', index=False) logger.info("%d/%d) %s %d data -> %s", n, wind_code_count, wind_code, df.shape[0], table_name)
def stg_run_ending(self): """ 处理策略结束相关事项 释放策略资源 更新策略执行信息 :return: """ self.stg_base.release() # 更新数据库 td_to 字段 with with_db_session(engine_ibats) as session: session.query(StgRunInfo).filter( StgRunInfo.stg_run_id == self.stg_run_id).update( {StgRunInfo.dt_to: datetime.now()}) try: session.bulk_save_objects(self.stg_run_status_detail_list) # sql_str = StgRunInfo.update().where( # StgRunInfo.c.stg_run_id == self.stg_run_id).values(dt_to=datetime.now()) # session.execute(sql_str) session.commit() self.logger.debug("%d 条 stg_run_status_detail 被保存", len(self.stg_run_status_detail_list)) except SQLAlchemyError: logger.exception("%d 条 stg_run_status_detail 被保存时发生异常", len(self.stg_run_status_detail_list)) session.rollback() self.is_working = False self.is_done = True
def merge_latest(chain_param=None, ): """ 将 cmc_coin_v1_daily 历史数据 以及 cmc_coin_pro_latest 最新价格数据 合并到 cmc_coin_merged_latest :return: """ table_name = 'cmc_coin_merged_latest' logger.info("开始合并数据到 %s 表", table_name) has_table = engine_md.has_table(table_name) create_sql_str = """CREATE TABLE {table_name} ( `id` VARCHAR(60) NOT NULL, `date` DATE NOT NULL, `datetime` DATETIME NULL, `name` VARCHAR(60) NULL, `symbol` VARCHAR(20) NULL, `close` DOUBLE NULL, `volume` DOUBLE NULL, `market_cap` DOUBLE NULL, PRIMARY KEY (`id`, `date`)) ENGINE = MyISAM""".format(table_name=table_name) with with_db_session(engine_md) as session: if not has_table: session.execute(create_sql_str) logger.info("创建 %s 表", table_name) session.execute('truncate table {table_name}'.format(table_name=table_name)) insert_sql_str = """INSERT INTO `{table_name}` (`id`, `date`, `datetime`, `name`, `symbol`, `close`, `volume`, `market_cap`) select daily.id, `date`, `date`, `name`, `symbol`, `close`, `volume`, `market_cap` from cmc_coin_v1_daily daily left join cmc_coin_v1_info info on daily.id = info.id""".format(table_name=table_name) session.execute(insert_sql_str) session.commit() insert_latest_sql_str = """INSERT INTO `{table_name}` (`id`, `date`, `datetime`, `name`, `symbol`, `close`, `volume`, `market_cap`) select info.id, date(latest.last_updated), latest.last_updated, latest.name, latest.symbol, price, volume_24h, market_cap from cmc_coin_pro_latest latest left join ( select latest.name, latest.symbol, max(latest.last_updated) last_updated from cmc_coin_pro_latest latest group by latest.name, latest.symbol ) g on latest.name = g.name and latest.symbol = g.symbol and latest.last_updated = g.last_updated left outer join cmc_coin_v1_info info on latest.name = info.name and latest.symbol = info.symbol on duplicate key update `datetime`=values(`datetime`), `name`=values(`name`), `symbol`=values(`symbol`), `close`=values(`close`), `volume`=values(`volume`), `market_cap`=values(`market_cap`)""".format(table_name=table_name) session.execute(insert_latest_sql_str) session.commit() data_count = session.execute("select count(*) from {table_name}".format(table_name=table_name)).scalar() logger.info("%d 条记录插入到 %s", data_count, table_name)
def _record_order_detail(self, symbol, price: float, vol: int, direction: Direction, action: Action): order_date = self.curr_timestamp.date() order_detail = OrderDetail( stg_run_id=self.stg_run_id, trade_agent_key=self.agent_name, order_date=order_date, order_time=self.curr_timestamp.time(), order_millisec=0, direction=int(direction), action=int(action), symbol=symbol, order_price=float(price), order_vol=int(vol), calc_mode=self.calc_mode, ) if config.BACKTEST_UPDATE_OR_INSERT_PER_ACTION: with with_db_session(engine_ibats, expire_on_commit=False) as session: session.add(order_detail) session.commit() self.order_detail_list.append(order_detail) self._order_detail_dic.setdefault(symbol, []).append(order_detail) # 更新成交信息 # Order_2_Deal 模式:下单即成交 if self.trade_mode == BacktestTradeMode.Order_2_Deal: self._record_trade_detail(order_detail) else: self.un_finished_order_list.append(order_detail)
def get_stg_run_id_latest(): """获取最新的 stg_run_id""" engine_ibats = engines.engine_ibats with with_db_session(engine_ibats) as session: stg_run_id = session.query(func.max(StgRunInfo.stg_run_id)).scalar() return stg_run_id
def add_new_col_data(col_name, param, chain_param=None, db_col_name=None, col_type_str='DOUBLE', ths_code_set: set = None): """ 1)修改 daily 表,增加字段 2)ckpv表增加数据 3)第二部不见得1天能够完成,当第二部完成后,将ckvp数据更新daily表中 :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用 :param col_name:增加字段名称 :param param: 参数 :param db_col_name: 默认为 None,此时与col_name相同 :param col_type_str: DOUBLE, VARCHAR(20), INTEGER, etc. 不区分大小写 :param ths_code_set: 默认 None, 否则仅更新指定 ths_code :return: """ table_name = 'ifind_stock_hk_daily_ds' if db_col_name is None: # 默认为 None,此时与col_name相同 db_col_name = col_name # 检查当前数据库是否存在 db_col_name 列,如果不存在则添加该列 add_col_2_table(engine_md, table_name, db_col_name, col_type_str) # 将数据增量保存到 ckdvp 表 all_finished = add_data_2_ckdvp(col_name, param, ths_code_set) # 将数据更新到 ds 表中 if all_finished: sql_str = """update {table_name} daily, ifind_ckdvp_stock_hk ckdvp set daily.{db_col_name} = ckdvp.value where daily.ths_code = ckdvp.ths_code and daily.time = ckdvp.time and ckdvp.key = '{db_col_name}' and ckdvp.param = '{param}'""".format(db_col_name=db_col_name, param=param, table_name=table_name) with with_db_session(engine_md) as session: session.execute(sql_str) session.commit() logger.info('更新 %s 字段 %s 表', db_col_name, table_name)
def tushare_to_sqlite_pre_ts_code(file_name, table_name, field_pair_list): """ 将Mysql数据导入到sqlite,全量读取然后导出 速度慢,出发内存比较少,或需要导出的数据不多,否则不需要使用 :param file_name: :param table_name: :return: """ logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name) sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) conn = sqlite3.connect(db_file_path) sql_str = f"select ts_code from {table_name} group by ts_code" with with_db_session(engine_md) as session: table = session.execute(sql_str) code_list = list([row[0] for row in table.fetchall()]) code_count, data_count = len(code_list), 0 for num, (ts_code) in enumerate(code_list, start=1): code_exchange = ts_code.split('.') sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}" sql_str = f"select * from {table_name} where ts_code=%s" # where code = '000001.XSHE' df = pd.read_sql(sql_str, engine_md, params=[ts_code]) # if field_pair_list is not None: field_list = [_[0] for _ in field_pair_list] field_list.append('ts_code') df_tot = df_tot[field_list].rename(columns=dict(field_pair_list)) df_len = df.shape[0] data_count += df_len logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
def load_by_wind_code_desc(instrument_types): wind_code_set, year_month_set = set(), set() for instrument_type, exchange in instrument_types: # re.search(r"(?<=RB)\d{4}(?=\.SHF)", 'RB2101.SHF') pattern = re.compile(r"(?<=" + instrument_type + r")\d{4}(?=\." + exchange + ")") sql_str = f"""select wind_code from wind_future_info where wind_code like '{instrument_type}%.{exchange}'""" with with_db_session(engine_md) as session: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for row in table.fetchall(): wind_code = row[0] match = pattern.search(wind_code) if match is None: continue wind_code_set.add(wind_code) year_month_set.add(match.group()) year_month_list = list(year_month_set) year_month_list.sort(reverse=True) wind_code_list = [ [ f'{instrument_type}{_}.{exchange}' for instrument_type, exchange in instrument_types if f'{instrument_type}{_}.{exchange}' in wind_code_set ] for _ in year_month_list] for _ in wind_code_list: import_future_daily(None, wind_code_set=set(_))
def get_all_instrument_type(): sql_str = "select ths_code from ifind_future_daily group by ths_code" with with_db_session(engine_md) as session: instrument_list = [row[0] for row in session.execute(sql_str).fetchall()] re_pattern_instrument_type = re.compile(r'\D+(?=\d{3})', re.IGNORECASE) instrument_type_set = {re_pattern_instrument_type.search(name).group() for name in instrument_list} return list(instrument_type_set)
def get_main_secondary_contract_by_instrument_types(instrument_types=None): if instrument_types is None or len(instrument_types) == 0: sql_str = """SELECT t.instrument_type, t.Contract, t.ContractNext FROM wind_future_continuous_adj t inner join ( SELECT instrument_type, max(trade_date) trade_date_max FROM wind_future_continuous_adj group by instrument_type ) latest on t.instrument_type = latest.instrument_type and t.trade_date = latest.trade_date_max""" else: sql_str = f"""SELECT t.instrument_type, t.Contract, t.ContractNext FROM wind_future_continuous_adj t inner join ( SELECT instrument_type, max(trade_date) trade_date_max FROM wind_future_continuous_adj where instrument_type in ('{"','".join(instrument_types)}') group by instrument_type ) latest on t.instrument_type = latest.instrument_type and t.trade_date = latest.trade_date_max""" with with_db_session(engine_md) as session: table = session.execute(sql_str) wind_code_list = [] for instrument_type, main, secondary in table.fetchall(): if main is not None: wind_code_list.append(main) if secondary is not None: wind_code_list.append(secondary) return wind_code_list
def import_coin_info(chain_param=None): """获取全球交易币基本信息""" table_name = 'tushare_coin_info' has_table = engine_md.has_table(table_name) # 设置 dtype dtype = { 'coin': String(60), 'en_name': String(60), 'cn_name': String(60), 'issue_date': Date, 'amount': DOUBLE, } coinlist_df = pro.coinlist(start_date='20170101', end_date=date_2_str(date.today(), DATE_FORMAT_STR)) data_count = bunch_insert_on_duplicate_update(coinlist_df, table_name, engine_md, dtype) logging.info("更新 %s 完成 新增数据 %d 条", table_name, data_count) if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) create_pk_str = """ALTER TABLE {table_name} CHANGE COLUMN `coin` `coin` VARCHAR(60) NOT NULL FIRST, CHANGE COLUMN `en_name` `en_name` VARCHAR(60) NOT NULL AFTER `coin`, ADD PRIMARY KEY (`coin`, `en_name`)""".format( table_name=table_name) with with_db_session(engine_md) as session: session.execute(create_pk_str)
def get_exchange_latest_data(): sql_str = """SELECT exch_eng, max(ths_start_trade_date_future) FROM ifind_future_info group by exch_eng""" with with_db_session(engine_md) as session: table = session.execute(sql_str) exchange_latest_ipo_date_dic = dict(table.fetchall()) return exchange_latest_ipo_date_dic
def update_df_2_db(instrument_type, table_name, data_df, dtype=None): """将 DataFrame 数据保存到 数据库对应的表中""" # 为了解决 AttributeError: 'numpy.float64' object has no attribute 'translate' 错误,需要将数据类型转换成 float data_df["adj_factor_main"] = data_df["adj_factor_main"].apply(str_2_float) data_df["adj_factor_secondary"] = data_df["adj_factor_secondary"].apply( str_2_float) # 清理历史记录 with with_db_session(engine_md) as session: sql_str = "SELECT table_name FROM information_schema.TABLES " \ "WHERE table_name = :table_name and TABLE_SCHEMA=(select database())" # 复权数据表 is_existed = session.execute(sql_str, params={ "table_name": table_name }).fetchone() if is_existed is not None: session.execute( "delete from %s where instrument_type = :instrument_type" % table_name, params={"instrument_type": instrument_type}) logger.debug("删除 %s 中的 %s 历史数据,重新载入新的复权数据", table_name, instrument_type) # 插入数据库 # pd.DataFrame.to_sql(data_df, table_name, engine_md, if_exists='append', index=False, dtype=dtype) bunch_insert_on_duplicate_update( data_df, table_name, engine_md, dtype=dtype, myisam_if_create_table=True, primary_keys=['trade_date', 'instrument_type'], schema=config.DB_SCHEMA_MD)
def check_contract_has_no_missing(instrument_type): """ 检查期货品种历史合约数据是否齐全,tushare_future_basic 表中所列合约,tushare_future_daily_md 是否存在对应的数据 :param instrument_type: :return: """ sql_str = r"""select t1.ts_code, t1.delist_date from ( select ts_code, delist_date from tushare_future_basic where fut_code = :fut_code and delist_date is not null ) t1 left join ( select distinct ts_code from tushare_future_daily_md where ts_code in ( select ts_code from tushare_future_basic where fut_code=:fut_code and delist_date is not null) ) t2 on t1.ts_code=t2.ts_code where t2.ts_code is null order by delist_date""" miss_data_dic = {} with with_db_session(engine_md) as session: table = session.execute(sql_str, params={"fut_code": instrument_type}) for ts_code, delist_date in table.fetchall(): logger.info('缺少 %s 合约数据,交割日 %s', ts_code, delist_date) miss_data_dic[ts_code] = delist_date return miss_data_dic
def save(self): self.logger.info("更新 %s 开始", self.table_name) has_table = engine_md.has_table(self.table_name) # 判断表是否已经存在 if has_table: with with_db_session(engine_md) as session: sql_str = f"""select trade_date from jq_trade_date where trade_date>(select max(day) from {self.table_name}) order by trade_date""" table = session.execute(sql_str, params={"trade_date": self.BASE_DATE}) trade_date_list = [_[0] for _ in table.fetchall()] date_start = execute_scalar(sql_str, engine_md) self.logger.info('查询 %s 数据使用起始日期 %s', self.table_name, date_2_str(date_start)) else: with with_db_session(engine_md) as session: sql_str = "select trade_date from jq_trade_date where trade_date>=:trade_date order by trade_date" table = session.execute(sql_str, params={"trade_date": self.BASE_DATE}) trade_date_list = [_[0] for _ in table.fetchall()] self.logger.warning('%s 不存在,使用基础日期 %s', self.table_name, self.BASE_DATE) # 查询最新的 trade_date_list.sort() data_count_tot, for_count = 0, len(trade_date_list) try: for num, trade_date in enumerate(trade_date_list): q = query(self.statement) df = get_fundamentals(q, date=date_2_str(trade_date)) if df is None or df.shape[0] == 0: continue logger.debug('%d/%d) %s 包含 %d 条数据', num, for_count, trade_date, df.shape[0]) data_count = bunch_insert_on_duplicate_update( df, self.table_name, engine_md, dtype=self.dtype, myisam_if_create_table=True, primary_keys=['id'], schema=config.DB_SCHEMA_MD) data_count_tot += data_count except: logger.exception("更新 %s 异常", self.table_name) finally: # 导入数据库 logging.info("更新 %s 结束 %d 条信息被更新", self.table_name, data_count_tot)
def import_tushare_adj_factor(chain_param=None, ): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name = 'tushare_stock_daily_adj_factor' primary_keys = ["ts_code", "trade_date"] logging.info("更新 %s 开始", table_name) # 进行表格判断,确定是否含有 table_name has_table = engine_md.has_table(table_name) # sqlite_file_name = 'eDB_adjfactor.db' check_sqlite_db_primary_keys(table_name, primary_keys) if has_table: sql_str = """ select cal_date FROM ( select * from tushare_trade_date trddate where( cal_date>(SELECT max(trade_date) FROM {table_name})) )tt where (is_open=1 and cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) and exchange='SSE') """.format(table_name=table_name) else: sql_str = """ SELECT cal_date FROM tushare_trade_date trddate WHERE (trddate.is_open=1 AND cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) AND exchange='SSE') ORDER BY cal_date""" logger.warning('%s 不存在,仅使用 tushare_stock_info 表进行计算日期范围', table_name) with with_db_session(engine_md) as session: # 获取交易日数据 table = session.execute(sql_str) trade_date_list = [row[0] for row in table.fetchall()] trade_date_count, data_count_tot = len(trade_date_list), 0 try: for num, trade_date in enumerate(trade_date_list, start=1): trade_date = datetime_2_str(trade_date, STR_FORMAT_DATE_TS) data_df = pro.adj_factor(ts_code='', trade_date=trade_date) if data_df is not None and data_df.shape[0] > 0: data_count = bunch_insert( data_df, table_name=table_name, dtype=DTYPE_TUSHARE_STOCK_DAILY_ADJ_FACTOR, primary_keys=primary_keys) data_count_tot += data_count logging.info("%d/%d) %s 表 %s %d 条信息被更新", num, trade_date_count, table_name, trade_date, data_count) else: logging.info("%d/%d) %s 表 %s 数据信息可被更新", num, trade_date_count, table_name, trade_date) except: logger.exception("更新 %s 异常", table_name) finally: logging.info("%s 表 %d 条记录更新完成", table_name, data_count_tot)
def get_rr_with_md(stg_run_id, compound_rr=True): """ 获取策略收益率数据 :param stg_run_id: :param compound_rr:复合收益率 :return: """ engine_ibats = engines.engine_ibats # 获取 收益曲线 with with_db_session(engine_ibats) as session: if compound_rr: sql_str = str( session.query( StgRunStatusDetail.trade_dt.label('trade_dt'), StgRunStatusDetail.cash_and_margin.label('cash and margin'), (StgRunStatusDetail.cash_and_margin.label('cash_and_margin') + StgRunStatusDetail.commission_tot.label('commission_tot')).label('no_commission'), StgRunStatusDetail.rr_compound.label('rr'), StgRunStatusDetail.rr_compound_nc.label('rr no commission'), ).filter( StgRunStatusDetail.stg_run_id == stg_run_id ) ) else: sql_str = str( session.query( StgRunStatusDetail.trade_dt.label('trade_dt'), StgRunStatusDetail.cash_and_margin.label('cash and margin'), (StgRunStatusDetail.cash_and_margin.label('cash_and_margin') + StgRunStatusDetail.commission_tot.label('commission_tot')).label('no_commission'), StgRunStatusDetail.rr.label('rr'), StgRunStatusDetail.rr_nc.label('rr no commission'), ).filter( StgRunStatusDetail.stg_run_id == stg_run_id ) ) rr_df = pd.read_sql(sql_str, engine_ibats, params=[stg_run_id]).set_index('trade_dt') if rr_df is None or rr_df.shape[0] == 0: return None, None # rr_df['rr'] = rr_df['cash and margin'] / rr_df['cash and margin'].iloc[0] # rr_df['rr without commission'] = rr_df['without commission'] / rr_df['without commission'].iloc[0] col_list_rr = ['rr', 'rr no commission'] rr_df[col_list_rr] += 1 # 获取行情数据 sum_df, symbol_rr_dic = get_md(stg_run_id) sum_df = sum_df.join(rr_df[col_list_rr]) col_list = ['md_rr'] col_list.extend(col_list) for num, (key, (df, close_key)) in enumerate(symbol_rr_dic.items()): md_df = df[[close_key]].copy() md_df['md_rr'] = md_df[close_key] / md_df[close_key].iloc[0] md_df = md_df.join(rr_df)[col_list] symbol_rr_dic[key] = (md_df, close_key) return sum_df, symbol_rr_dic
def get_all_instrument_type(): """获取合约类型列表""" sql_str = "select fut_code from tushare_future_basic group by fut_code" with with_db_session(engine_md) as session: instrument_type_list = [ row[0] for row in session.execute(sql_str).fetchall() ] return instrument_type_list
def import_tushare_namechange(chain_param=None): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name = 'tushare_stock_namechange' logging.info("更新 %s 开始", table_name) has_table = engine_md.has_table(table_name) if has_table: sql_str = """select max(start_date) start_date FROM md_integration.tushare_stock_namechange""" else: sql_str = """select min(list_date) start_date FROM md_integration.tushare_stock_info""" with with_db_session(engine_md) as session: # 获取交易日数据 table = session.execute(sql_str) start_date = list(row[0] for row in table.fetchall()) start_date = datetime_2_str(start_date[0], STR_FORMAT_DATE_TS) end_date = datetime_2_str(date.today(), STR_FORMAT_DATE_TS) try: data_df = pro.namechange( start_date=start_date, end_date=end_date, fields='ts_code,name,start_date,end_date,change_reason') if len(data_df) > 0: data_count = bunch_insert_on_duplicate_update( data_df, table_name, engine_md, DTYPE_TUSHARE_STOCK_NAMECHANGE) logging.info("更新 %s 结束 %d 条上市公司更名信息被更新", table_name, data_count) else: logging.info("无数据信息可被更新") finally: if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) # build_primary_key([table_name]) create_pk_str = """ALTER TABLE {table_name} CHANGE COLUMN `ts_code` `ts_code` VARCHAR(20) NOT NULL FIRST, CHANGE COLUMN `start_date` `start_date` DATE NOT NULL AFTER `ts_code`, ADD PRIMARY KEY (`ts_code`, `start_date`)""".format( table_name=table_name) with with_db_session(engine_md) as session: session.execute(create_pk_str) logger.info('%s 表 `ts_code`, `start_date` 主键设置完成', table_name)
def repair_table(): datetime_start = datetime.now() with with_db_session(engine_md) as session: session.execute('REPAIR TABLE pytdx_stock_tick USE_FRM') datetime_end = datetime.now() span = datetime_end - datetime_start print('花费时间 ', span)
def import_tushare_suspend(chain_param=None): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name = 'tushare_stock_daily_suspend' logging.info("更新 %s 开始", table_name) has_table = engine_md.has_table(table_name) # 进行表格判断,确定是否含有tushare_suspend # 下面一定要注意引用表的来源,否则可能是串,提取混乱!!!比如本表是tushare_daily_basic,所以引用的也是这个,如果引用错误,就全部乱了l if has_table: sql_str = """ select cal_date FROM ( select * from tushare_trade_date trddate where( cal_date>(SELECT max(suspend_date) FROM {table_name} )) )tt where (is_open=1 and cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) and exchange='SSE') """.format(table_name=table_name) else: sql_str = """ SELECT cal_date FROM tushare_trade_date trddate WHERE (trddate.is_open=1 AND cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) AND exchange='SSE') ORDER BY cal_date""" logger.warning('%s 不存在,仅使用 tushare_stock_info 表进行计算日期范围', table_name) with with_db_session(engine_md) as session: # 获取交易日数据 table = session.execute(sql_str) trade_date_list = list(row[0] for row in table.fetchall()) try: trade_date_list_len = len(trade_date_list) for num, trade_date in enumerate(trade_date_list, start=1): trade_date = datetime_2_str(trade_date, STR_FORMAT_DATE_TS) data_df = pro.suspend(ts_code='', suspend_date=trade_date, resume_date='', fields='') if len(data_df) > 0: data_count = bunch_insert_p( data_df, table_name=table_name, dtype=DTYPE_TUSHARE_SUSPEND, primary_keys=['ts_code', 'suspend_date']) logging.info("%d/%d) %s 更新 %s 结束 %d 条信息被更新", num, trade_date_list_len, trade_date, table_name, data_count) else: logging.info("%s 当日无停牌股票", trade_date_list_len) except: logger.exception('更新 %s 表异常', table_name)
def update_future_info_hk(chain_param=None): """ 更新 香港股指 期货合约列表信息 香港恒生指数期货,香港国企指数期货合约只有07年2月开始的合约,且无法通过 wset 进行获取 :param chain_param: 在celery 中將前面結果做爲參數傳給後面的任務 :return: """ table_name = "wind_future_info_hk" has_table = engine_md.has_table(table_name) param_list = [ ("ipo_date", Date), ("sec_name", String(50)), ("sec_englishname", String(50)), ("exch_eng", String(50)), ("lasttrade_date", Date), ("lastdelivery_date", Date), ("dlmonth", String(50)), ("lprice", Date), ("sccode", String(50)), ("margin", Date), ("punit", String(50)), ("changelt", Date), ("mfprice", Date), ("contractmultiplier", DOUBLE), ("ftmargins", String(100)), ("trade_code", String(50)), ] wind_indictor_str = ",".join([key for key, _ in param_list]) dtype = {key: val for key, val in param_list} dtype['wind_code'] = String(20) logger.info("更新 wind_future_info_hk 开始") # 获取已存在合约列表 sql_str = 'select wind_code, ipo_date from wind_future_info_hk' with with_db_session(engine_md) as session: table = session.execute(sql_str) wind_code_ipo_date_dic = dict(table.fetchall()) # 获取合约列表 # 手动生成合约列表 # 香港恒生指数期货,香港国企指数期货合约只有07年2月开始的合约,且无法通过 wset 进行获取 wind_code_list = ['%s%02d%02d.HK' % (name, year, month) for name, year, month in itertools.product(['HSIF', 'HHIF'], range(7, 19), range(1, 13)) if not (year == 7 and month == 1)] # 获取合约基本信息 # w.wss("AU1706.SHF,AG1612.SHF,AU0806.SHF", "ipo_date,sec_name,sec_englishname,exch_eng,lasttrade_date,lastdelivery_date,dlmonth,lprice,sccode,margin,punit,changelt,mfprice,contractmultiplier,ftmargins,trade_code") # future_info_df = wss_cache(w, wind_code_list, # "ipo_date,sec_name,sec_englishname,exch_eng,lasttrade_date,lastdelivery_date,dlmonth,lprice,sccode,margin,punit,changelt,mfprice,contractmultiplier,ftmargins,trade_code") if len(wind_code_list) > 0: future_info_df = invoker.wss(wind_code_list, wind_indictor_str) future_info_df['MFPRICE'] = future_info_df['MFPRICE'].apply(mfprice_2_num) future_info_df.rename(columns={c: str.lower(c) for c in future_info_df.columns}, inplace=True) future_info_df.index.rename('wind_code', inplace=True) future_info_df = future_info_df[~(future_info_df['ipo_date'].isna() | future_info_df['lasttrade_date'].isna())] future_info_df.reset_index(inplace=True) future_info_count = future_info_df.shape[0] bunch_insert_on_duplicate_update(future_info_df, table_name, engine_md, dtype=dtype) logger.info("更新 wind_future_info_hk 结束 %d 条记录被更新", future_info_count)
def min_to_vnpy_increment(chain_param=None, instrument_types=None): from tasks.config import config from tasks.backend import engine_dic table_name = 'dbbardata' interval = '1m' engine_vnpy = engine_dic[config.DB_SCHEMA_VNPY] has_table = engine_vnpy.has_table(table_name) if not has_table: logger.error('当前数据库 %s 没有 %s 表,建议使用 vnpy先建立相应的数据库表后再进行导入操作', engine_vnpy, table_name) return sql_increment_str = "select trade_datetime `datetime`, `open` open_price, high high_price, " \ "`low` low_price, `close` close_price, volume, position as open_interest " \ "from wind_future_min where wind_code = %s and " \ "trade_datetime > %s and `close` is not null and `close` <> 0" sql_whole_str = "select trade_datetime `datetime`, `open` open_price, high high_price, " \ "`low` low_price, `close` close_price, volume, position as open_interest " \ "from wind_future_min where wind_code = %s and " \ "`close` is not null and `close` <> 0" wind_code_list = get_wind_code_list_by_types(instrument_types) wind_code_count = len(wind_code_list) for n, wind_code in enumerate(wind_code_list, start=1): symbol, exchange = wind_code.split('.') if exchange in WIND_VNPY_EXCHANGE_DIC: exchange_vnpy = WIND_VNPY_EXCHANGE_DIC[exchange] else: logger.warning('%s exchange: %s 在交易所列表中不存在', wind_code, exchange) exchange_vnpy = exchange sql_str = f"select max(`datetime`) from {table_name} where symbol=:symbol and `interval`=:interval" with with_db_session(engine_vnpy) as session: datetime_exist = session.scalar(sql_str, params={ 'symbol': symbol, 'interval': interval }) if datetime_exist is not None: # 读取日线数据 df = pd.read_sql(sql_increment_str, engine_md, params=[wind_code, datetime_exist]).dropna() else: df = pd.read_sql(sql_whole_str, engine_md, params=[wind_code]).dropna() df_len = df.shape[0] if df_len == 0: continue df['symbol'] = symbol df['exchange'] = exchange_vnpy df['interval'] = interval datetime_latest = df['datetime'].max().to_pydatetime() df.to_sql(table_name, engine_vnpy, if_exists='append', index=False) logger.info("%d/%d) %s (%s ~ %s] %d data -> %s interval %s", n, wind_code_count, symbol, datetime_2_str(datetime_exist), datetime_2_str(datetime_latest), df_len, table_name, interval)
def import_tushare_stock_pledge_stat(chain_param=None, ts_code_set=None): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name = 'tushare_stock_pledge_stat' logging.info("更新 %s 开始", table_name) has_table = engine_md.has_table(table_name) # 进行表格判断,确定是否含有tushare_stock_daily sql_str = """SELECT ts_code FROM tushare_stock_info """ logger.warning('使用 tushare_stock_info 表确认需要提取股票质押数据的范围') with with_db_session(engine_md) as session: # 获取交易日数据 table = session.execute(sql_str) ts_code_list = list(row[0] for row in table.fetchall()) data_df_list, data_count, all_data_count, data_len = [], 0, 0, len( ts_code_list) logger.info('%d 只股票的质押信息将被插入 tushare_stock_pledge_stat 表', data_len) # 将data_df数据,添加到data_df_list Cycles = 1 try: for ts_code in ts_code_list: data_df = invoke_pledge_stat(ts_code=ts_code) logger.warning('提取 %s 质押信息 %d 条', ts_code, len(data_df)) # 把数据攒起来 if data_df.shape[0] > 0: data_count += data_df.shape[0] data_df_list.append(data_df) # 大于阀值有开始插入 if data_count >= 10000 and len(data_df_list) > 0: data_df_all = pd.concat(data_df_list) data_count = bunch_insert_on_duplicate_update( data_df_all, table_name, engine_md, DTYPE_TUSHARE_STOCK_PLEDGE_STAT) logger.warning('更新股票质押信息 %d 条', data_count) all_data_count += data_count data_df_list, data_count = [], 0 # 仅调试使用 Cycles = Cycles + 1 if DEBUG and Cycles > 2: break finally: if len(data_df_list) > 0: data_df_all = pd.concat(data_df_list) data_count = bunch_insert_on_duplicate_update( data_df_all, table_name, engine_md, DTYPE_TUSHARE_STOCK_PLEDGE_STAT) all_data_count = all_data_count + data_count logging.info("更新 %s 结束,总共 %d 条信息被更新", table_name, all_data_count)
def get_instrument_last_trade_date_dic() -> dict: sql_str = """SELECT ths_code, ths_last_td_date_future FROM ifind_future_info where ths_start_trade_date_future is not null and ths_last_td_date_future is not null and ths_last_delivery_date_future is not null""" with with_db_session(engine_md) as session: table = session.execute(sql_str) instrument_last_trade_date_dic = dict(table.fetchall()) return instrument_last_trade_date_dic
def show_cash_and_margin(stg_run_id, enable_show_plot=True, enable_save_plot=False, run_mode=RunMode.Backtest, **kwargs): """ plot cash_and_margin :param stg_run_id: :param enable_show_plot: :param enable_save_plot: :param run_mode: :param kwargs: :return: """ # stg_run_id=154 engine_ibats = engines.engine_ibats with with_db_session(engine_ibats) as session: if stg_run_id is None: stg_run_id = session.query(func.max(StgRunInfo.stg_run_id)).scalar() logger.warning('没有设置 stg_run_id 参数,将输出最新的 stg_run_id=%d 对应记录', stg_run_id) sql_str = str( session.query( StgRunStatusDetail.trade_dt.label('trade_dt'), StgRunStatusDetail.cash_available.label('cash'), (StgRunStatusDetail.cash_available.label('cash') + StgRunStatusDetail.commission_tot ).label('cash + commission'), StgRunStatusDetail.curr_margin.label('margin'), StgRunStatusDetail.cash_and_margin.label('cash_and_margin'), (StgRunStatusDetail.cash_and_margin.label('cash_and_margin') + StgRunStatusDetail.commission_tot ).label('no commission'), ).filter( StgRunStatusDetail.stg_run_id == stg_run_id ) ) df = pd.read_sql(sql_str, engine_ibats, params=[stg_run_id], index_col=['trade_dt']) if df.shape[0] == 0: file_path = None return df, file_path ax = df[['cash', 'margin']].plot.area() if run_mode != RunMode.Backtest_FixPercent: df[['cash_and_margin', 'no commission', 'cash + commission']].plot(ax=ax) ax.set_title( f"Cash + Margin [{stg_run_id}] " f"{date_2_str(min(df.index))} - {date_2_str(max(df.index))} ({df.shape[0]} days)") if enable_save_plot: file_name = get_file_name(f'cash_and_margin', name=stg_run_id) file_path = os.path.join(get_cache_folder_path(), file_name) plt.savefig(file_path, dpi=75) else: file_path = None if enable_show_plot: plt.show() return df, file_path