def tushare_to_sqlite_tot_select(file_name, table_name, field_pair_list):
    """
    将Mysql数据导入到sqlite,全量读取然后导出
    速度快,对内存要求较高
    :param file_name:
    :param table_name:
    :return:
    """
    logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name)
    sqlite_db_folder_path = get_folder_path('sqlite_db',
                                            create_if_not_found=False)
    db_file_path = os.path.join(sqlite_db_folder_path, file_name)
    conn = sqlite3.connect(db_file_path)
    sql_str = f"select * from {table_name}"
    df_tot = pd.read_sql(sql_str, engine_md)  #
    # 对 fields 进行筛选及重命名
    if field_pair_list is not None:
        field_list = [_[0] for _ in field_pair_list]
        field_list.append('ts_code')
        df_tot = df_tot[field_list].rename(columns=dict(field_pair_list))

    dfg = df_tot.groupby('ts_code')
    num, code_count, data_count = 0, len(dfg), 0
    for num, (ts_code, df) in enumerate(dfg, start=1):
        code_exchange = ts_code.split('.')
        sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}"
        df_len = df.shape[0]
        data_count += df_len
        logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num,
                     code_count, table_name, file_name, sqlite_table_name,
                     df_len)
        df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace')

    logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name,
                data_count)
def tushare_to_sqlite_pre_ts_code(file_name, table_name, field_pair_list):
    """
    将Mysql数据导入到sqlite,全量读取然后导出
    速度慢,出发内存比较少,或需要导出的数据不多,否则不需要使用
    :param file_name:
    :param table_name:
    :return:
    """
    logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name)
    sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False)
    db_file_path = os.path.join(sqlite_db_folder_path, file_name)
    conn = sqlite3.connect(db_file_path)
    sql_str = f"select ts_code from {table_name} group by ts_code"
    with with_db_session(engine_md) as session:
        table = session.execute(sql_str)
        code_list = list([row[0] for row in table.fetchall()])

    code_count, data_count = len(code_list), 0
    for num, (ts_code) in enumerate(code_list, start=1):
        code_exchange = ts_code.split('.')
        sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}"
        sql_str = f"select * from {table_name} where ts_code=%s"  # where code = '000001.XSHE'
        df = pd.read_sql(sql_str, engine_md, params=[ts_code])  #
        if field_pair_list is not None:
            field_list = [_[0] for _ in field_pair_list]
            field_list.append('ts_code')
            df_tot = df_tot[field_list].rename(columns=dict(field_pair_list))

        df_len = df.shape[0]
        data_count += df_len
        logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录',
                     num, code_count, table_name, file_name, sqlite_table_name, df_len)
        df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace')

    logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
Exemple #3
0
def get_cache_folder_path() -> str:
    import os
    folder_path = get_folder_path('output', create_if_not_found=True)
    _report_file_path = os.path.join(folder_path, 'report', "_cache_")
    if not os.path.exists(_report_file_path):
        os.makedirs(_report_file_path)

    return _report_file_path
Exemple #4
0
def get_report_folder_path(stg_run_id=None) -> str:
    import os
    folder_path = get_folder_path('output', create_if_not_found=True)
    if stg_run_id is None:
        _report_file_path = os.path.join(folder_path, 'report')
    else:
        _report_file_path = os.path.join(folder_path, 'report',
                                         str(stg_run_id))
    if not os.path.exists(_report_file_path):
        os.makedirs(_report_file_path)

    return _report_file_path
Exemple #5
0
def get_sqlite_file_path(file_name):
    """

    :param file_name:
    :return:
    """
    if config.SQLITE_FOLDER_PATH is None:
        folder_path = get_folder_path('sqlite_db', create_if_not_found=False)
    else:
        folder_path = config.SQLITE_FOLDER_PATH

    file_path = os.path.join(folder_path, file_name)
    return file_path
Exemple #6
0
 def save_model(self):
     """
     将模型导出到文件
     :return:
     """
     saver = tf.train.Saver()
     folder_path = get_folder_path('my_net', create_if_not_found=False)
     if folder_path is None:
         raise ValueError('folder_path: "my_net" not exist')
     file_path = os.path.join(
         folder_path,
         f"save_net_{self.output_size}_{self.normalization_model}.ckpt")
     save_path = saver.save(self.session, file_path)
     logger.info("Save to path: %s", save_path)
     self.model_file_path = save_path
     return save_path
def tushare_to_sqlite_batch(file_name, table_name, field_pair_list, batch_size=500, **kwargs):
    """
    将Mysql数据导入到sqlite,全量读取然后导出
    速度适中,可更加 batch_size 调剂对内存的需求
    :param file_name:
    :param table_name:
    :param field_pair_list:
    :param batch_size:
    :param **kwargs:
    :return:
    """
    logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name)
    sqlite_db_folder_path = get_folder_path('sqlite_db', create_if_not_found=False)
    db_file_path = os.path.join(sqlite_db_folder_path, file_name)
    conn = sqlite3.connect(db_file_path)
    sql_str = f"select ts_code from {table_name} group by ts_code"
    with with_db_session(engine_md) as session:
        table = session.execute(sql_str)
        code_list = list([row[0] for row in table.fetchall()])

    code_count, data_count, num = len(code_list), 0, 0
    for code_sub_list in split_chunk(code_list, batch_size):
        in_clause = ", ".join([r'%s' for _ in code_sub_list])
        sql_str = f"select * from {table_name} where ts_code in ({in_clause})"
        df_tot = pd.read_sql(sql_str, engine_md, params=code_sub_list)
        # 对 fields 进行筛选及重命名
        if field_pair_list is not None:
            field_list = [_[0] for _ in field_pair_list]
            field_list.append('ts_code')
            df_tot = df_tot[field_list].rename(columns=dict(field_pair_list))

        dfg = df_tot.groupby('ts_code')
        for num, (ts_code, df) in enumerate(dfg, start=num + 1):
            code_exchange = ts_code.split('.')
            sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}"
            df_len = df.shape[0]
            data_count += df_len
            logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录',
                         num, code_count, table_name, file_name, sqlite_table_name, df_len)
            df.drop('ts_code', axis=1, inplace=True)
            df.to_sql(sqlite_table_name, conn, index=False, if_exists='replace')

    logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name, data_count)
Exemple #8
0
 def __init__(self, unit=1, train=True):
     super().__init__()
     self.unit = unit
     self.input_size = 13
     self.batch_size = 50
     self.n_step = 20
     self.output_size = 2
     self.n_hidden_units = 10
     self.lr = 0.006
     self.normalization_model = True
     self._model = None
     # tf.Session()
     self._session = None
     self.train_validation_rate = 0.8
     self.is_load_model_if_exist = True
     folder_path = get_folder_path('my_net', create_if_not_found=False)
     file_path = os.path.join(folder_path, f"save_net_{self.normalization_model}.ckpt")
     self.model_file_path = file_path
     self.training_iters = 600
"""
@author  : MG
@Time    : 19-5-21 下午2:21
@File    : export.py
@contact : [email protected]
@desc    : 
"""
from functools import lru_cache
import pandas as pd
from tasks.backend import with_db_session_p, engine_md
from tasks.wind.future_reorg.reorg_md_2_db import wind_future_continuous_md
from ibats_utils.mess import get_folder_path, date_2_str
import os
import re

module_root_path = get_folder_path(re.compile(r'^tasks$'),
                                   create_if_not_found=False)  # 'tasks'
root_parent_path = os.path.abspath(
    os.path.join(module_root_path, os.path.pardir))


@lru_cache()
def get_export_path(file_name, create_folder_if_no_exist=True):
    folder_path = os.path.join(root_parent_path, 'export_files')
    if create_folder_if_no_exist and not os.path.exists(folder_path):
        os.makedirs(folder_path)

    return os.path.join(folder_path, file_name)


def trade_date_list(file_path=None, future_or_stock='future'):
    if file_path is None:
def tushare_to_sqlite_batch(file_name,
                            table_name,
                            field_pair_list,
                            batch_size=500,
                            sort_by='trade_date',
                            clean_old_file_first=True,
                            **kwargs):
    """
    将Mysql数据导入到sqlite,全量读取然后导出
    速度适中,可更加 batch_size 调剂对内存的需求
    :param file_name:
    :param table_name:
    :param field_pair_list:
    :param batch_size:
    :param sort_by:
    :param clean_old_file_first:
    :param kwargs:
    :return:
    """
    logger.info('mysql %s 导入到 sqlite %s 开始', table_name, file_name)
    sqlite_db_folder_path = get_folder_path('sqlite_db',
                                            create_if_not_found=False)
    db_file_path = os.path.join(sqlite_db_folder_path, file_name)
    # 删除历史文件——可以提上导入文件速度
    if clean_old_file_first and os.path.exists(
            db_file_path) and os.path.isfile(db_file_path):
        os.remove(db_file_path)

    conn = sqlite3.connect(db_file_path)
    # 对 fields 进行筛选及重命名
    if field_pair_list is not None:
        field_list = [_[0] for _ in field_pair_list]
        field_list.append('ts_code')
        field_pair_dic = dict(field_pair_list)
        sort_by = field_pair_dic[sort_by] if sort_by is not None else None
    else:
        field_list = None
        field_pair_dic = None

    if table_name == 'tushare_stock_index_daily_md':
        # tushare_stock_index_daily_md 表处理方式有些特殊
        ts_code_sqlite_table_name_dic = {
            # "": "CBIndex",  #
            "h30024.CSI": "CYBZ",  # 中证800保险
            "399300.SZ": "HS300",  # 沪深300
            "000016.SH": "HS50",  # 上证50
            "399905.SZ": "HS500",  # 中证500
            "399678.SZ": "SCXG",  # 深次新股
            "399101.SZ": "ZXBZ",  # 中小板综
        }
        code_list = [_ for _ in ts_code_sqlite_table_name_dic.keys()]
        in_clause = ", ".join([r'%s' for _ in code_list])
        sql_str = f"select * from {table_name} where ts_code in ({in_clause})"
        df_tot = pd.read_sql(sql_str, engine_md, params=code_list)
        # 对 fields 进行筛选及重命名
        if field_pair_dic is not None:
            df_tot = df_tot[field_list].rename(columns=field_pair_dic)

        dfg = df_tot.groupby('ts_code')
        code_count, data_count = len(code_list), 0
        for num, (ts_code, df) in enumerate(dfg, start=1):
            sqlite_table_name = ts_code_sqlite_table_name_dic[ts_code]
            df_len = df.shape[0]
            data_count += df_len
            logger.debug('%2d/%d) mysql %s -> sqlite %s %s %d 条记录', num,
                         code_count, table_name, file_name, sqlite_table_name,
                         df_len)
            df = df.drop('ts_code', axis=1)
            # 排序
            if sort_by is not None:
                df = df.sort_values(sort_by)

            df.to_sql(sqlite_table_name,
                      conn,
                      index=False,
                      if_exists='replace')
    else:
        # 非 tushare_stock_index_daily_md 表
        sql_str = f"select ts_code from {table_name} group by ts_code"
        with with_db_session(engine_md) as session:
            table = session.execute(sql_str)
            code_list = list([row[0] for row in table.fetchall()])

        code_count, data_count, num = len(code_list), 0, 0
        for code_sub_list in split_chunk(code_list, batch_size):
            in_clause = ", ".join([r'%s' for _ in code_sub_list])
            sql_str = f"select * from {table_name} where ts_code in ({in_clause})"
            df_tot = pd.read_sql(sql_str, engine_md, params=code_sub_list)
            # 对 fields 进行筛选及重命名
            if field_pair_dic is not None:
                df_tot = df_tot[field_list].rename(columns=field_pair_dic)

            dfg = df_tot.groupby('ts_code')
            for num, (ts_code, df) in enumerate(dfg, start=num + 1):
                code_exchange = ts_code.split('.')
                sqlite_table_name = f"{code_exchange[1]}{code_exchange[0]}"
                df_len = df.shape[0]
                data_count += df_len
                logger.debug('%4d/%d) mysql %s -> sqlite %s %s %d 条记录', num,
                             code_count, table_name, file_name,
                             sqlite_table_name, df_len)
                df = df.drop('ts_code', axis=1)
                # 排序
                if sort_by is not None:
                    df = df.sort_values(sort_by)

                df.to_sql(sqlite_table_name,
                          conn,
                          index=False,
                          if_exists='replace')

    logger.info('mysql %s 导入到 sqlite %s 结束,导出数据 %d 条', table_name, file_name,
                data_count)
def get_sqlite_conn(file_name):
    sqlite_db_folder_path = get_folder_path('sqlite_db',
                                            create_if_not_found=False)
    db_file_path = os.path.join(sqlite_db_folder_path, file_name)
    conn = sqlite3.connect(db_file_path)
    return conn
Exemple #12
0
def _test_use(is_plot):
    from ibats_common.backend.mess import get_folder_path
    import os
    # 参数设置
    run_mode = RunMode.Backtest
    strategy_params = {'unit': 100}
    md_agent_params_list = [{
        'md_period':
        PeriodType.Min1,
        'instrument_id_list': ['RB'],
        'datetime_key':
        'trade_date',
        'init_md_date_from':
        '1995-1-1',  # 行情初始化加载历史数据,供策略分析预加载使用
        'init_md_date_to':
        '2014-1-1',
        # 'C:\GitHub\IBATS_Common\ibats_common\example\ru_price2.csv'
        'file_path':
        os.path.abspath(
            os.path.join(get_folder_path('example', create_if_not_found=False),
                         'data', 'RB.csv')),
        'symbol_key':
        'instrument_type',
    }]
    if run_mode == RunMode.Realtime:
        trade_agent_params = {}
        strategy_handler_param = {}
    else:
        trade_agent_params = {
            'trade_mode': BacktestTradeMode.Order_2_Deal,
            'init_cash': 1000000,
            "calc_mode": CalcMode.Margin,
        }
        strategy_handler_param = {
            'date_from': '2014-1-1',  # 策略回测历史数据,回测指定时间段的历史行情
            'date_to': '2018-10-18',
        }
    # 初始化策略处理器
    stghandler = strategy_handler_factory(
        stg_class=AIStg,
        strategy_params=strategy_params,
        md_agent_params_list=md_agent_params_list,
        exchange_name=ExchangeName.LocalFile,
        run_mode=RunMode.Backtest,
        trade_agent_params=trade_agent_params,
        strategy_handler_param=strategy_handler_param,
    )
    stghandler.start()
    time.sleep(10)
    stghandler.keep_running = False
    stghandler.join()
    stg_run_id = stghandler.stg_run_id
    logging.info("执行结束 stg_run_id = %d", stg_run_id)

    if is_plot:
        from ibats_common.analysis.plot_db import show_order, show_cash_and_margin, show_rr_with_md
        show_order(stg_run_id)
        show_cash_and_margin(stg_run_id)
        show_rr_with_md(stg_run_id)

    return stg_run_id
Exemple #13
0
#! /usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author  : MG
@Time    : 2018/6/26 10:27
@File    : __init__.py.py
@contact : [email protected]
@desc    : 
"""
from ibats_utils.mess import get_folder_path
import os
import re

module_root_path = get_folder_path(re.compile(r'^ibats[\w]+'),
                                   create_if_not_found=False)  # 'ibats_common'
root_parent_path = os.path.abspath(
    os.path.join(module_root_path, os.path.pardir))