def get_index_constituent_2_dic(index_code, index_name, date_start, idx_start, date_constituent_df_dict, idx_constituent_set_dic): """ 通过 wind 获取指数成分股及权重 存入 date_constituent_df_dict idx_constituent_set_dic :param index_code: :param index_name: :param date_start: :param idx_start: :param date_constituent_df_dict: :param idx_constituent_set_dic: :return: """ if date_start in date_constituent_df_dict and idx_start in idx_constituent_set_dic: date_start_str = date_2_str(date_start) logger.debug('%s %s %s 成分股 已经存在 直接返回', date_start_str, index_name, index_code) sec_df = date_constituent_df_dict[date_start] constituent_set = idx_constituent_set_dic[idx_start] else: sec_df = get_sectorconstituent(index_code, index_name, date_start) if sec_df is None or sec_df.shape[0] == 0: date_start_str = date_2_str(date_start) logger.warning('%s 无法获取到 %s %s 成分股', date_start_str, index_name, index_code) # raise ValueError('%s 无法获取到 %s %s 成分股' % (date_start_str, index_name, index_code)) return None, None date_constituent_df_dict[date_start] = sec_df # constituent_set = set(sec_df['wind_code']) constituent_set = { tuple(val) for key, val in sec_df[['wind_code', 'weight']].T.items() } idx_constituent_set_dic[idx_start] = constituent_set return sec_df, constituent_set
def get_sectorconstituent_2_dic(sector_code, sector_name, date_start, idx_start, trade_date_list_sorted, date_constituent_df_dict, idx_constituent_set_dic): """ 通过 wind 获取板块成分股 存入 date_constituent_df_dict idx_constituent_set_dic :param sector_code: :param sector_name: :param date_start: :param idx_start: :param trade_date_list_sorted: :param date_constituent_df_dict: :param idx_constituent_set_dic: :return: """ if date_start in date_constituent_df_dict and idx_start in idx_constituent_set_dic: date_start_str = date_2_str(date_start) logger.debug('%s %s %s 成分股 已经存在 直接返回', date_start_str, sector_name, sector_code) sec_df = date_constituent_df_dict[date_start] constituent_set_left = idx_constituent_set_dic[idx_start] else: sec_df = get_sectorconstituent(sector_code, sector_name, date_start) if sec_df is None or sec_df.shape[0] == 0: date_start_str = date_2_str(date_start) logger.warning('%s 无法获取到 %s %s 成分股', date_start_str, sector_name, sector_code) raise ValueError('%s 无法获取到 %s %s 成分股' % (date_start_str, sector_name, sector_code)) date_constituent_df_dict[date_start] = sec_df constituent_set_left = set(sec_df['wind_code']) idx_constituent_set_dic[idx_start] = constituent_set_left return sec_df, constituent_set_left
def valid_model_acc(self, factor_df: pd.DataFrame): xs, ys, _ = self.get_x_y(factor_df) trade_date_from_str, trade_date_to_str = date_2_str( factor_df.index[0]), date_2_str(factor_df.index[-1]) random_state = self.predict_test_random_state xs_train, xs_validation, ys_train, ys_validation = train_test_split( xs, ys, test_size=0.2, random_state=random_state) self.logger.debug( 'random_state=%d, xs_train %s, ys_train %s, xs_validation %s, ys_validation %s, [%s, %s]', random_state, xs_train.shape, ys_train.shape, xs_validation.shape, ys_validation.shape, trade_date_from_str, trade_date_to_str) sess = self.get_session(renew=True) with sess.as_default(): # with tf.Graph().as_default(): # self.logger.debug('sess.graph:%s tf.get_default_graph():%s', sess.graph, tf.get_default_graph()) result = self.model.evaluate(xs_validation, ys_validation, batch_size=self.batch_size) val_acc = result[0] result = self.model.evaluate(xs_train, ys_train, batch_size=self.batch_size) train_acc = result[0] self.logger.info("[%s - %s] 训练集准确率: %.2f%%", trade_date_from_str, trade_date_to_str, train_acc * 100) self.logger.info("[%s - %s] 验证集准确率: %.2f%%", trade_date_from_str, trade_date_to_str, val_acc * 100) return train_acc, val_acc
def __init__(self, date_from, date_to, get_stg_handler, q_table_key=None): self.action_space = ['empty', 'hold_long', 'hold_short'] self.n_actions = len(self.action_space) self.q_table_key = q_table_key self.stg_handler = None self.train_date_from, self.train_date_to = date_2_str(date_from), date_2_str(date_to) self.get_stg_handler = get_stg_handler self._build()
def get_wind_code_list_by_types( instrument_types: list, all_if_none=True, lasttrade_date_lager_than_n_days_before: Optional[int] = 30) -> list: """ 输入 instrument_type 列表,返回对应的所有合约列表 :param instrument_types: 可以使 instrument_type 列表 也可以是 (instrument_type,exchange)列表 :param all_if_none 如果 instrument_types 为 None 则返回全部合约代码 :param lasttrade_date_lager_than_n_days_before 仅返回最后一个交易日 大于 N 日前日期的合约 :return: wind_code_list """ wind_code_list = [] if all_if_none and instrument_types is None: sql_str = f"select wind_code from wind_future_info" with with_db_session(engine_md) as session: if lasttrade_date_lager_than_n_days_before is not None: date_from_str = date_2_str(date.today() - timedelta( days=lasttrade_date_lager_than_n_days_before)) sql_str += " where lasttrade_date > :last_trade_date" table = session.execute( sql_str, params={"last_trade_date": date_from_str}) else: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for row in table.fetchall(): wind_code = row[0] wind_code_list.append(wind_code) else: for instrument_type in instrument_types: if isinstance(instrument_type, tuple): instrument_type, exchange = instrument_type else: exchange = None # re.search(r"(?<=RB)\d{4}(?=\.SHF)", 'RB2101.SHF') # pattern = re.compile(r"(?<=" + instrument_type + r")\d{4}(?=\." + exchange + ")") # MySql: REGEXP 'rb[:digit:]+.[:alpha:]+' # 参考链接: https://blog.csdn.net/qq_22238021/article/details/80929518 sql_str = f"select wind_code from wind_future_info where wind_code " \ f"REGEXP '^{instrument_type}[:digit:]+.{'[:alpha:]+' if exchange is None else exchange}'" with with_db_session(engine_md) as session: if lasttrade_date_lager_than_n_days_before is not None: date_from_str = date_2_str(date.today() - timedelta( days=lasttrade_date_lager_than_n_days_before)) sql_str += " and lasttrade_date > :last_trade_date" table = session.execute( sql_str, params={"last_trade_date": date_from_str}) else: table = session.execute(sql_str) # 获取date_from,date_to,将date_from,date_to做为value值 for row in table.fetchall(): wind_code = row[0] wind_code_list.append(wind_code) return wind_code_list
def get_wind_kv_per_year(wind_code, wind_indictor_str, date_from, date_to, params): """ \ :param wind_code: :param wind_indictor_str: :param date_from: :param date_to: :param params: "year=%(year)d;westPeriod=180"== > "year=2018;westPeriod=180" :return: """ date_from, date_to = str_2_date(date_from), str_2_date(date_to) # 以年底为分界线,将日期范围截取成以自然年为分段的日期范围 date_pair = [] if date_from <= date_to: date_curr = date_from while True: date_new_year = str_2_date("%d-01-01" % (date_curr.year + 1)) date_year_end = date_new_year - timedelta(days=1) if date_to < date_year_end: date_pair.append((date_curr, date_to)) break else: date_pair.append((date_curr, date_year_end)) date_curr = date_new_year data_df_list = [] for date_from_sub, date_to_sub in date_pair: params_sub = params % {'year': (date_from_sub.year + 1)} try: data_df = invoker.wsd(wind_code, wind_indictor_str, date_from_sub, date_to_sub, params_sub) except APIError as exp: logger.exception("%s %s [%s ~ %s] %s 执行异常", wind_code, wind_indictor_str, date_2_str(date_from_sub), date_2_str(date_to_sub), params_sub) if exp.ret_dic.setdefault('error_code', 0) in ( -40520007, # 没有可用数据 -40521009, # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月 ): continue else: raise exp if data_df is None: logger.warning('%s %s [%s ~ %s] has no data', wind_code, wind_indictor_str, date_2_str(date_from_sub), date_2_str(date_to_sub)) continue data_df.dropna(inplace=True) if data_df.shape[0] == 0: # logger.warning('%s %s [%s ~ %s] has 0 data', # wind_code, wind_indictor_str, date_2_str(date_from_sub), date_2_str(date_to_sub)) continue data_df_list.append(data_df) # 合并数据 data_df_tot = pd.concat(data_df_list) if len(data_df_list) > 0 else None return data_df_tot
def import_coin_info(chain_param=None): """获取全球交易币基本信息""" table_name = 'tushare_coin_info' has_table = engine_md.has_table(table_name) # 设置 dtype dtype = { 'coin': String(60), 'en_name': String(60), 'cn_name': String(60), 'issue_date': Date, 'amount': DOUBLE, } coinlist_df = pro.coinlist(start_date='20170101', end_date=date_2_str(date.today(), DATE_FORMAT_STR)) data_count = bunch_insert_on_duplicate_update(coinlist_df, table_name, engine_md, dtype) logging.info("更新 %s 完成 新增数据 %d 条", table_name, data_count) if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) create_pk_str = """ALTER TABLE {table_name} CHANGE COLUMN `coin` `coin` VARCHAR(60) NOT NULL FIRST, CHANGE COLUMN `en_name` `en_name` VARCHAR(60) NOT NULL AFTER `coin`, ADD PRIMARY KEY (`coin`, `en_name`)""".format( table_name=table_name) with with_db_session(engine_md) as session: session.execute(create_pk_str)
def default(self, obj): # print("obj.__class__", obj.__class__, "isinstance(obj.__class__, DeclarativeMeta)", # isinstance(obj.__class__, DeclarativeMeta)) if isinstance(obj.__class__, DeclarativeMeta): # an SQLAlchemy class fields = {} for field in [ x for x in dir(obj) if not x.startswith('_') and x != 'metadata' ]: data = obj.__getattribute__(field) try: json.dumps( data ) # this will fail on non-encodable values, like other classes fields[field] = data except TypeError: # 添加了对datetime的处理 print(data) if isinstance(data, datetime): fields[field] = data.isoformat() elif isinstance(data, date): fields[field] = data.isoformat() elif isinstance(data, timedelta): fields[field] = (datetime.min + data).time().isoformat() else: fields[field] = None # a json-encodable dict return fields elif isinstance(obj, date): return json.dumps(date_2_str(obj)) return json.JSONEncoder.default(self, obj)
def get_sectorconstituent(index_code, index_name, target_date) -> pd.DataFrame: """ 通过 wind 获取指数成分股及权重 :param index_code: :param index_name: :param target_date: :return: """ target_date_str = date_2_str(target_date) logger.info('获取 %s %s %s 板块信息', index_code, index_name, target_date) sec_df = invoker.wset( "indexconstituent", "date=%s;windcode=%s" % (target_date_str, index_code)) if sec_df is not None and sec_df.shape[0] > 0: # 发现部分情况下返回数据的日期与 target_date 日期不匹配 sec_df = sec_df[sec_df['date'].apply( lambda x: str_2_date(x) == target_date)] if sec_df is None or sec_df.shape[0] == 0: return None sec_df["index_code"] = index_code sec_df["index_name"] = index_name sec_df.rename(columns={ 'date': 'trade_date', 'sec_name': 'stock_name', 'i_weight': 'weight', }, inplace=True) return sec_df
def save(self): self.logger.info("更新 %s 开始", self.table_name) has_table = engine_md.has_table(self.table_name) # 判断表是否已经存在 if has_table: with with_db_session(engine_md) as session: sql_str = f"""select trade_date from jq_trade_date where trade_date>(select max(day) from {self.table_name}) order by trade_date""" table = session.execute(sql_str, params={"trade_date": self.BASE_DATE}) trade_date_list = [_[0] for _ in table.fetchall()] date_start = execute_scalar(sql_str, engine_md) self.logger.info('查询 %s 数据使用起始日期 %s', self.table_name, date_2_str(date_start)) else: with with_db_session(engine_md) as session: sql_str = "select trade_date from jq_trade_date where trade_date>=:trade_date order by trade_date" table = session.execute(sql_str, params={"trade_date": self.BASE_DATE}) trade_date_list = [_[0] for _ in table.fetchall()] self.logger.warning('%s 不存在,使用基础日期 %s', self.table_name, self.BASE_DATE) # 查询最新的 trade_date_list.sort() data_count_tot, for_count = 0, len(trade_date_list) try: for num, trade_date in enumerate(trade_date_list): q = query(self.statement) df = get_fundamentals(q, date=date_2_str(trade_date)) if df is None or df.shape[0] == 0: continue logger.debug('%d/%d) %s 包含 %d 条数据', num, for_count, trade_date, df.shape[0]) data_count = bunch_insert_on_duplicate_update( df, self.table_name, engine_md, dtype=self.dtype, myisam_if_create_table=True, primary_keys=['id'], schema=config.DB_SCHEMA_MD) data_count_tot += data_count except: logger.exception("更新 %s 异常", self.table_name) finally: # 导入数据库 logging.info("更新 %s 结束 %d 条信息被更新", self.table_name, data_count_tot)
def get_df_iter(self, date_start, date_end, step, df_len_limit=3000, deep=0): """ 获取日期范围内的数据,当数据记录大于上限条数时,将日期范围进行二分法拆分,迭代进行查询 :param date_start: :param date_end: :param step: :param df_len_limit: :param deep: :return: """ for num, (date_from, date_to) in enumerate(iter_2_range( range_date(date_start, date_end, step), has_left_outer=False, has_right_outer=False), start=1): q = query(self.statement).filter( self.statement.pub_date > date_2_str(date_from), self.statement.pub_date <= date_2_str(date_to)) df = finance.run_query(q) df_len = df.shape[0] if df_len >= df_len_limit: if step >= 2: self.logger.warning( '%s%s%d) [%s ~ %s] 包含 %d 条数据,可能已经超越 %d 条提取上限,开始进一步分割日期', self.table_name, ' ' * deep, num, date_from, date_to, df_len, df_len_limit) yield from self.get_df_iter(date_from, date_to, step // 2, deep=deep + 1) else: self.logger.warning( '%s%s%d) [%s ~ %s] 包含 %d 条数据,可能已经超越 %d 条提取上限且无法再次分割日期范围,手动需要补充提取剩余数据', self.table_name, ' ' * deep, num, date_from, date_to, df_len, df_len_limit) yield df, date_from, date_to else: self.logger.debug('%s%s%d) [%s ~ %s] 包含 %d 条数据', self.table_name, ' ' * deep, num, date_from, date_to, df_len) yield df, date_from, date_to
def stat_fund(date_from, date_to): sql_str = """ SELECT (@rowNum:=@rowNum+1) AS rowNo, t.* FROM ( SELECT date_from.ts_code, basic.name, basic.management, date_from.date_from, nav_to.end_date, nav_from.accum_nav nav_from, nav_to.accum_nav nav_to, nav_to.accum_nav/nav_from.accum_nav pct_chg FROM ( SELECT ts_code, max(end_date) date_from FROM tushare_fund_nav WHERE end_date<= :date_from GROUP BY ts_code ) date_from JOIN ( SELECT ts_code, max(end_date) date_to FROM tushare_fund_nav WHERE end_date<= :date_to GROUP BY ts_code ) date_to ON date_from.ts_code = date_to.ts_code JOIN tushare_fund_nav nav_from ON date_from.ts_code = nav_from.ts_code AND date_from.date_from = nav_from.end_date JOIN tushare_fund_nav nav_to ON date_to.ts_code = nav_to.ts_code AND date_to.date_to = nav_to.end_date JOIN tushare_fund_basic basic ON date_from.ts_code = basic.ts_code WHERE basic.name NOT LIKE '%B%' and basic.name NOT LIKE '%A%' and basic.name NOT LIKE '%C%' HAVING nav_to.accum_nav IS NOT NULL AND nav_from.accum_nav IS NOT NULL and pct_chg != 1 and pct_chg < 2 ORDER BY nav_to.accum_nav/nav_from.accum_nav ) t""" # data_df = pd.read_sql(sql_str, engine_md) with with_db_session(engine_md) as session: session.execute("Select (@rowNum :=0) ;") table = session.execute(sql_str, params={ 'date_from': date_2_str(date_from), 'date_to': date_2_str(date_to) }) data = [[d for d in row] for row in table.fetchall()] data_df = pd.DataFrame(data, columns=[ 'rowNo', 'ts_code', 'name', 'management', 'date_from', 'date_to', 'nav_from', 'nav_to', 'pct_chg' ]) return data_df.describe()['pct_chg']
def trade_date_list(file_path=None): if file_path is None: file_path = get_export_path('trade_date.csv') with with_db_session_p() as session: sql_str = "select cal_date from tushare_trade_date where exchange='SSE' and is_open=1" table = session.execute(sql_str) ret_list = [date_2_str(_[0]) for _ in table.fetchall()] pd.DataFrame({'trade_date': ret_list}).to_csv(file_path, index=False)
def save_adj_factor(instrument_type: str, method: ReversionRightsMethod, db_table_name='wind_future_adj_factor', to_csv_dir_path=None, generate_reversion_rights_factors_func: Callable = generate_reversion_rights_factors): """ :param instrument_type: 合约类型 :param method: 合约类型 :param db_table_name: 保存到数据库名称,None 为不保存数据库 :param to_csv_dir_path: 是否保存到csv文件 :param generate_reversion_rights_factors_func: 生成复权因子的函数 :return: """ logger.info("生成 %s 复权因子[%s]", instrument_type, method.name) adj_factor_df, trade_date_latest = generate_reversion_rights_factors_func( instrument_type, method=method) if adj_factor_df is None: return if to_csv_dir_path is not None: csv_file_name = f'adj_factor_{instrument_type}_{method.name}.csv' folder_path = os.path.join(to_csv_dir_path, date_2_str(trade_date_latest)) csv_file_path = os.path.join(folder_path, csv_file_name) os.makedirs(folder_path, exist_ok=True) adj_factor_df.to_csv(csv_file_path, index=False) if db_table_name is not None: dtype = { 'trade_date': Date, 'instrument_id_main': String(20), 'adj_factor_main': DOUBLE, 'instrument_id_secondary': String(20), 'adj_factor_secondary': DOUBLE, 'instrument_type': String(20), 'method': String(20), } adj_factor_df['method'] = method.name update_df_2_db(instrument_type, db_table_name, adj_factor_df, method=method, dtype=dtype) logger.info( "生成 %s 复权因子 %d 条记录[%s]", # \n%s instrument_type, adj_factor_df.shape[0], method.name # , adj_factor_df )
def init_state(self, md_df: pd.DataFrame): # 每一次新 episode 需要重置 state, action self.last_state = None self.last_action = None if self.do_train: # for_train == False 当期为策略运行使用,在 on_prepare 阶段以及 on_period 定期进行重新训练 trade_date_s = md_df['trade_date'] self.train_date_from = pd.to_datetime(trade_date_s.iloc[0]) + self.retrain_period trade_date_to = trade_date_s.iloc[-1] self.train(date_2_str(trade_date_to)) _ = self.choose_action(md_df)
def import_jq_stock_income(chain_param=None, ts_code_set=None): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ logger.info("更新 %s 开始", TABLE_NAME) has_table = engine_md.has_table(TABLE_NAME) # 判断表是否已经存在 if has_table: sql_str = f"""select max(pub_date) from {TABLE_NAME}""" date_start = execute_scalar(sql_str, engine_md) logger.info('查询 %s 数据使用起始日期 %s', TABLE_NAME, date_2_str(date_start)) else: date_start = BASE_DATE logger.warning('%s 不存在,使用基础日期 %s', TABLE_NAME, date_2_str(date_start)) # 查询最新的 pub_date date_end = datetime.date.today() if date_start >= date_end: logger.info('%s 已经是最新数据,无需进一步获取', date_start) return data_count_tot = 0 try: for num, (df, date_from, date_to) in enumerate( get_df_iter(date_start, date_end, LOOP_STEP)): # logger.debug('%d) [%s ~ %s] 包含 %d 条数据', num, date_from, date_to, df.shape[0]) data_count = bunch_insert_on_duplicate_update( df, TABLE_NAME, engine_md, dtype=DTYPE, myisam_if_create_table=True, primary_keys=['id'], schema=config.DB_SCHEMA_MD) data_count_tot += data_count finally: # 导入数据库 logging.info("更新 %s 结束 %d 条信息被更新", TABLE_NAME, data_count_tot)
def save_adj_factor(instrument_types: list, to_db=True, to_csv=True): """ :param instrument_types: 合约类型 :param to_db: 是否保存到数据库 :param to_csv: 是否保存到csv文件 :param method: division 除法 diff 差值发 :return: """ dir_path = 'output' if to_csv: # 建立 output folder os.makedirs(dir_path, exist_ok=True) for method in Method: for n, instrument_type in enumerate(instrument_types): logger.info("生成 %s 复权因子", instrument_type) adj_factor_df, trade_date_latest = generate_reversion_rights_factors( instrument_type, method=method) if adj_factor_df is None: continue if to_csv: csv_file_name = f'adj_factor_{instrument_type}_{method.name}.csv' folder_path = os.path.join(dir_path, date_2_str(trade_date_latest)) csv_file_path = os.path.join(folder_path, csv_file_name) os.makedirs(folder_path, exist_ok=True) adj_factor_df.to_csv(csv_file_path, index=False) if to_db: table_name = 'wind_future_adj_factor' dtype = { 'trade_date': Date, 'instrument_id_main': String(20), 'adj_factor_main': DOUBLE, 'instrument_id_secondary': String(20), 'adj_factor_secondary': DOUBLE, 'instrument_type': String(20), 'method': String(20), } adj_factor_df['method'] = method.name update_df_2_db(instrument_type, table_name, adj_factor_df, dtype) logger.info( "生成 %s 复权因子 %s 条记录", # \n%s instrument_type, adj_factor_df.shape[0] # , adj_factor_df )
def save(self): self.logger.info("更新 %s 开始", self.table_name) has_table = engine_md.has_table(self.table_name) # 判断表是否已经存在 if has_table: sql_str = f"""select max(pub_date) from {self.table_name}""" date_start = execute_scalar(sql_str, engine_md) self.logger.info('查询 %s 数据使用起始日期 %s', self.table_name, date_2_str(date_start)) else: date_start = self.BASE_DATE self.logger.warning('%s 不存在,使用基础日期 %s', self.table_name, date_2_str(date_start)) # 查询最新的 pub_date date_end = datetime.date.today() if date_start >= date_end: self.logger.info('%s %s 已经是最新数据,无需进一步获取', self.table_name, date_start) return data_count_tot = 0 try: for num, (df, date_from, date_to) in enumerate( self.get_df_iter(date_start, date_end, self.loop_step)): # logger.debug('%d) [%s ~ %s] 包含 %d 条数据', num, date_from, date_to, df.shape[0]) if df is not None and df.shape[0] > 0: data_count = bunch_insert_on_duplicate_update( df, self.table_name, engine_md, dtype=self.dtype, myisam_if_create_table=True, primary_keys=['id'], schema=config.DB_SCHEMA_MD) data_count_tot += data_count finally: # 导入数据库 logging.info("更新 %s 结束 %d 条信息被更新", self.table_name, data_count_tot)
def choose_action(self, md_df: pd.DataFrame): trade_date_to = pd.to_datetime(md_df['trade_date'].iloc[-1]) if self.do_train and trade_date_to > (self.train_date_latest + self.retrain_period): # for_train == False 当期为策略运行使用,在 on_prepare 阶段以及 on_period 定期进行重新训练 self.train(date_2_str(trade_date_to)) elif self.ql_table is None: self.init_ql_table(trade_date_to) state, reward = self.get_state_reward(md_df) if self.last_state is not None: self.ql_table.learn(self.last_state, self.last_action, reward, state) self.last_action = self.ql_table.choose_action(state) self.last_state = state return self.last_action
def save_model(self, trade_date): """ 将模型导出到文件 :param trade_date: :return: """ folder_path = os.path.join(self.model_folder_path, date_2_str(trade_date)) if not os.path.exists(folder_path): os.makedirs(folder_path) file_path = os.path.join( folder_path, f"model_{int(self.label_func_min_rr * 10000)}_{int(self.label_func_max_rr * 10000)}.tfl") self.model.save(file_path) self.logger.info("模型训练截止日期: %s 保存到: %s", self.trade_date_last_train, file_path) return file_path
def get_sectorconstituent(sector_code, sector_name, target_date) -> pd.DataFrame: """ 通过 wind 获取板块成分股 :param sector_code: :param sector_name: :param target_date: :return: """ target_date_str = date_2_str(target_date) logger.info('获取 %s %s %s 板块信息', sector_code, sector_name, target_date) sec_df = invoker.wset( "sectorconstituent", "date=%s;sectorid=%s" % (target_date_str, sector_code)) sec_df["sector_code"] = sector_code sec_df["sector_name"] = sector_name sec_df.rename(columns={ 'date': 'trade_date', 'sec_name': 'stock_name', }, inplace=True) return sec_df
def df_2_table(doc, df, format_by_index=None, format_by_col=None, max_col_count=None, mark_top_n=None, mark_top_n_on_cols=None): """ :param doc: :param df: :param format_by_index: 按索引格式化 :param format_by_col: 按列格式化 :param max_col_count: 每行最大列数(不包括索引) :param mark_top_n: 标记 top N :param mark_top_n_on_cols: 选择哪些列标记 top N,None 代表不标记 :return: """ if max_col_count is None: max_col_count = df.shape[1] if mark_top_n is not None: if mark_top_n_on_cols is not None: rank_df = df[mark_top_n_on_cols] else: rank_df = df rank_df = rank_df.rank(ascending=False) is_in_rank_df = rank_df <= mark_top_n else: is_in_rank_df = None for table_num, col_name_list in enumerate( split_chunk(list(df.columns), max_col_count)): if table_num > 0: # 如果是换行写入第二、三、四。。个表格,先打一个空行 doc.add_paragraph('') sub_df = df[col_name_list] row_num, col_num = sub_df.shape t = doc.add_table(row_num + 1, col_num + 1) # write head # col_name_list = list(sub_df.columns) for j in range(col_num): # t.cell(0, j).text = df.columns[j] # paragraph = t.cell(0, j).add_paragraph() paragraph = t.cell(0, j + 1).paragraphs[0] paragraph.add_run(str(col_name_list[j])).bold = True paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER # write head bg color for j in range(col_num + 1): # t.cell(0, j).text = df.columns[j] t.cell(0, j)._tc.get_or_add_tcPr().append( parse_xml(r'<w:shd {} w:fill="00A2E8"/>'.format(nsdecls('w')))) # format table style to be a grid t.style = 'TableGrid' # populate the table with the dataframe for i in range(row_num): index = sub_df.index[i] paragraph = t.cell(i + 1, 0).paragraphs[0] index_str = str(date_2_str(index)) paragraph.add_run(index_str).bold = True paragraph.alignment = WD_ALIGN_PARAGRAPH.LEFT if format_by_index is not None and index in format_by_index: format_row = format_by_index[index] else: format_row = None for j in range(col_num): col_name = col_name_list[j] if format_row is None and format_by_col is not None and col_name in format_by_col: format_cell = format_by_col[col_name] else: format_cell = format_row content = sub_df.values[i, j] if format_cell is None: text = str(content) elif isinstance(format_cell, str): text = str.format(format_cell, content) elif callable(format_cell): text = format_cell(content) else: raise ValueError('%s: %s 无效', index, format_cell) paragraph = t.cell(i + 1, j + 1).paragraphs[0] paragraph.alignment = WD_ALIGN_PARAGRAPH.RIGHT try: style = paragraph.add_run(text) if is_in_rank_df is not None and col_name in is_in_rank_df and is_in_rank_df.loc[ index, col_name]: style.font.color.rgb = RGBColor(0xed, 0x1c, 0x24) style.bold = True except TypeError as exp: logger.exception('df.iloc[%d, %d] = df["%s", "%s"] = %s', i, j, index, col_name, text) raise exp from exp for i in range(1, row_num + 1): for j in range(col_num + 1): if i % 2 == 0: t.cell(i, j)._tc.get_or_add_tcPr().append( parse_xml(r'<w:shd {} w:fill="A3D9EA"/>'.format( nsdecls('w'))))
def import_index_daily(chain_param=None): """导入指数数据 :param chain_param: 在celery 中將前面結果做爲參數傳給後面的任務 :return: """ table_name = "wind_index_daily" has_table = engine_md.has_table(table_name) col_name_param_list = [ ('open', DOUBLE), ('high', DOUBLE), ('low', DOUBLE), ('close', DOUBLE), ('volume', DOUBLE), ('amt', DOUBLE), ('turn', DOUBLE), ('free_turn', DOUBLE), ] wind_indictor_str = ",".join([key for key, _ in col_name_param_list]) rename_col_dic = { key.upper(): key.lower() for key, _ in col_name_param_list } dtype = {key: val for key, val in col_name_param_list} dtype['wind_code'] = String(20) # TODO: 'trade_date' 声明为 Date 类型后,插入数据库会报错,目前原因不详,日后再解决 # dtype['trade_date'] = Date, # yesterday = date.today() - timedelta(days=1) # date_ending = date.today() - ONE_DAY if datetime.now().hour < BASE_LINE_HOUR else date.today() # sql_str = """select wii.wind_code, wii.sec_name, ifnull(adddate(latest_date, INTERVAL 1 DAY), wii.basedate) date_from # from wind_index_info wii left join # ( # select wind_code,index_name, max(trade_date) as latest_date # from wind_index_daily group by wind_code # ) daily # on wii.wind_code=daily.wind_code""" # with with_db_session(engine_md) as session: # table = session.execute(sql_str) # wind_code_date_from_dic = {wind_code: (sec_name, date_from) for wind_code, sec_name, date_from in table.fetchall()} # with with_db_session(engine_md) as session: # # 获取市场有效交易日数据 # sql_str = "select trade_date from wind_trade_date where trade_date > '2005-1-1'" # table = session.execute(sql_str) # trade_date_sorted_list = [t[0] for t in table.fetchall()] # trade_date_sorted_list.sort() # date_to = get_last(trade_date_sorted_list, lambda x: x <= date_ending) # data_len = len(wind_code_date_from_dic) if has_table: sql_str = """ SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to FROM ( SELECT info.wind_code, ifnull(trade_date, basedate) date_frm, null, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM wind_index_info info LEFT OUTER JOIN (SELECT wind_code, adddate(max(trade_date),1) trade_date FROM {table_name} GROUP BY wind_code) daily ON info.wind_code = daily.wind_code ) tt WHERE date_frm <= if(null<end_date, null, end_date) ORDER BY wind_code""".format(table_name=table_name) else: logger.warning('%s 不存在,仅使用 wind_index_info 表进行计算日期范围', table_name) sql_str = """ SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to FROM ( SELECT info.wind_code, basedate date_frm, null, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM wind_index_info info ) tt WHERE date_frm <= if(null<end_date, null, end_date) ORDER BY wind_code;""" with with_db_session(engine_md) as session: # 获取每只股票需要获取日线数据的日期区间 table = session.execute(sql_str) # 获取每只股票需要获取日线数据的日期区间 begin_time = None wind_code_date_from_dic = { wind_code: (date_from if begin_time is None else min([date_from, begin_time]), date_to) for wind_code, date_from, date_to in table.fetchall() if wind_code_set is None or wind_code in wind_code_set } data_len = len(wind_code_date_from_dic) logger.info('%d indexes will been import', data_len) for data_num, (wind_code, (date_from, date_to)) in enumerate(wind_code_date_from_dic.items()): if str_2_date(date_from) > date_to: logger.warning("%d/%d) %s %s - %s 跳过", data_num, data_len, wind_code, date_from, date_to) continue try: temp = invoker.wsd(wind_code, wind_indictor_str, date_from, date_to) except APIError as exp: logger.exception("%d/%d) %s 执行异常", data_num, data_len, wind_code) if exp.ret_dic.setdefault('error_code', 0) in ( -40520007, # 没有可用数据 -40521009, # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月 ): continue else: break temp.reset_index(inplace=True) temp.rename(columns={'index': 'trade_date'}, inplace=True) temp.rename(columns=rename_col_dic, inplace=True) temp.trade_date = temp.trade_date.apply(str_2_date) temp['wind_code'] = wind_code bunch_insert_on_duplicate_update(temp, table_name, engine_md, dtype=dtype) logger.info('更新指数 %s 至 %s 成功', wind_code, date_2_str(date_to)) if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) build_primary_key([table_name])
def import_jq_stock_daily(chain_param=None, code_set=None): """ 插入股票日线数据到最近一个工作日-1。 如果超过 BASE_LINE_HOUR 时间,则获取当日的数据 :return: """ table_name_info = TABLE_NAME_INFO table_name = TABLE_NAME table_name_bak = get_bak_table_name(table_name) logging.info("更新 %s 开始", table_name) # 根据 info table 查询每只股票日期区间 sql_info_str = f""" SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to FROM ( SELECT info.jq_code, start_date date_frm, end_date date_to, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM {table_name_info} info ) tt WHERE date_frm <= if(date_to<end_date, date_to, end_date) ORDER BY jq_code""" has_table = engine_md.has_table(table_name) has_bak_table = engine_md.has_table(table_name_bak) # 进行表格判断,确定是否含有 jq_stock_daily_md if has_table: # 这里对原始的 sql语句进行了调整 # 以前的逻辑:每只股票最大的一个交易日+1天作为起始日期 # 现在的逻辑:每只股票最大一天的交易日作为起始日期 # 主要原因在希望通过接口获取到数据库中现有最大交易日对应的 factor因子以进行比对 sql_trade_date_range_str = f""" SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to FROM ( SELECT info.jq_code, ifnull(trade_date, info.start_date) date_frm, info.end_date date_to, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM {table_name_info} info LEFT OUTER JOIN (SELECT jq_code, max(trade_date) trade_date FROM {table_name} GROUP BY jq_code) daily ON info.jq_code = daily.jq_code ) tt WHERE date_frm < if(date_to<end_date, date_to, end_date) ORDER BY jq_code""" else: sql_trade_date_range_str = sql_info_str logger.warning('%s 不存在,仅使用 %s 表进行计算日期范围', table_name, table_name_info) sql_trade_date_str = """SELECT trade_date FROM jq_trade_date trddate WHERE trade_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) ORDER BY trade_date""" with with_db_session(engine_md) as session: # 获取截至当期全部交易日前 table = session.execute(sql_trade_date_str) trade_date_list = [row[0] for row in table.fetchall()] trade_date_list.sort() # 获取每只股票日线数据的日期区间 table = session.execute(sql_trade_date_range_str) # 计算每只股票需要获取日线数据的日期区间 # 获取date_from,date_to,将date_from,date_to做为value值 code_date_range_dic = { key_code: (date_from, date_to) for key_code, date_from, date_to in table.fetchall() if code_set is None or key_code in code_set } # 从 info 表中查询全部日期区间 if sql_info_str == sql_trade_date_range_str: code_date_range_from_info_dic = code_date_range_dic else: # 获取每只股票日线数据的日期区间 table = session.execute(sql_info_str) # 计算每只股票需要获取日线数据的日期区间 # 获取date_from,date_to,将date_from,date_to做为value值 code_date_range_from_info_dic = { key_code: (date_from, date_to) for key_code, date_from, date_to in table.fetchall() if code_set is None or key_code in code_set } # data_len = len(code_date_range_dic) data_df_list, data_count, all_data_count, data_len = [], 0, 0, len( code_date_range_dic) logger.info('%d stocks will been import into %s', data_len, table_name) # 将data_df数据,添加到data_df_list try: for num, (key_code, (date_from_tmp, date_to_tmp)) in enumerate(code_date_range_dic.items(), start=1): data_df = None try: for loop_count in range(2): # 根据交易日数据取交集,避免不用的请求耽误时间 date_from = get_first(trade_date_list, lambda x: x >= date_from_tmp) date_to = get_last(trade_date_list, lambda x: x <= date_to_tmp) if date_from is None or date_to is None or date_from >= date_to: logger.debug('%d/%d) %s [%s - %s] 跳过', num, data_len, key_code, date_from, date_to) break logger.debug('%d/%d) %s [%s - %s] %s', num, data_len, key_code, date_from, date_to, '第二次查询' if loop_count > 0 else '') data_df = invoke_daily(key_code=key_code, start_date=date_2_str(date_from), end_date=date_2_str(date_to)) # 该判断只在第一次循环时执行 if loop_count == 0 and has_table: # 进行 factor 因子判断,如果发现最小的一个交易日的因子不为1,则删除数据库中该股票的全部历史数据,然后重新下载。 # 因为当期股票下载的数据为前复权价格,如果股票出现复权调整,则历史数据全部需要重新下载 factor_value = data_df.sort_values('trade_date').iloc[ 0, :]['factor'] if factor_value != 1 and ( code_date_range_from_info_dic[key_code][0] != code_date_range_dic[key_code][0]): # 删除该股屏历史数据 sql_str = f"delete from {table_name} where jq_code=:jq_code" row_count = execute_sql_commit( sql_str, params={'jq_code': key_code}) date_from_tmp, date_to_tmp = code_date_range_from_info_dic[ key_code] if has_bak_table: sql_str = f"delete from {table_name_bak} where jq_code=:jq_code" row_count = execute_sql_commit( sql_str, params={'jq_code': key_code}) date_from_tmp, date_to_tmp = code_date_range_from_info_dic[ key_code] logger.info( '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s] 同时清除bak表中相应记录', num, data_len, key_code, row_count, date_from_tmp, date_to_tmp) else: logger.info( '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s]', num, data_len, key_code, row_count, date_from_tmp, date_to_tmp) # 重新设置起止日期,进行第二次循环 continue # 退出 for _ in range(2): 循环 break except Exception as exp: data_df = None logger.exception('%s [%s - %s]', key_code, date_2_str(date_from_tmp), date_2_str(date_to_tmp)) if exp.args[0].find('超过了每日最大查询限制'): break # 把数据攒起来 if data_df is not None and data_df.shape[0] > 0: data_count += data_df.shape[0] data_df_list.append(data_df) # 大于阀值有开始插入 if data_count >= 500: data_df_all = pd.concat(data_df_list) bunch_insert(data_df_all, table_name, dtype=DTYPE, primary_keys=['jq_code', 'trade_date']) all_data_count += data_count data_df_list, data_count = [], 0 if DEBUG and num >= 2: break finally: # 导入数据库 if len(data_df_list) > 0: data_df_all = pd.concat(data_df_list) data_count = bunch_insert(data_df_all, table_name, dtype=DTYPE, primary_keys=['jq_code', 'trade_date']) all_data_count = all_data_count + data_count logging.info("更新 %s 结束 %d 条信息被更新", table_name, all_data_count)
def transfer_2_batch(df: pd.DataFrame, n_step, labels=None, date_from=None, date_to=None): """ [num, factor_count] -> [num - n_step + 1, n_step, factor_count] 将 df 转化成 n_step 长度的一段一段的数据 labels 为与 df对应的数据,处理方式与index相同,如果labels不为空,则返回数据最后增加以下 new_ys :param df: :param n_step: :param labels:如果不为 None,则长度必须与 df.shape[0] 一致 :param date_from: :param date_to: :return: """ df_len = df.shape[0] if labels is not None and df_len != len(labels): raise ValueError("ys 长度 %d 必须与 df 长度 %d 保持一致", len(labels), df_len) # TODO: date_from, date_to 的逻辑可以进一步优化,延期为了省时间先保持这样 # 根据 date_from 对factor进行截取 if date_from is not None: date_from = pd.to_datetime(date_from) is_fit = df.index >= date_from if np.any(is_fit): start_idx = np.argmax(is_fit) - n_step if start_idx < 0: start_idx = 0 logger.warning("%s 为起始日期的数据,前向历史数据不足 %d 条,因此,起始日期向后推移至 %s", date_2_str(date_from), n_step, date_2_str(df.index[60])) df = df.iloc[start_idx:] df_len = df.shape[0] if labels is not None: labels = labels[start_idx:] else: logger.warning("没有 %s 之后的数据,当前数据最晚日期为 %s", date_2_str(date_from), date_2_str(max(df.index))) if labels is not None: return None, None, None, None else: return None, None, None # 根据 date_from 对factor进行截取 if date_to is not None: date_to = pd.to_datetime(date_to) is_fit = df.index <= date_to if np.any(is_fit): to_idx = np.argmin(is_fit) df = df.iloc[:to_idx] df_len = df.shape[0] if labels is not None: labels = labels[:to_idx] else: logger.warning("没有 %s 之前的数据,当前数据最晚日期为 %s", date_2_str(date_to), date_2_str(min(df.index))) if labels is not None: return None, None, None, None else: return None, None, None new_shape = [df_len - n_step + 1, n_step] new_shape.extend(df.shape[1:]) df_index, df_columns = df.index[(n_step - 1):], df.columns data_arr_batch, factor_arr = np.zeros(new_shape), df.to_numpy( dtype=np.float32) for idx_from, idx_to in enumerate(range(n_step, factor_arr.shape[0] + 1)): data_arr_batch[idx_from] = factor_arr[idx_from:idx_to] if labels is not None: new_ys = labels[(n_step - 1):] return df_index, df_columns, data_arr_batch, new_ys else: return df_index, df_columns, data_arr_batch
def merge_tushare_stock_daily(ths_code_set: set = None, date_from=None): """A股行情数据、财务信息 合并成为到 日级别数据""" table_name = 'tushare_stock_daily' logging.info("合成 %s 开始", table_name) has_table = engine_md.has_table(table_name) if date_from is None and has_table: sql_str = "select adddate(max(`trade_date`),1) from {table_name}".format(table_name=table_name) with with_db_session(engine_md) as session: date_from = date_2_str(session.execute(sql_str).scalar()) # 获取日级别数据 # TODO: 增加 ths_code_set 参数 daily_df, dtype_daily = get_tushare_daily_merged_df(ths_code_set, date_from) daily_df_g = daily_df.groupby('ts_code') ths_code_set_4_daily = set(daily_df_g.size().index) # 获取合并后的财务数据 ifind_fin_df, dtype_fin = get_tushre_merge_stock_fin_df() # 整理 dtype dtype = dtype_daily.copy() dtype.update(dtype_fin) logging.debug("提取财务数据完成") # 计算 财报披露时间 report_date_dic_dic = {} for num, ((ths_code, report_date), data_df) in enumerate( ifind_fin_df.groupby(['ts_code', 'f_ann_date']), start=1): if ths_code_set is not None and ths_code not in ths_code_set: continue if is_nan_or_none(report_date): continue report_date_dic = report_date_dic_dic.setdefault(ths_code, {}) if report_date not in report_date_dic_dic: if data_df.shape[0] > 0: report_date_dic[report_date] = data_df.iloc[0] logger.debug("计算财报日期完成") # 整理 data_df 数据 tot_data_count, data_count, data_df_list, for_count = 0, 0, [], len(report_date_dic_dic) try: for num, (ths_code, report_date_dic) in enumerate(report_date_dic_dic.items(), start=1): # key:ths_code # TODO: 檢查判斷 ths_code 是否存在在ifind_fin_df_g 裏面,,size暫時使用 以後在驚醒改進 if ths_code not in ths_code_set_4_daily: logger.error('fin 表中不存在 %s 的財務數據', ths_code) continue daily_df_cur_ts_code = daily_df_g.get_group(ths_code) logger.debug('%d/%d) 处理 %s %d 条数据', num, for_count, ths_code, daily_df_cur_ts_code.shape[0]) report_date_list = list(report_date_dic.keys()) report_date_list.sort() report_date_list_len = len(report_date_list) for num_sub, (report_date_from, report_date_to) in enumerate(iter_2_range(report_date_list)): logger.debug('%d/%d) %d/%d) 处理 %s [%s - %s]', num, for_count, num_sub, report_date_list_len, ths_code, date_2_str(report_date_from), date_2_str(report_date_to)) # 计算有效的日期范围 if report_date_from is None: is_fit = daily_df_cur_ts_code['trade_date'] < report_date_to elif report_date_to is None: is_fit = daily_df_cur_ts_code['trade_date'] >= report_date_from else: is_fit = (daily_df_cur_ts_code['trade_date'] < report_date_to) & ( daily_df_cur_ts_code['trade_date'] >= report_date_from) # 获取日期范围内的数据 ifind_his_ds_df_segment = daily_df_cur_ts_code[is_fit].copy() segment_count = ifind_his_ds_df_segment.shape[0] if segment_count == 0: continue fin_s = report_date_dic[report_date_from] if report_date_from is not None else None for key in dtype_fin.keys(): if key in ('ts_code', 'trade_date'): continue ifind_his_ds_df_segment[key] = fin_s[key] if fin_s is not None and key in fin_s else None ifind_his_ds_df_segment['report_date'] = report_date_from # 添加数据到列表 data_df_list.append(ifind_his_ds_df_segment) data_count += segment_count if DEBUG and len(data_df_list) > 1: break # 保存数据库 if data_count > 1000: # 保存到数据库 data_df = pd.concat(data_df_list) data_count = bunch_insert_on_duplicate_update( data_df, table_name, engine_md, dtype, myisam_if_create_table=True) tot_data_count += data_count data_count, data_df_list = 0, [] finally: # 保存到数据库 if len(data_df_list) > 0: data_df = pd.concat(data_df_list) data_count = bunch_insert_on_duplicate_update( data_df, table_name, engine_md, dtype, myisam_if_create_table=True) tot_data_count += data_count logger.info('%s 新增或更新记录 %d 条', table_name, tot_data_count) if not has_table and engine_md.has_table(table_name): build_primary_key([table_name])
def invoke_index_weight(index_code, trade_date): trade_date = date_2_str(trade_date, STR_FORMAT_DATE_TS) invoke_index_weight = pro.index_weight(index_code=index_code, trade_date=trade_date) return invoke_index_weight
def import_future_info(chain_param=None): """ 更新期货合约列表信息 :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用 :return: """ table_name = 'ifind_future_info' has_table = engine_md.has_table(table_name) logger.info("更新 %s [%s] 开始", table_name, has_table) # 获取已存在合约列表 if has_table: sql_str = f'SELECT ths_code, ths_start_trade_date_future FROM {table_name}' with with_db_session(engine_md) as session: table = session.execute(sql_str) code_ipo_date_dic = dict(table.fetchall()) exchange_latest_ipo_date_dic = get_exchange_latest_data() else: code_ipo_date_dic = {} exchange_latest_ipo_date_dic = {} exchange_sectorid_dic_list = [ { 'exch_eng': 'SHFE', 'exchange_name': '上海期货交易所', 'sectorid': '091001', 'date_establish': '1995-05-10' }, { 'exch_eng': 'CFFEX', 'exchange_name': '中国金融期货交易所', 'sectorid': '091004', 'date_establish': '2013-09-10' }, { 'exch_eng': 'DCE', 'exchange_name': '大连商品交易所', 'sectorid': '091002', 'date_establish': '1999-01-10' }, { 'exch_eng': 'CZCE', 'exchange_name': '郑州商品交易所', 'sectorid': '091003', 'date_establish': '1999-01-10' }, ] # 字段列表及参数 indicator_param_list = [ ('ths_future_short_name_future', '', String(50)), ('ths_future_code_future', '', String(20)), ('ths_sec_type_future', '', String(20)), ('ths_td_variety_future', '', String(20)), ('ths_td_unit_future', '', DOUBLE), ('ths_pricing_unit_future', '', String(20)), ('ths_mini_chg_price_future', '', DOUBLE), ('ths_chg_ratio_lmit_future', '', DOUBLE), ('ths_td_deposit_future', '', DOUBLE), ('ths_start_trade_date_future', '', Date), ('ths_last_td_date_future', '', Date), ('ths_last_delivery_date_future', '', Date), ('ths_delivery_month_future', '', String(10)), ('ths_listing_benchmark_price_future', '', DOUBLE), ('ths_initial_td_deposit_future', '', DOUBLE), ('ths_contract_month_explain_future', '', String(120)), ('ths_td_time_explain_future', '', String(120)), ('ths_last_td_date_explian_future', '', String(120)), ('ths_delivery_date_explain_future', '', String(120)), ('ths_exchange_short_name_future', '', String(50)), ('ths_contract_en_short_name_future', '', String(50)), ('ths_contract_en_name_future', '', String(50)), ] json_indicator, json_param = unzip_join( [(key, val) for key, val, _ in indicator_param_list], sep=';') # 设置 dtype dtype = {key: val for key, _, val in indicator_param_list} dtype['ths_code'] = String(20) dtype['exch_eng'] = String(20) # 获取合约列表 code_set = set() ndays_per_update = 90 # 获取历史期货合约列表信息 sector_count = len(exchange_sectorid_dic_list) for num, exchange_sectorid_dic in enumerate(exchange_sectorid_dic_list, start=1): exchange_name = exchange_sectorid_dic['exchange_name'] exch_eng = exchange_sectorid_dic['exch_eng'] sector_id = exchange_sectorid_dic['sectorid'] date_establish = exchange_sectorid_dic['date_establish'] # 计算获取合约列表的起始日期 date_since = str_2_date( exchange_latest_ipo_date_dic.setdefault(exch_eng, date_establish)) date_yestoday = date.today() - timedelta(days=1) logger.info("%d/%d) %s[%s][%s] %s ~ %s", num, sector_count, exchange_name, exch_eng, sector_id, date_since, date_yestoday) while date_since <= date_yestoday: date_since_str = date_2_str(date_since) # #数据池-板块_板块成分-日期;同花顺代码;证券名称;当日行情端证券名称(仅股票节点有效)-iFinD数据接口 # 获取板块成分(期货商品的合约) # THS_DP('block','2021-01-15;091002003','date:Y,thscode:Y,security_name:Y,security_name_in_time:Y') try: future_info_df = invoker.THS_DataPool( 'block', '%s;%s' % (date_since_str, sector_id), 'thscode:Y,security_name:Y') except APIError as exp: if exp.ret_dic['error_code'] in ( -4001, -4210, ): future_info_df = None else: logger.exception("THS_DataPool %s 获取失败, '%s;%s'", exchange_name, date_since_str, sector_id) break # if future_info_df is None or future_info_df.shape[0] == 0: # break if future_info_df is not None and future_info_df.shape[0] > 0: code_set |= set(future_info_df['THSCODE']) if date_since >= date_yestoday: break else: date_since += timedelta(days=ndays_per_update) if date_since > date_yestoday: date_since = date_yestoday if DEBUG: break # 获取合约列表 code_list = [wc for wc in code_set if wc not in code_ipo_date_dic] # 获取合约基本信息 if len(code_list) > 0: for code_list in split_chunk(code_list, 500): future_info_df = invoker.THS_BasicData(code_list, json_indicator, json_param) if future_info_df is None or future_info_df.shape[0] == 0: data_count = 0 logger.warning("更新 %s 结束 %d 条记录被更新", table_name, data_count) else: # 补充 exch_eng 字段 future_info_df['exch_eng'] = '' for exchange_sectorid_dic in exchange_sectorid_dic_list: future_info_df['exch_eng'][ future_info_df['ths_exchange_short_name_future'] == exchange_sectorid_dic[ 'exchange_name']] = exchange_sectorid_dic[ 'exch_eng'] data_count = bunch_insert_on_duplicate_update( future_info_df, table_name, engine_md, dtype, primary_keys=['ths_code'], schema=config.DB_SCHEMA_MD) logger.info("更新 %s 结束 %d 条记录被更新", table_name, data_count)
def import_index_daily_ds(chain_param=None, ths_code_set: set = None, begin_time=None): """ 通过date_serise接口将历史数据保存到 ifind_index_daily_ds,该数据作为 History数据的补充数据 例如:复权因子af、涨跌停标识、停牌状态、原因等 :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用 :param ths_code_set: :param begin_time: :return: """ table_name = 'ifind_index_daily_ds' has_table = engine_md.has_table(table_name) json_indicator, json_param = unzip_join( [(key, val) for key, val, _ in INDICATOR_PARAM_LIST_INDEX_DAILY_DS], sep=';') if has_table: sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to FROM ( SELECT info.ths_code, ifnull(trade_date_max_1, ths_index_base_period_index) date_frm, NULL, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM ifind_index_info info LEFT OUTER JOIN (SELECT ths_code, adddate(max(time),1) trade_date_max_1 FROM {table_name} GROUP BY ths_code) daily ON info.ths_code = daily.ths_code ) tt WHERE date_frm <= if(NULL<end_date, NULL, end_date) ORDER BY ths_code""".format(table_name=table_name) else: sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to FROM ( SELECT info.ths_code, ths_index_base_period_index date_frm, NULL, if(hour(now())<16, subdate(curdate(),1), curdate()) end_date FROM ifind_index_info info ) tt WHERE date_frm <= if(NULL<end_date, NULL, end_date) ORDER BY ths_code;""" logger.warning('%s 不存在,仅使用 ifind_index_info 表进行计算日期范围' % table_name) with with_db_session(engine_md) as session: # 获取每只股票需要获取日线数据的日期区间 table = session.execute(sql_str) # 获取每只股票需要获取日线数据的日期区间 code_date_range_dic = { ths_code: (date_from if begin_time is None else min([date_from, begin_time]), date_to) for ths_code, date_from, date_to in table.fetchall() if ths_code_set is None or ths_code in ths_code_set } if TRIAL: date_from_min = date.today() - timedelta(days=(365 * 5)) # 试用账号只能获取近5年数据 code_date_range_dic = { ths_code: (max([date_from, date_from_min]), date_to) for ths_code, (date_from, date_to) in code_date_range_dic.items() if date_to is not None and date_from_min <= date_to } data_df_list, data_count, tot_data_count, code_count = [], 0, 0, len( code_date_range_dic) try: for num, (ths_code, (begin_time, end_time)) in enumerate(code_date_range_dic.items(), start=1): logger.debug('%d/%d) %s [%s - %s]', num, code_count, ths_code, begin_time, end_time) end_time = date_2_str(end_time) data_df = invoker.THS_DateSerial( ths_code, json_indicator, json_param, 'Days:Tradedays,Fill:Previous,Interval:D', begin_time, end_time) if data_df is not None and data_df.shape[0] > 0: data_count += data_df.shape[0] data_df_list.append(data_df) # 大于阀值有开始插入 if data_count >= 10000: data_df_all = pd.concat(data_df_list) # data_df_all.to_sql(table_name, engine_md, if_exists='append', index=False, dtype=dtype) data_count = bunch_insert_on_duplicate_update( data_df_all, table_name, engine_md, DTYPE_INDEX_DAILY_DS) tot_data_count += data_count data_df_list, data_count = [], 0 # 仅调试使用 if DEBUG and len(data_df_list) > 1: break finally: if data_count > 0: data_df_all = pd.concat(data_df_list) data_count = bunch_insert_on_duplicate_update( data_df_all, table_name, engine_md, DTYPE_INDEX_DAILY_DS) tot_data_count += data_count if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) build_primary_key([table_name]) logging.info("更新 %s 完成 新增数据 %d 条", table_name, tot_data_count)
def update_private_fund_nav(chain_param=None, get_df=False, wind_code_list=None): """ :param chain_param: 在celery 中將前面結果做爲參數傳給後面的任務 :param get_df: :param wind_code_list: :return: """ table_name = 'wind_fund_nav' # 初始化数据下载端口 # 初始化数据库engine # 链接数据库,并获取fundnav旧表 # with get_db_session(engine) as session: # table = session.execute('select wind_code, ADDDATE(max(trade_date),1) from wind_fund_nav group by wind_code') # fund_trade_date_begin_dic = dict(table.fetchall()) # 获取wind_fund_info表信息 has_table = engine_md.has_table(table_name) if has_table: fund_info_df = pd.read_sql_query( """SELECT DISTINCT fi.wind_code AS wind_code, IFNULL(trade_date_from, if(trade_date_latest BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1), ADDDATE(trade_date_latest,1) , fund_setupdate) ) date_from, if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to FROM fund_info fi LEFT JOIN ( SELECT wind_code, ADDDATE(max(trade_date),1) trade_date_from FROM wind_fund_nav GROUP BY wind_code ) wfn ON fi.wind_code = wfn.wind_code""", engine_md) else: logger.warning('wind_fund_nav 不存在,仅使用 fund_info 表进行计算日期范围') fund_info_df = pd.read_sql_query( """SELECT DISTINCT fi.wind_code AS wind_code, fund_setupdate date_from, if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to FROM fund_info fi ORDER BY wind_code""", engine_md) wind_code_date_frm_to_dic = { wind_code: (str_2_date(date_from), str_2_date(date_to)) for wind_code, date_from, date_to in zip(fund_info_df['wind_code'], fund_info_df['date_from'], fund_info_df['date_to']) } fund_info_df.set_index('wind_code', inplace=True) if wind_code_list is None: wind_code_list = list(fund_info_df.index) else: wind_code_list = list(set(wind_code_list) & set(fund_info_df.index)) # 结束时间 date_last_day = date.today() - timedelta(days=1) # date_end_str = date_end.strftime(STR_FORMAT_DATE) fund_nav_all_df = [] no_data_count = 0 code_count = len(wind_code_list) # 对每个新获取的基金名称进行判断,若存在 fundnav 中,则只获取部分净值 wind_code_trade_date_latest_dic = {} date_gap = timedelta(days=10) try: for num, wind_code in enumerate(wind_code_list): date_begin, date_end = wind_code_date_frm_to_dic[wind_code] # if date_end > date_last_day: # date_end = date_last_day if date_begin > date_end: continue # 设定数据获取的起始日期 # wind_code_trade_date_latest_dic[wind_code] = date_to # if wind_code in fund_trade_date_begin_dic: # trade_latest = fund_trade_date_begin_dic[wind_code] # if trade_latest > date_end: # continue # date_begin = max([date_begin, trade_latest]) # if date_begin is None: # continue # elif isinstance(date_begin, str): # date_begin = datetime.strptime(date_begin, STR_FORMAT_DATE).date() if isinstance(date_begin, date): if date_begin.year < 1900: continue if date_begin > date_end: continue date_begin_str = date_begin.strftime('%Y-%m-%d') else: logger.error("%s date_begin:%s", wind_code, date_begin) continue if isinstance(date_end, date): if date_begin.year < 1900: continue if date_begin > date_end: continue date_end_str = date_end.strftime('%Y-%m-%d') else: logger.error("%s date_end:%s", wind_code, date_end) continue # 尝试获取 fund_nav 数据 for k in range(2): try: fund_nav_tmp_df = invoker.wsd( codes=wind_code, fields='nav,NAV_acc,NAV_date', beginTime=date_2_str(date_begin_str), endTime=date_2_str(date_end_str), options='Fill=Previous') trade_date_latest = datetime.strptime( date_end_str, '%Y-%m-%d').date() - date_gap wind_code_trade_date_latest_dic[ wind_code] = trade_date_latest break except APIError as exp: # -40520007z if exp.ret_dic.setdefault('error_code', 0) == -40520007: trade_date_latest = datetime.strptime( date_end_str, '%Y-%m-%d').date() - date_gap wind_code_trade_date_latest_dic[ wind_code] = trade_date_latest logger.error("%s Failed, ErrorMsg: %s" % (wind_code, str(exp))) continue except Exception as exp: logger.error("%s Failed, ErrorMsg: %s" % (wind_code, str(exp))) continue else: fund_nav_tmp_df = None if fund_nav_tmp_df is None: logger.info('%s No data', wind_code) # del wind_code_trade_date_latest_dic[wind_code] no_data_count += 1 logger.warning('%d funds no data', no_data_count) else: fund_nav_tmp_df.dropna(how='all', inplace=True) df_len = fund_nav_tmp_df.shape[0] if df_len == 0: continue fund_nav_tmp_df['wind_code'] = wind_code # 此处删除 trade_date_latest 之后再加上,主要是为了避免因抛出异常而导致的该条数据也被记录更新 # del wind_code_trade_date_latest_dic[wind_code] trade_date_latest = fund_nav_df_2_sql(table_name, fund_nav_tmp_df, engine_md, is_append=True) if trade_date_latest is None: logger.error('%s[%d] data insert failed', wind_code) else: wind_code_trade_date_latest_dic[ wind_code] = trade_date_latest logger.info('%d) %s updated, %d funds left', num, wind_code, code_count - num) if get_df: fund_nav_all_df = fund_nav_all_df.append( fund_nav_tmp_df) if DEBUG and num > 4: # 调试使用 break finally: # import_wind_fund_nav_to_fund_nav() # # update_trade_date_latest(wind_code_trade_date_latest_dic) # try: # # update_fund_mgrcomp_info() # except: # # 新功能上线前由于数据库表不存在,可能导致更新失败,属于正常现象 logger.exception('新功能上线前由于数据库表不存在,可能导致更新失败,属于正常现象') if not has_table and engine_md.has_table(table_name): alter_table_2_myisam(engine_md, [table_name]) build_primary_key([table_name]) return fund_nav_all_df