def tushare_to_sqlite_tot_select(file_name, table_name, field_pair_list): """ 将Mysql数据导入到sqlite,全量读取然后导出 速度快,对内存要求较高 :param file_name: :param table_name: :return: """ logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name) sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) conn = sqlite3.connect(db_file_path) sql_str = f"select * from {table_name}" df_tot = pd.read_sql(sql_str, engine_md) # # 对 fields 进行筛选及重命名 if field_pair_list is not None: field_list = [_[0] for _ in field_pair_list] field_list.append('ts_code') df_tot = df_tot[field_list].rename(columns=dict(field_pair_list)) dfg = df_tot.groupby('ts_code') num, code_count, data_count = 0, len(dfg), 0 for num, (ts_code, df) in enumerate(dfg, start=1): code_exchange = ts_code.split('.') sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}" df_len = df.shape[0] data_count += df_len logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
def tushare_to_sqlite_pre_ts_code(file_name, table_name, field_pair_list): """ 将Mysql数据导入到sqlite,全量读取然后导出 速度慢,出发内存比较少,或需要导出的数据不多,否则不需要使用 :param file_name: :param table_name: :return: """ logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name) sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) conn = sqlite3.connect(db_file_path) sql_str = f"select ts_code from {table_name} group by ts_code" with with_db_session(engine_md) as session: table = session.execute(sql_str) code_list = list([row[0] for row in table.fetchall()]) code_count, data_count = len(code_list), 0 for num, (ts_code) in enumerate(code_list, start=1): code_exchange = ts_code.split('.') sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}" sql_str = f"select * from {table_name} where ts_code=%s" # where code = '000001.XSHE' df = pd.read_sql(sql_str, engine_md, params=[ts_code]) # if field_pair_list is not None: field_list = [_[0] for _ in field_pair_list] field_list.append('ts_code') df_tot = df_tot[field_list].rename(columns=dict(field_pair_list)) df_len = df.shape[0] data_count += df_len logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
def get_cache_folder_path() -> str: import os folder_path = get_folder_path('output', create_if_not_found=True) _report_file_path = os.path.join(folder_path, 'report', "_cache_") if not os.path.exists(_report_file_path): os.makedirs(_report_file_path) return _report_file_path
def get_report_folder_path(stg_run_id=None) -> str: import os folder_path = get_folder_path('output', create_if_not_found=True) if stg_run_id is None: _report_file_path = os.path.join(folder_path, 'report') else: _report_file_path = os.path.join(folder_path, 'report', str(stg_run_id)) if not os.path.exists(_report_file_path): os.makedirs(_report_file_path) return _report_file_path
def get_sqlite_file_path(file_name): """ :param file_name: :return: """ if config.SQLITE_FOLDER_PATH is None: folder_path = get_folder_path('sqlite_db', create_if_not_found=False) else: folder_path = config.SQLITE_FOLDER_PATH file_path = os.path.join(folder_path, file_name) return file_path
def save_model(self): """ 将模型导出到文件 :return: """ saver = tf.train.Saver() folder_path = get_folder_path('my_net', create_if_not_found=False) if folder_path is None: raise ValueError('folder_path: "my_net" not exist') file_path = os.path.join( folder_path, f"save_net_{self.output_size}_{self.normalization_model}.ckpt") save_path = saver.save(self.session, file_path) logger.info("Save to path: %s", save_path) self.model_file_path = save_path return save_path
def tushare_to_sqlite_batch(file_name, table_name, field_pair_list, batch_size=500, **kwargs): """ 将Mysql数据导入到sqlite,全量读取然后导出 速度适中,可更加 batch_size 调剂对内存的需求 :param file_name: :param table_name: :param field_pair_list: :param batch_size: :param **kwargs: :return: """ logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name) sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) conn = sqlite3.connect(db_file_path) sql_str = f"select ts_code from {table_name} group by ts_code" with with_db_session(engine_md) as session: table = session.execute(sql_str) code_list = list([row[0] for row in table.fetchall()]) code_count, data_count, num = len(code_list), 0, 0 for code_sub_list in split_chunk(code_list, batch_size): in_clause = ", ".join([r'%s' for _ in code_sub_list]) sql_str = f"select * from {table_name} where ts_code in ({in_clause})" df_tot = pd.read_sql(sql_str, engine_md, params=code_sub_list) # 对 fields 进行筛选及重命名 if field_pair_list is not None: field_list = [_[0] for _ in field_pair_list] field_list.append('ts_code') df_tot = df_tot[field_list].rename(columns=dict(field_pair_list)) dfg = df_tot.groupby('ts_code') for num, (ts_code, df) in enumerate(dfg, start=num + 1): code_exchange = ts_code.split('.') sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}" df_len = df.shape[0] data_count += df_len logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df.drop('ts_code', axis=1, inplace=True) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
def __init__(self, unit=1, train=True): super().__init__() self.unit = unit self.input_size = 13 self.batch_size = 50 self.n_step = 20 self.output_size = 2 self.n_hidden_units = 10 self.lr = 0.006 self.normalization_model = True self._model = None # tf.Session() self._session = None self.train_validation_rate = 0.8 self.is_load_model_if_exist = True folder_path = get_folder_path('my_net', create_if_not_found=False) file_path = os.path.join(folder_path, f"save_net_{self.normalization_model}.ckpt") self.model_file_path = file_path self.training_iters = 600
""" @author : MG @Time : 19-5-21 下午2:21 @File : export.py @contact : [email protected] @desc : """ from functools import lru_cache import pandas as pd from tasks.backend import with_db_session_p, engine_md from tasks.wind.future_reorg.reorg_md_2_db import wind_future_continuous_md from ibats_utils.mess import get_folder_path, date_2_str import os import re module_root_path = get_folder_path(re.compile(r'^tasks$'), create_if_not_found=False) # 'tasks' root_parent_path = os.path.abspath( os.path.join(module_root_path, os.path.pardir)) @lru_cache() def get_export_path(file_name, create_folder_if_no_exist=True): folder_path = os.path.join(root_parent_path, 'export_files') if create_folder_if_no_exist and not os.path.exists(folder_path): os.makedirs(folder_path) return os.path.join(folder_path, file_name) def trade_date_list(file_path=None, future_or_stock='future'): if file_path is None:
def tushare_to_sqlite_batch(file_name, table_name, field_pair_list, batch_size=500, sort_by='trade_date', clean_old_file_first=True, **kwargs): """ 将Mysql数据导入到sqlite,全量读取然后导出 速度适中,可更加 batch_size 调剂对内存的需求 :param file_name: :param table_name: :param field_pair_list: :param batch_size: :param sort_by: :param clean_old_file_first: :param kwargs: :return: """ logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name) sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) # 删除历史文件——可以提上导入文件速度 if clean_old_file_first and os.path.exists( db_file_path) and os.path.isfile(db_file_path): os.remove(db_file_path) conn = sqlite3.connect(db_file_path) # 对 fields 进行筛选及重命名 if field_pair_list is not None: field_list = [_[0] for _ in field_pair_list] field_list.append('ts_code') field_pair_dic = dict(field_pair_list) sort_by = field_pair_dic[sort_by] if sort_by is not None else None else: field_list = None field_pair_dic = None if table_name == 'tushare_stock_index_daily_md': # tushare_stock_index_daily_md 表处理方式有些特殊 ts_code_sqlite_table_name_dic = { # "": "CBIndex", # "h30024.CSI": "CYBZ", # 中证800保险 "399300.SZ": "HS300", # 沪深300 "000016.SH": "HS50", # 上证50 "399905.SZ": "HS500", # 中证500 "399678.SZ": "SCXG", # 深次新股 "399101.SZ": "ZXBZ", # 中小板综 } code_list = [_ for _ in ts_code_sqlite_table_name_dic.keys()] in_clause = ", ".join([r'%s' for _ in code_list]) sql_str = f"select * from {table_name} where ts_code in ({in_clause})" df_tot = pd.read_sql(sql_str, engine_md, params=code_list) # 对 fields 进行筛选及重命名 if field_pair_dic is not None: df_tot = df_tot[field_list].rename(columns=field_pair_dic) dfg = df_tot.groupby('ts_code') code_count, data_count = len(code_list), 0 for num, (ts_code, df) in enumerate(dfg, start=1): sqlite_table_name = ts_code_sqlite_table_name_dic[ts_code] df_len = df.shape[0] data_count += df_len logger.debug('%2d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df = df.drop('ts_code', axis=1) # 排序 if sort_by is not None: df = df.sort_values(sort_by) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') else: # 非 tushare_stock_index_daily_md 表 sql_str = f"select ts_code from {table_name} group by ts_code" with with_db_session(engine_md) as session: table = session.execute(sql_str) code_list = list([row[0] for row in table.fetchall()]) code_count, data_count, num = len(code_list), 0, 0 for code_sub_list in split_chunk(code_list, batch_size): in_clause = ", ".join([r'%s' for _ in code_sub_list]) sql_str = f"select * from {table_name} where ts_code in ({in_clause})" df_tot = pd.read_sql(sql_str, engine_md, params=code_sub_list) # 对 fields 进行筛选及重命名 if field_pair_dic is not None: df_tot = df_tot[field_list].rename(columns=field_pair_dic) dfg = df_tot.groupby('ts_code') for num, (ts_code, df) in enumerate(dfg, start=num + 1): code_exchange = ts_code.split('.') sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}" df_len = df.shape[0] data_count += df_len logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num, code_count, table_name, file_name, sqlite_table_name, df_len) df = df.drop('ts_code', axis=1) # 排序 if sort_by is not None: df = df.sort_values(sort_by) df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace') logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
def get_sqlite_conn(file_name): sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False) db_file_path = os.path.join(sqlite_db_folder_path, file_name) conn = sqlite3.connect(db_file_path) return conn
def _test_use(is_plot): from ibats_common.backend.mess import get_folder_path import os # 参数设置 run_mode = RunMode.Backtest strategy_params = {'unit': 100} md_agent_params_list = [{ 'md_period': PeriodType.Min1, 'instrument_id_list': ['RB'], 'datetime_key': 'trade_date', 'init_md_date_from': '1995-1-1', # 行情初始化加载历史数据,供策略分析预加载使用 'init_md_date_to': '2014-1-1', # 'C:\GitHub\IBATS_Common\ibats_common\example\ru_price2.csv' 'file_path': os.path.abspath( os.path.join(get_folder_path('example', create_if_not_found=False), 'data', 'RB.csv')), 'symbol_key': 'instrument_type', }] if run_mode == RunMode.Realtime: trade_agent_params = {} strategy_handler_param = {} else: trade_agent_params = { 'trade_mode': BacktestTradeMode.Order_2_Deal, 'init_cash': 1000000, "calc_mode": CalcMode.Margin, } strategy_handler_param = { 'date_from': '2014-1-1', # 策略回测历史数据,回测指定时间段的历史行情 'date_to': '2018-10-18', } # 初始化策略处理器 stghandler = strategy_handler_factory( stg_class=AIStg, strategy_params=strategy_params, md_agent_params_list=md_agent_params_list, exchange_name=ExchangeName.LocalFile, run_mode=RunMode.Backtest, trade_agent_params=trade_agent_params, strategy_handler_param=strategy_handler_param, ) stghandler.start() time.sleep(10) stghandler.keep_running = False stghandler.join() stg_run_id = stghandler.stg_run_id logging.info("执行结束 stg_run_id = %d", stg_run_id) if is_plot: from ibats_common.analysis.plot_db import show_order, show_cash_and_margin, show_rr_with_md show_order(stg_run_id) show_cash_and_margin(stg_run_id) show_rr_with_md(stg_run_id) return stg_run_id
#! /usr/bin/env python # -*- coding:utf-8 -*- """ @author : MG @Time : 2018/6/26 10:27 @File : __init__.py.py @contact : [email protected] @desc : """ from ibats_utils.mess import get_folder_path import os import re module_root_path = get_folder_path(re.compile(r'^ibats[\w]+'), create_if_not_found=False) # 'ibats_common' root_parent_path = os.path.abspath( os.path.join(module_root_path, os.path.pardir))