Ejemplo n.º 1
0
    def load_model_if_exist(self,
                            trade_date,
                            enable_load_model_if_exist=False):
        """
        将模型导出到文件
        目录结构:
        tf_saves_2019-06-05_16_21_39
          *   model_tfls
          *       *   2012-12-31
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *       *   2013-02-28
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *   tensorboard_logs
          *       *   2012-12-31_496[1]_20190605_184316
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
          *       *   2013-02-28_496[1]_20190605_184716
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
        :param enable_load_model_if_exist:
        :param trade_date:
        :return:
        """
        if self.enable_load_model_if_exist or enable_load_model_if_exist:
            # 获取小于等于当期交易日的最大的一个交易日对应的文件名
            min_available_date = str_2_date(trade_date) - timedelta(
                days=self.retrain_period)
            self.logger.debug('尝试加载现有模型,[%s - %s] %d 天', min_available_date,
                              trade_date, self.retrain_period)
            date_file_path_pair_list = [
                _ for _ in self.get_date_file_path_pair_list()
                if _[0] >= min_available_date
            ]
            if len(date_file_path_pair_list) > 0:
                # 按日期排序
                date_file_path_pair_list.sort(key=lambda x: x[0])
                # 获取小于等于当期交易日的最大的一个交易日对应的文件名
                # file_path = get_last(date_file_path_pair_list, lambda x: x[0] <= trade_date, lambda x: x[1])
                trade_date = str_2_date(trade_date)
                ret = get_last(date_file_path_pair_list,
                               lambda x: x[0] <= trade_date)
                if ret is not None:
                    key, folder_path, predict_test_random_state = ret
                    if folder_path is not None:
                        model = self.get_model(
                            rebuild_model=True)  # 这句话是必须的,需要实现建立模型才可以加载
                        model.load(folder_path)
                        self.trade_date_last_train = key
                        self.predict_test_random_state = predict_test_random_state
                        self.logger.info(
                            "加载模型成功。trade_date_last_train: %s load from path: %s",
                            key, folder_path)
                        return True

        return False
Ejemplo n.º 2
0
def import_tushare_stock_daily(chain_param=None, ts_code_set=None):
    """
    插入股票日线数据到最近一个工作日-1。
    如果超过 BASE_LINE_HOUR 时间,则获取当日的数据
    :return:
    """
    table_name = 'tushare_stock_daily_md'
    logging.info("更新 %s 开始", table_name)

    has_table = engine_md.has_table(table_name)
    # 进行表格判断,确定是否含有tushare_stock_daily
    if has_table:
        sql_str = """
            SELECT ts_code, date_frm, if(delist_date<end_date, delist_date, end_date) date_to
            FROM
            (
            SELECT info.ts_code, ifnull(trade_date, list_date) date_frm, delist_date,
            if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
            FROM 
                tushare_stock_info info 
            LEFT OUTER JOIN
                (SELECT ts_code, adddate(max(trade_date),1) trade_date FROM {table_name} GROUP BY ts_code) daily
            ON info.ts_code = daily.ts_code
            ) tt
            WHERE date_frm <= if(delist_date<end_date, delist_date, end_date) 
            ORDER BY ts_code""".format(table_name=table_name)
    else:
        sql_str = """
            SELECT ts_code, date_frm, if(delist_date<end_date, delist_date, end_date) date_to
            FROM
              (
                SELECT info.ts_code, list_date date_frm, delist_date,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM tushare_stock_info info 
              ) tt
            WHERE date_frm <= if(delist_date<end_date, delist_date, end_date) 
            ORDER BY ts_code"""
        logger.warning('%s 不存在,仅使用 tushare_stock_info 表进行计算日期范围', table_name)

    sql_trade_date_str = """
           SELECT cal_date FROM tushare_trade_date trddate WHERE (trddate.is_open=1 
        AND cal_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) 
        AND exchange='SSE') ORDER BY cal_date"""

    with with_db_session(engine_md) as session:
        table = session.execute(sql_trade_date_str)
        trade_date_list = [row[0] for row in table.fetchall()]
        trade_date_list.sort()
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 计算每只股票需要获取日线数据的日期区间
        begin_time = None
        # 获取date_from,date_to,将date_from,date_to做为value值
        code_date_range_dic = {
            ts_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for ts_code, date_from, date_to in table.fetchall()
            if ts_code_set is None or ts_code in ts_code_set
        }

    # data_len = len(code_date_range_dic)
    data_df_list, data_count, all_data_count, data_len = [], 0, 0, len(
        code_date_range_dic)
    logger.info('%d stocks will been import into tushare_stock_daily_md',
                data_len)
    # 将data_df数据,添加到data_df_list

    try:
        for num, (ts_code,
                  (date_from_tmp,
                   date_to_tmp)) in enumerate(code_date_range_dic.items(),
                                              start=1):
            date_from = get_first(trade_date_list,
                                  lambda x: x >= date_from_tmp)
            date_to = get_last(trade_date_list, lambda x: x <= date_to_tmp)
            if date_from is None or date_to is None or date_from > date_to:
                logger.debug('%d/%d) %s [%s - %s] 跳过', num, data_len, ts_code,
                             date_from, date_to)
                continue
            logger.debug('%d/%d) %s [%s - %s]', num, data_len, ts_code,
                         date_from, date_to)
            data_df = invoke_daily(
                ts_code=ts_code,
                start_date=datetime_2_str(date_from, STR_FORMAT_DATE_TS),
                end_date=datetime_2_str(date_to, STR_FORMAT_DATE_TS))
            # data_df = df
            if len(data_df) > 0:
                while try_2_date(data_df['trade_date'].iloc[-1]) > date_from:
                    last_date_in_df_last, last_date_in_df_cur = try_2_date(
                        data_df['trade_date'].iloc[-1]), None
                    df2 = invoke_daily(
                        ts_code=ts_code,
                        start_date=datetime_2_str(date_from,
                                                  STR_FORMAT_DATE_TS),
                        end_date=datetime_2_str(
                            try_2_date(data_df['trade_date'].iloc[-1]) -
                            timedelta(days=1), STR_FORMAT_DATE_TS))
                    if df2 is not None and len(df2) > 0:
                        last_date_in_df_cur = try_2_date(
                            df2['trade_date'].iloc[-1])
                        if last_date_in_df_cur < last_date_in_df_last:
                            data_df = pd.concat([data_df, df2])
                            # df = df2
                        elif last_date_in_df_cur == last_date_in_df_last:
                            break
                        if data_df is None:
                            logger.warning(
                                '%d/%d) %s has no data during %s %s', num,
                                data_len, ts_code, date_from, date_to)
                            continue
                        logger.info('%d/%d) %d data of %s between %s and %s',
                                    num, data_len, data_df.shape[0], ts_code,
                                    date_from, date_to)
                    else:
                        break
            # 把数据攒起来
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df_list.append(data_df)

            # 大于阀值有开始插入
            if data_count >= 500:
                data_df_all = pd.concat(data_df_list)
                bunch_insert_on_duplicate_update(data_df_all, table_name,
                                                 engine_md,
                                                 DTYPE_TUSHARE_STOCK_DAILY_MD)
                data_count = bunch_insert(
                    data_df_all,
                    table_name=table_name,
                    dtype=DTYPE_TUSHARE_STOCK_DAILY_MD,
                    primary_keys=['ts_code', 'trade_date'])
                all_data_count += data_count
                data_df_list, data_count = [], 0

    finally:
        # 导入数据库
        if len(data_df_list) > 0:
            data_df_all = pd.concat(data_df_list)
            data_count = bunch_insert(data_df_all,
                                      table_name=table_name,
                                      dtype=DTYPE_TUSHARE_STOCK_DAILY_MD,
                                      primary_keys=['ts_code', 'trade_date'])
            all_data_count = all_data_count + data_count
            logging.info("更新 %s 结束 %d 条信息被更新", table_name, all_data_count)
            if not has_table and engine_md.has_table(table_name):
                alter_table_2_myisam(engine_md, [table_name])
                build_primary_key([table_name])
Ejemplo n.º 3
0
def import_jq_stock_daily(chain_param=None, code_set=None):
    """
    插入股票日线数据到最近一个工作日-1。
    如果超过 BASE_LINE_HOUR 时间,则获取当日的数据
    :return:
    """
    table_name_info = TABLE_NAME_INFO
    table_name = TABLE_NAME
    table_name_bak = get_bak_table_name(table_name)
    logging.info("更新 %s 开始", table_name)

    # 根据 info table 查询每只股票日期区间
    sql_info_str = f"""
        SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to
        FROM
          (
            SELECT info.jq_code, start_date date_frm, end_date date_to,
            if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
            FROM {table_name_info} info
          ) tt
        WHERE date_frm <= if(date_to<end_date, date_to, end_date)
        ORDER BY jq_code"""

    has_table = engine_md.has_table(table_name)
    has_bak_table = engine_md.has_table(table_name_bak)
    # 进行表格判断,确定是否含有 jq_stock_daily_md
    if has_table:
        # 这里对原始的 sql语句进行了调整
        # 以前的逻辑:每只股票最大的一个交易日+1天作为起始日期
        # 现在的逻辑:每只股票最大一天的交易日作为起始日期
        # 主要原因在希望通过接口获取到数据库中现有最大交易日对应的 factor因子以进行比对
        sql_trade_date_range_str = f"""
            SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to
            FROM
            (
                SELECT info.jq_code, ifnull(trade_date, info.start_date) date_frm, info.end_date date_to,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM 
                    {table_name_info} info 
                LEFT OUTER JOIN
                    (SELECT jq_code, max(trade_date) trade_date FROM {table_name} GROUP BY jq_code) daily
                ON info.jq_code = daily.jq_code
            ) tt
            WHERE date_frm < if(date_to<end_date, date_to, end_date) 
            ORDER BY jq_code"""

    else:
        sql_trade_date_range_str = sql_info_str
        logger.warning('%s 不存在,仅使用 %s 表进行计算日期范围', table_name, table_name_info)

    sql_trade_date_str = """SELECT trade_date FROM jq_trade_date trddate 
       WHERE trade_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) ORDER BY trade_date"""

    with with_db_session(engine_md) as session:
        # 获取截至当期全部交易日前
        table = session.execute(sql_trade_date_str)
        trade_date_list = [row[0] for row in table.fetchall()]
        trade_date_list.sort()
        # 获取每只股票日线数据的日期区间
        table = session.execute(sql_trade_date_range_str)
        # 计算每只股票需要获取日线数据的日期区间
        # 获取date_from,date_to,将date_from,date_to做为value值
        code_date_range_dic = {
            key_code: (date_from, date_to)
            for key_code, date_from, date_to in table.fetchall()
            if code_set is None or key_code in code_set
        }

        # 从 info 表中查询全部日期区间
        if sql_info_str == sql_trade_date_range_str:
            code_date_range_from_info_dic = code_date_range_dic
        else:
            # 获取每只股票日线数据的日期区间
            table = session.execute(sql_info_str)
            # 计算每只股票需要获取日线数据的日期区间
            # 获取date_from,date_to,将date_from,date_to做为value值
            code_date_range_from_info_dic = {
                key_code: (date_from, date_to)
                for key_code, date_from, date_to in table.fetchall()
                if code_set is None or key_code in code_set
            }

    # data_len = len(code_date_range_dic)
    data_df_list, data_count, all_data_count, data_len = [], 0, 0, len(
        code_date_range_dic)
    logger.info('%d stocks will been import into %s', data_len, table_name)
    # 将data_df数据,添加到data_df_list

    try:
        for num, (key_code,
                  (date_from_tmp,
                   date_to_tmp)) in enumerate(code_date_range_dic.items(),
                                              start=1):
            data_df = None
            try:
                for loop_count in range(2):
                    # 根据交易日数据取交集,避免不用的请求耽误时间
                    date_from = get_first(trade_date_list,
                                          lambda x: x >= date_from_tmp)
                    date_to = get_last(trade_date_list,
                                       lambda x: x <= date_to_tmp)
                    if date_from is None or date_to is None or date_from >= date_to:
                        logger.debug('%d/%d) %s [%s - %s] 跳过', num, data_len,
                                     key_code, date_from, date_to)
                        break
                    logger.debug('%d/%d) %s [%s - %s] %s', num, data_len,
                                 key_code, date_from, date_to,
                                 '第二次查询' if loop_count > 0 else '')
                    data_df = invoke_daily(key_code=key_code,
                                           start_date=date_2_str(date_from),
                                           end_date=date_2_str(date_to))

                    # 该判断只在第一次循环时执行
                    if loop_count == 0 and has_table:
                        # 进行 factor 因子判断,如果发现最小的一个交易日的因子不为1,则删除数据库中该股票的全部历史数据,然后重新下载。
                        # 因为当期股票下载的数据为前复权价格,如果股票出现复权调整,则历史数据全部需要重新下载
                        factor_value = data_df.sort_values('trade_date').iloc[
                            0, :]['factor']
                        if factor_value != 1 and (
                                code_date_range_from_info_dic[key_code][0] !=
                                code_date_range_dic[key_code][0]):
                            # 删除该股屏历史数据
                            sql_str = f"delete from {table_name} where jq_code=:jq_code"
                            row_count = execute_sql_commit(
                                sql_str, params={'jq_code': key_code})
                            date_from_tmp, date_to_tmp = code_date_range_from_info_dic[
                                key_code]
                            if has_bak_table:
                                sql_str = f"delete from {table_name_bak} where jq_code=:jq_code"
                                row_count = execute_sql_commit(
                                    sql_str, params={'jq_code': key_code})
                                date_from_tmp, date_to_tmp = code_date_range_from_info_dic[
                                    key_code]
                                logger.info(
                                    '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s] 同时清除bak表中相应记录',
                                    num, data_len, key_code, row_count,
                                    date_from_tmp, date_to_tmp)
                            else:
                                logger.info(
                                    '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s]',
                                    num, data_len, key_code, row_count,
                                    date_from_tmp, date_to_tmp)

                            # 重新设置起止日期,进行第二次循环
                            continue

                    # 退出  for _ in range(2): 循环
                    break
            except Exception as exp:
                data_df = None
                logger.exception('%s [%s - %s]', key_code,
                                 date_2_str(date_from_tmp),
                                 date_2_str(date_to_tmp))
                if exp.args[0].find('超过了每日最大查询限制'):
                    break

            # 把数据攒起来
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df_list.append(data_df)

            # 大于阀值有开始插入
            if data_count >= 500:
                data_df_all = pd.concat(data_df_list)
                bunch_insert(data_df_all,
                             table_name,
                             dtype=DTYPE,
                             primary_keys=['jq_code', 'trade_date'])
                all_data_count += data_count
                data_df_list, data_count = [], 0

            if DEBUG and num >= 2:
                break

    finally:
        # 导入数据库
        if len(data_df_list) > 0:
            data_df_all = pd.concat(data_df_list)
            data_count = bunch_insert(data_df_all,
                                      table_name,
                                      dtype=DTYPE,
                                      primary_keys=['jq_code', 'trade_date'])
            all_data_count = all_data_count + data_count
            logging.info("更新 %s 结束 %d 条信息被更新", table_name, all_data_count)