def get_index_constituent_2_dic(index_code, index_name, date_start, idx_start,
                                date_constituent_df_dict,
                                idx_constituent_set_dic):
    """
    通过 wind 获取指数成分股及权重 存入 date_constituent_df_dict idx_constituent_set_dic
    :param index_code:
    :param index_name:
    :param date_start:
    :param idx_start:
    :param date_constituent_df_dict:
    :param idx_constituent_set_dic:
    :return:
    """
    if date_start in date_constituent_df_dict and idx_start in idx_constituent_set_dic:
        date_start_str = date_2_str(date_start)
        logger.debug('%s %s %s 成分股 已经存在 直接返回', date_start_str, index_name,
                     index_code)
        sec_df = date_constituent_df_dict[date_start]
        constituent_set = idx_constituent_set_dic[idx_start]
    else:
        sec_df = get_sectorconstituent(index_code, index_name, date_start)
        if sec_df is None or sec_df.shape[0] == 0:
            date_start_str = date_2_str(date_start)
            logger.warning('%s 无法获取到 %s %s 成分股', date_start_str, index_name,
                           index_code)
            # raise ValueError('%s 无法获取到 %s %s 成分股' % (date_start_str, index_name, index_code))
            return None, None
        date_constituent_df_dict[date_start] = sec_df
        # constituent_set = set(sec_df['wind_code'])
        constituent_set = {
            tuple(val)
            for key, val in sec_df[['wind_code', 'weight']].T.items()
        }
        idx_constituent_set_dic[idx_start] = constituent_set
    return sec_df, constituent_set
def get_sectorconstituent_2_dic(sector_code, sector_name, date_start,
                                idx_start, trade_date_list_sorted,
                                date_constituent_df_dict,
                                idx_constituent_set_dic):
    """
    通过 wind 获取板块成分股 存入 date_constituent_df_dict idx_constituent_set_dic
    :param sector_code:
    :param sector_name:
    :param date_start:
    :param idx_start:
    :param trade_date_list_sorted:
    :param date_constituent_df_dict:
    :param idx_constituent_set_dic:
    :return:
    """
    if date_start in date_constituent_df_dict and idx_start in idx_constituent_set_dic:
        date_start_str = date_2_str(date_start)
        logger.debug('%s %s %s 成分股 已经存在 直接返回', date_start_str, sector_name,
                     sector_code)
        sec_df = date_constituent_df_dict[date_start]
        constituent_set_left = idx_constituent_set_dic[idx_start]
    else:
        sec_df = get_sectorconstituent(sector_code, sector_name, date_start)
        if sec_df is None or sec_df.shape[0] == 0:
            date_start_str = date_2_str(date_start)
            logger.warning('%s 无法获取到 %s %s 成分股', date_start_str, sector_name,
                           sector_code)
            raise ValueError('%s 无法获取到 %s %s 成分股' %
                             (date_start_str, sector_name, sector_code))
        date_constituent_df_dict[date_start] = sec_df
        constituent_set_left = set(sec_df['wind_code'])
        idx_constituent_set_dic[idx_start] = constituent_set_left
    return sec_df, constituent_set_left
Пример #3
0
 def valid_model_acc(self, factor_df: pd.DataFrame):
     xs, ys, _ = self.get_x_y(factor_df)
     trade_date_from_str, trade_date_to_str = date_2_str(
         factor_df.index[0]), date_2_str(factor_df.index[-1])
     random_state = self.predict_test_random_state
     xs_train, xs_validation, ys_train, ys_validation = train_test_split(
         xs, ys, test_size=0.2, random_state=random_state)
     self.logger.debug(
         'random_state=%d, xs_train %s, ys_train %s, xs_validation %s, ys_validation %s, [%s, %s]',
         random_state, xs_train.shape, ys_train.shape, xs_validation.shape,
         ys_validation.shape, trade_date_from_str, trade_date_to_str)
     sess = self.get_session(renew=True)
     with sess.as_default():
         # with tf.Graph().as_default():
         # self.logger.debug('sess.graph:%s tf.get_default_graph():%s', sess.graph, tf.get_default_graph())
         result = self.model.evaluate(xs_validation,
                                      ys_validation,
                                      batch_size=self.batch_size)
         val_acc = result[0]
         result = self.model.evaluate(xs_train,
                                      ys_train,
                                      batch_size=self.batch_size)
         train_acc = result[0]
         self.logger.info("[%s - %s] 训练集准确率: %.2f%%", trade_date_from_str,
                          trade_date_to_str, train_acc * 100)
         self.logger.info("[%s - %s] 验证集准确率: %.2f%%", trade_date_from_str,
                          trade_date_to_str, val_acc * 100)
     return train_acc, val_acc
Пример #4
0
 def __init__(self, date_from, date_to, get_stg_handler, q_table_key=None):
     self.action_space = ['empty', 'hold_long', 'hold_short']
     self.n_actions = len(self.action_space)
     self.q_table_key = q_table_key
     self.stg_handler = None
     self.train_date_from, self.train_date_to = date_2_str(date_from), date_2_str(date_to)
     self.get_stg_handler = get_stg_handler
     self._build()
Пример #5
0
def get_wind_code_list_by_types(
        instrument_types: list,
        all_if_none=True,
        lasttrade_date_lager_than_n_days_before: Optional[int] = 30) -> list:
    """
    输入 instrument_type 列表,返回对应的所有合约列表
    :param instrument_types: 可以使 instrument_type 列表 也可以是 (instrument_type,exchange)列表
    :param all_if_none 如果 instrument_types 为 None 则返回全部合约代码
    :param lasttrade_date_lager_than_n_days_before 仅返回最后一个交易日 大于 N 日前日期的合约
    :return: wind_code_list
    """
    wind_code_list = []
    if all_if_none and instrument_types is None:
        sql_str = f"select wind_code from wind_future_info"
        with with_db_session(engine_md) as session:
            if lasttrade_date_lager_than_n_days_before is not None:
                date_from_str = date_2_str(date.today() - timedelta(
                    days=lasttrade_date_lager_than_n_days_before))
                sql_str += " where lasttrade_date > :last_trade_date"
                table = session.execute(
                    sql_str, params={"last_trade_date": date_from_str})
            else:
                table = session.execute(sql_str)

            # 获取date_from,date_to,将date_from,date_to做为value值
            for row in table.fetchall():
                wind_code = row[0]
                wind_code_list.append(wind_code)
    else:
        for instrument_type in instrument_types:
            if isinstance(instrument_type, tuple):
                instrument_type, exchange = instrument_type
            else:
                exchange = None
            # re.search(r"(?<=RB)\d{4}(?=\.SHF)", 'RB2101.SHF')
            # pattern = re.compile(r"(?<=" + instrument_type + r")\d{4}(?=\." + exchange + ")")
            # MySql: REGEXP 'rb[:digit:]+.[:alpha:]+'
            # 参考链接: https://blog.csdn.net/qq_22238021/article/details/80929518

            sql_str = f"select wind_code from wind_future_info where wind_code " \
                      f"REGEXP '^{instrument_type}[:digit:]+.{'[:alpha:]+' if exchange is None else exchange}'"
            with with_db_session(engine_md) as session:
                if lasttrade_date_lager_than_n_days_before is not None:
                    date_from_str = date_2_str(date.today() - timedelta(
                        days=lasttrade_date_lager_than_n_days_before))
                    sql_str += " and lasttrade_date > :last_trade_date"
                    table = session.execute(
                        sql_str, params={"last_trade_date": date_from_str})
                else:
                    table = session.execute(sql_str)

                # 获取date_from,date_to,将date_from,date_to做为value值
                for row in table.fetchall():
                    wind_code = row[0]
                    wind_code_list.append(wind_code)

    return wind_code_list
Пример #6
0
def get_wind_kv_per_year(wind_code, wind_indictor_str, date_from, date_to,
                         params):
    """
    \
    :param wind_code:
    :param wind_indictor_str:
    :param date_from:
    :param date_to:
    :param params: "year=%(year)d;westPeriod=180"== > "year=2018;westPeriod=180"
    :return:
    """
    date_from, date_to = str_2_date(date_from), str_2_date(date_to)
    # 以年底为分界线,将日期范围截取成以自然年为分段的日期范围
    date_pair = []
    if date_from <= date_to:
        date_curr = date_from
        while True:
            date_new_year = str_2_date("%d-01-01" % (date_curr.year + 1))
            date_year_end = date_new_year - timedelta(days=1)
            if date_to < date_year_end:
                date_pair.append((date_curr, date_to))
                break
            else:
                date_pair.append((date_curr, date_year_end))
            date_curr = date_new_year
    data_df_list = []
    for date_from_sub, date_to_sub in date_pair:
        params_sub = params % {'year': (date_from_sub.year + 1)}
        try:
            data_df = invoker.wsd(wind_code, wind_indictor_str, date_from_sub,
                                  date_to_sub, params_sub)
        except APIError as exp:
            logger.exception("%s %s [%s ~ %s] %s 执行异常", wind_code,
                             wind_indictor_str, date_2_str(date_from_sub),
                             date_2_str(date_to_sub), params_sub)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                raise exp
        if data_df is None:
            logger.warning('%s %s [%s ~ %s] has no data', wind_code,
                           wind_indictor_str, date_2_str(date_from_sub),
                           date_2_str(date_to_sub))
            continue
        data_df.dropna(inplace=True)
        if data_df.shape[0] == 0:
            # logger.warning('%s %s [%s ~ %s] has 0 data',
            #                wind_code, wind_indictor_str, date_2_str(date_from_sub), date_2_str(date_to_sub))
            continue
        data_df_list.append(data_df)

    # 合并数据
    data_df_tot = pd.concat(data_df_list) if len(data_df_list) > 0 else None
    return data_df_tot
Пример #7
0
def import_coin_info(chain_param=None):
    """获取全球交易币基本信息"""
    table_name = 'tushare_coin_info'
    has_table = engine_md.has_table(table_name)
    # 设置 dtype
    dtype = {
        'coin': String(60),
        'en_name': String(60),
        'cn_name': String(60),
        'issue_date': Date,
        'amount': DOUBLE,
    }
    coinlist_df = pro.coinlist(start_date='20170101',
                               end_date=date_2_str(date.today(),
                                                   DATE_FORMAT_STR))
    data_count = bunch_insert_on_duplicate_update(coinlist_df, table_name,
                                                  engine_md, dtype)
    logging.info("更新 %s 完成 新增数据 %d 条", table_name, data_count)

    if not has_table and engine_md.has_table(table_name):
        alter_table_2_myisam(engine_md, [table_name])
        create_pk_str = """ALTER TABLE {table_name}
            CHANGE COLUMN `coin` `coin` VARCHAR(60) NOT NULL FIRST,
            CHANGE COLUMN `en_name` `en_name` VARCHAR(60) NOT NULL AFTER `coin`,
            ADD PRIMARY KEY (`coin`, `en_name`)""".format(
            table_name=table_name)
        with with_db_session(engine_md) as session:
            session.execute(create_pk_str)
Пример #8
0
    def default(self, obj):
        # print("obj.__class__", obj.__class__, "isinstance(obj.__class__, DeclarativeMeta)",
        #   isinstance(obj.__class__, DeclarativeMeta))
        if isinstance(obj.__class__, DeclarativeMeta):
            # an SQLAlchemy class
            fields = {}
            for field in [
                    x for x in dir(obj)
                    if not x.startswith('_') and x != 'metadata'
            ]:
                data = obj.__getattribute__(field)
                try:
                    json.dumps(
                        data
                    )  # this will fail on non-encodable values, like other classes
                    fields[field] = data
                except TypeError:  # 添加了对datetime的处理
                    print(data)
                    if isinstance(data, datetime):
                        fields[field] = data.isoformat()
                    elif isinstance(data, date):
                        fields[field] = data.isoformat()
                    elif isinstance(data, timedelta):
                        fields[field] = (datetime.min +
                                         data).time().isoformat()
                    else:
                        fields[field] = None
            # a json-encodable dict
            return fields
        elif isinstance(obj, date):
            return json.dumps(date_2_str(obj))

        return json.JSONEncoder.default(self, obj)
def get_sectorconstituent(index_code, index_name, target_date) -> pd.DataFrame:
    """
    通过 wind 获取指数成分股及权重
    :param index_code:
    :param index_name:
    :param target_date:
    :return:
    """
    target_date_str = date_2_str(target_date)
    logger.info('获取 %s %s %s 板块信息', index_code, index_name, target_date)
    sec_df = invoker.wset(
        "indexconstituent",
        "date=%s;windcode=%s" % (target_date_str, index_code))
    if sec_df is not None and sec_df.shape[0] > 0:
        # 发现部分情况下返回数据的日期与 target_date 日期不匹配
        sec_df = sec_df[sec_df['date'].apply(
            lambda x: str_2_date(x) == target_date)]
    if sec_df is None or sec_df.shape[0] == 0:
        return None
    sec_df["index_code"] = index_code
    sec_df["index_name"] = index_name
    sec_df.rename(columns={
        'date': 'trade_date',
        'sec_name': 'stock_name',
        'i_weight': 'weight',
    },
                  inplace=True)
    return sec_df
Пример #10
0
    def save(self):
        self.logger.info("更新 %s 开始", self.table_name)
        has_table = engine_md.has_table(self.table_name)
        # 判断表是否已经存在
        if has_table:
            with with_db_session(engine_md) as session:
                sql_str = f"""select trade_date from jq_trade_date where trade_date>(select max(day) from {self.table_name}) order by trade_date"""
                table = session.execute(sql_str,
                                        params={"trade_date": self.BASE_DATE})
                trade_date_list = [_[0] for _ in table.fetchall()]
            date_start = execute_scalar(sql_str, engine_md)
            self.logger.info('查询 %s 数据使用起始日期 %s', self.table_name,
                             date_2_str(date_start))
        else:
            with with_db_session(engine_md) as session:
                sql_str = "select trade_date from jq_trade_date where trade_date>=:trade_date order by trade_date"
                table = session.execute(sql_str,
                                        params={"trade_date": self.BASE_DATE})
                trade_date_list = [_[0] for _ in table.fetchall()]
            self.logger.warning('%s 不存在,使用基础日期 %s', self.table_name,
                                self.BASE_DATE)

        # 查询最新的
        trade_date_list.sort()
        data_count_tot, for_count = 0, len(trade_date_list)
        try:
            for num, trade_date in enumerate(trade_date_list):
                q = query(self.statement)
                df = get_fundamentals(q, date=date_2_str(trade_date))
                if df is None or df.shape[0] == 0:
                    continue
                logger.debug('%d/%d) %s 包含 %d 条数据', num, for_count, trade_date,
                             df.shape[0])
                data_count = bunch_insert_on_duplicate_update(
                    df,
                    self.table_name,
                    engine_md,
                    dtype=self.dtype,
                    myisam_if_create_table=True,
                    primary_keys=['id'],
                    schema=config.DB_SCHEMA_MD)
                data_count_tot += data_count
        except:
            logger.exception("更新 %s 异常", self.table_name)
        finally:
            # 导入数据库
            logging.info("更新 %s 结束 %d 条信息被更新", self.table_name, data_count_tot)
Пример #11
0
    def get_df_iter(self,
                    date_start,
                    date_end,
                    step,
                    df_len_limit=3000,
                    deep=0):
        """
        获取日期范围内的数据,当数据记录大于上限条数时,将日期范围进行二分法拆分,迭代进行查询
        :param date_start:
        :param date_end:
        :param step:
        :param df_len_limit:
        :param deep:
        :return:
        """
        for num, (date_from, date_to) in enumerate(iter_2_range(
                range_date(date_start, date_end, step),
                has_left_outer=False,
                has_right_outer=False),
                                                   start=1):
            q = query(self.statement).filter(
                self.statement.pub_date > date_2_str(date_from),
                self.statement.pub_date <= date_2_str(date_to))

            df = finance.run_query(q)
            df_len = df.shape[0]
            if df_len >= df_len_limit:
                if step >= 2:
                    self.logger.warning(
                        '%s%s%d) [%s ~ %s] 包含 %d 条数据,可能已经超越 %d 条提取上限,开始进一步分割日期',
                        self.table_name, '  ' * deep, num, date_from, date_to,
                        df_len, df_len_limit)
                    yield from self.get_df_iter(date_from,
                                                date_to,
                                                step // 2,
                                                deep=deep + 1)
                else:
                    self.logger.warning(
                        '%s%s%d) [%s ~ %s] 包含 %d 条数据,可能已经超越 %d 条提取上限且无法再次分割日期范围,手动需要补充提取剩余数据',
                        self.table_name, '  ' * deep, num, date_from, date_to,
                        df_len, df_len_limit)
                    yield df, date_from, date_to
            else:
                self.logger.debug('%s%s%d) [%s ~ %s] 包含 %d 条数据',
                                  self.table_name, '  ' * deep, num, date_from,
                                  date_to, df_len)
                yield df, date_from, date_to
Пример #12
0
def stat_fund(date_from, date_to):
    sql_str = """
        SELECT (@rowNum:=@rowNum+1) AS rowNo, t.* FROM 
        (
            SELECT date_from.ts_code, basic.name, basic.management, date_from.date_from, nav_to.end_date, 
                nav_from.accum_nav nav_from, nav_to.accum_nav nav_to, nav_to.accum_nav/nav_from.accum_nav pct_chg
            FROM
            (
                SELECT ts_code, max(end_date) date_from FROM tushare_fund_nav WHERE end_date<= :date_from 
                GROUP BY ts_code
            ) date_from
            JOIN
            (
                SELECT ts_code, max(end_date) date_to FROM tushare_fund_nav WHERE end_date<= :date_to 
                GROUP BY ts_code
            ) date_to
            ON date_from.ts_code = date_to.ts_code
            JOIN tushare_fund_nav nav_from
            ON date_from.ts_code = nav_from.ts_code
            AND date_from.date_from = nav_from.end_date
            JOIN tushare_fund_nav nav_to
            ON date_to.ts_code = nav_to.ts_code
            AND date_to.date_to = nav_to.end_date
            JOIN tushare_fund_basic basic
            ON date_from.ts_code = basic.ts_code
            WHERE basic.name NOT LIKE '%B%' and basic.name NOT LIKE '%A%' and basic.name NOT LIKE '%C%'
            HAVING nav_to.accum_nav IS NOT NULL AND nav_from.accum_nav IS NOT NULL and pct_chg != 1 and pct_chg < 2
            ORDER BY nav_to.accum_nav/nav_from.accum_nav
        ) t"""
    # data_df = pd.read_sql(sql_str, engine_md)
    with with_db_session(engine_md) as session:
        session.execute("Select (@rowNum :=0) ;")
        table = session.execute(sql_str,
                                params={
                                    'date_from': date_2_str(date_from),
                                    'date_to': date_2_str(date_to)
                                })
        data = [[d for d in row] for row in table.fetchall()]
    data_df = pd.DataFrame(data,
                           columns=[
                               'rowNo', 'ts_code', 'name', 'management',
                               'date_from', 'date_to', 'nav_from', 'nav_to',
                               'pct_chg'
                           ])
    return data_df.describe()['pct_chg']
Пример #13
0
def trade_date_list(file_path=None):
    if file_path is None:
        file_path = get_export_path('trade_date.csv')

    with with_db_session_p() as session:
        sql_str = "select cal_date from tushare_trade_date where exchange='SSE' and is_open=1"
        table = session.execute(sql_str)
        ret_list = [date_2_str(_[0]) for _ in table.fetchall()]

    pd.DataFrame({'trade_date': ret_list}).to_csv(file_path, index=False)
Пример #14
0
def save_adj_factor(instrument_type: str,
                    method: ReversionRightsMethod,
                    db_table_name='wind_future_adj_factor',
                    to_csv_dir_path=None,
                    generate_reversion_rights_factors_func:
                    Callable = generate_reversion_rights_factors):
    """

    :param instrument_type: 合约类型
    :param method: 合约类型
    :param db_table_name: 保存到数据库名称,None 为不保存数据库
    :param to_csv_dir_path: 是否保存到csv文件
    :param generate_reversion_rights_factors_func: 生成复权因子的函数
    :return:
    """
    logger.info("生成 %s 复权因子[%s]", instrument_type, method.name)
    adj_factor_df, trade_date_latest = generate_reversion_rights_factors_func(
        instrument_type, method=method)
    if adj_factor_df is None:
        return

    if to_csv_dir_path is not None:
        csv_file_name = f'adj_factor_{instrument_type}_{method.name}.csv'
        folder_path = os.path.join(to_csv_dir_path,
                                   date_2_str(trade_date_latest))
        csv_file_path = os.path.join(folder_path, csv_file_name)
        os.makedirs(folder_path, exist_ok=True)
        adj_factor_df.to_csv(csv_file_path, index=False)

    if db_table_name is not None:
        dtype = {
            'trade_date': Date,
            'instrument_id_main': String(20),
            'adj_factor_main': DOUBLE,
            'instrument_id_secondary': String(20),
            'adj_factor_secondary': DOUBLE,
            'instrument_type': String(20),
            'method': String(20),
        }
        adj_factor_df['method'] = method.name
        update_df_2_db(instrument_type,
                       db_table_name,
                       adj_factor_df,
                       method=method,
                       dtype=dtype)

    logger.info(
        "生成 %s 复权因子 %d 条记录[%s]",  # \n%s
        instrument_type,
        adj_factor_df.shape[0],
        method.name
        # , adj_factor_df
    )
Пример #15
0
    def init_state(self, md_df: pd.DataFrame):
        # 每一次新 episode 需要重置 state, action
        self.last_state = None
        self.last_action = None
        if self.do_train:
            # for_train == False 当期为策略运行使用,在 on_prepare 阶段以及 on_period 定期进行重新训练
            trade_date_s = md_df['trade_date']
            self.train_date_from = pd.to_datetime(trade_date_s.iloc[0]) + self.retrain_period
            trade_date_to = trade_date_s.iloc[-1]
            self.train(date_2_str(trade_date_to))

        _ = self.choose_action(md_df)
Пример #16
0
def import_jq_stock_income(chain_param=None, ts_code_set=None):
    """
    插入股票日线数据到最近一个工作日-1。
    如果超过 BASE_LINE_HOUR 时间,则获取当日的数据
    :return:
    """
    logger.info("更新 %s 开始", TABLE_NAME)
    has_table = engine_md.has_table(TABLE_NAME)
    # 判断表是否已经存在
    if has_table:
        sql_str = f"""select max(pub_date) from {TABLE_NAME}"""
        date_start = execute_scalar(sql_str, engine_md)
        logger.info('查询 %s 数据使用起始日期 %s', TABLE_NAME, date_2_str(date_start))
    else:
        date_start = BASE_DATE
        logger.warning('%s 不存在,使用基础日期 %s', TABLE_NAME, date_2_str(date_start))

    # 查询最新的 pub_date
    date_end = datetime.date.today()
    if date_start >= date_end:
        logger.info('%s 已经是最新数据,无需进一步获取', date_start)
        return
    data_count_tot = 0
    try:
        for num, (df, date_from, date_to) in enumerate(
                get_df_iter(date_start, date_end, LOOP_STEP)):
            # logger.debug('%d) [%s ~ %s] 包含 %d 条数据', num, date_from, date_to, df.shape[0])
            data_count = bunch_insert_on_duplicate_update(
                df,
                TABLE_NAME,
                engine_md,
                dtype=DTYPE,
                myisam_if_create_table=True,
                primary_keys=['id'],
                schema=config.DB_SCHEMA_MD)
            data_count_tot += data_count
    finally:
        # 导入数据库
        logging.info("更新 %s 结束 %d 条信息被更新", TABLE_NAME, data_count_tot)
Пример #17
0
def save_adj_factor(instrument_types: list, to_db=True, to_csv=True):
    """

    :param instrument_types: 合约类型
    :param to_db: 是否保存到数据库
    :param to_csv: 是否保存到csv文件
    :param method: division 除法  diff  差值发
    :return:
    """
    dir_path = 'output'
    if to_csv:
        # 建立 output folder
        os.makedirs(dir_path, exist_ok=True)

    for method in Method:
        for n, instrument_type in enumerate(instrument_types):
            logger.info("生成 %s 复权因子", instrument_type)
            adj_factor_df, trade_date_latest = generate_reversion_rights_factors(
                instrument_type, method=method)
            if adj_factor_df is None:
                continue

            if to_csv:
                csv_file_name = f'adj_factor_{instrument_type}_{method.name}.csv'
                folder_path = os.path.join(dir_path,
                                           date_2_str(trade_date_latest))
                csv_file_path = os.path.join(folder_path, csv_file_name)
                os.makedirs(folder_path, exist_ok=True)
                adj_factor_df.to_csv(csv_file_path, index=False)

            if to_db:
                table_name = 'wind_future_adj_factor'
                dtype = {
                    'trade_date': Date,
                    'instrument_id_main': String(20),
                    'adj_factor_main': DOUBLE,
                    'instrument_id_secondary': String(20),
                    'adj_factor_secondary': DOUBLE,
                    'instrument_type': String(20),
                    'method': String(20),
                }
                adj_factor_df['method'] = method.name
                update_df_2_db(instrument_type, table_name, adj_factor_df,
                               dtype)

            logger.info(
                "生成 %s 复权因子 %s 条记录",  # \n%s
                instrument_type,
                adj_factor_df.shape[0]
                # , adj_factor_df
            )
Пример #18
0
    def save(self):
        self.logger.info("更新 %s 开始", self.table_name)
        has_table = engine_md.has_table(self.table_name)
        # 判断表是否已经存在
        if has_table:
            sql_str = f"""select max(pub_date) from {self.table_name}"""
            date_start = execute_scalar(sql_str, engine_md)
            self.logger.info('查询 %s 数据使用起始日期 %s', self.table_name,
                             date_2_str(date_start))
        else:
            date_start = self.BASE_DATE
            self.logger.warning('%s 不存在,使用基础日期 %s', self.table_name,
                                date_2_str(date_start))

        # 查询最新的 pub_date
        date_end = datetime.date.today()
        if date_start >= date_end:
            self.logger.info('%s %s 已经是最新数据,无需进一步获取', self.table_name,
                             date_start)
            return
        data_count_tot = 0
        try:
            for num, (df, date_from, date_to) in enumerate(
                    self.get_df_iter(date_start, date_end, self.loop_step)):
                # logger.debug('%d) [%s ~ %s] 包含 %d 条数据', num, date_from, date_to, df.shape[0])
                if df is not None and df.shape[0] > 0:
                    data_count = bunch_insert_on_duplicate_update(
                        df,
                        self.table_name,
                        engine_md,
                        dtype=self.dtype,
                        myisam_if_create_table=True,
                        primary_keys=['id'],
                        schema=config.DB_SCHEMA_MD)
                    data_count_tot += data_count
        finally:
            # 导入数据库
            logging.info("更新 %s 结束 %d 条信息被更新", self.table_name, data_count_tot)
Пример #19
0
    def choose_action(self, md_df: pd.DataFrame):
        trade_date_to = pd.to_datetime(md_df['trade_date'].iloc[-1])
        if self.do_train and trade_date_to > (self.train_date_latest + self.retrain_period):
            # for_train == False 当期为策略运行使用,在 on_prepare 阶段以及 on_period 定期进行重新训练
            self.train(date_2_str(trade_date_to))
        elif self.ql_table is None:
            self.init_ql_table(trade_date_to)

        state, reward = self.get_state_reward(md_df)
        if self.last_state is not None:
            self.ql_table.learn(self.last_state, self.last_action, reward, state)

        self.last_action = self.ql_table.choose_action(state)
        self.last_state = state
        return self.last_action
Пример #20
0
 def save_model(self, trade_date):
     """
     将模型导出到文件
     :param trade_date:
     :return:
     """
     folder_path = os.path.join(self.model_folder_path, date_2_str(trade_date))
     if not os.path.exists(folder_path):
         os.makedirs(folder_path)
     file_path = os.path.join(
         folder_path,
         f"model_{int(self.label_func_min_rr * 10000)}_{int(self.label_func_max_rr * 10000)}.tfl")
     self.model.save(file_path)
     self.logger.info("模型训练截止日期: %s 保存到: %s", self.trade_date_last_train, file_path)
     return file_path
def get_sectorconstituent(sector_code, sector_name,
                          target_date) -> pd.DataFrame:
    """
    通过 wind 获取板块成分股
    :param sector_code:
    :param sector_name:
    :param target_date:
    :return:
    """
    target_date_str = date_2_str(target_date)
    logger.info('获取 %s %s %s 板块信息', sector_code, sector_name, target_date)
    sec_df = invoker.wset(
        "sectorconstituent",
        "date=%s;sectorid=%s" % (target_date_str, sector_code))
    sec_df["sector_code"] = sector_code
    sec_df["sector_name"] = sector_name
    sec_df.rename(columns={
        'date': 'trade_date',
        'sec_name': 'stock_name',
    },
                  inplace=True)
    return sec_df
Пример #22
0
def df_2_table(doc,
               df,
               format_by_index=None,
               format_by_col=None,
               max_col_count=None,
               mark_top_n=None,
               mark_top_n_on_cols=None):
    """

    :param doc:
    :param df:
    :param format_by_index: 按索引格式化
    :param format_by_col: 按列格式化
    :param max_col_count: 每行最大列数(不包括索引)
    :param mark_top_n: 标记 top N
    :param mark_top_n_on_cols: 选择哪些列标记 top N,None 代表不标记
    :return:
    """
    if max_col_count is None:
        max_col_count = df.shape[1]

    if mark_top_n is not None:
        if mark_top_n_on_cols is not None:
            rank_df = df[mark_top_n_on_cols]
        else:
            rank_df = df
        rank_df = rank_df.rank(ascending=False)
        is_in_rank_df = rank_df <= mark_top_n
    else:
        is_in_rank_df = None

    for table_num, col_name_list in enumerate(
            split_chunk(list(df.columns), max_col_count)):
        if table_num > 0:
            # 如果是换行写入第二、三、四。。个表格,先打一个空行
            doc.add_paragraph('')

        sub_df = df[col_name_list]
        row_num, col_num = sub_df.shape
        t = doc.add_table(row_num + 1, col_num + 1)

        # write head
        # col_name_list = list(sub_df.columns)
        for j in range(col_num):
            # t.cell(0, j).text = df.columns[j]
            # paragraph = t.cell(0, j).add_paragraph()
            paragraph = t.cell(0, j + 1).paragraphs[0]
            paragraph.add_run(str(col_name_list[j])).bold = True
            paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER

        # write head bg color
        for j in range(col_num + 1):
            # t.cell(0, j).text = df.columns[j]
            t.cell(0, j)._tc.get_or_add_tcPr().append(
                parse_xml(r'<w:shd {} w:fill="00A2E8"/>'.format(nsdecls('w'))))

        # format table style to be a grid
        t.style = 'TableGrid'

        # populate the table with the dataframe
        for i in range(row_num):
            index = sub_df.index[i]
            paragraph = t.cell(i + 1, 0).paragraphs[0]
            index_str = str(date_2_str(index))
            paragraph.add_run(index_str).bold = True
            paragraph.alignment = WD_ALIGN_PARAGRAPH.LEFT
            if format_by_index is not None and index in format_by_index:
                format_row = format_by_index[index]
            else:
                format_row = None

            for j in range(col_num):
                col_name = col_name_list[j]
                if format_row is None and format_by_col is not None and col_name in format_by_col:
                    format_cell = format_by_col[col_name]
                else:
                    format_cell = format_row

                content = sub_df.values[i, j]
                if format_cell is None:
                    text = str(content)
                elif isinstance(format_cell, str):
                    text = str.format(format_cell, content)
                elif callable(format_cell):
                    text = format_cell(content)
                else:
                    raise ValueError('%s: %s 无效', index, format_cell)

                paragraph = t.cell(i + 1, j + 1).paragraphs[0]
                paragraph.alignment = WD_ALIGN_PARAGRAPH.RIGHT
                try:
                    style = paragraph.add_run(text)
                    if is_in_rank_df is not None and col_name in is_in_rank_df and is_in_rank_df.loc[
                            index, col_name]:
                        style.font.color.rgb = RGBColor(0xed, 0x1c, 0x24)
                        style.bold = True
                except TypeError as exp:
                    logger.exception('df.iloc[%d, %d] = df["%s", "%s"] = %s',
                                     i, j, index, col_name, text)
                    raise exp from exp

        for i in range(1, row_num + 1):
            for j in range(col_num + 1):
                if i % 2 == 0:
                    t.cell(i, j)._tc.get_or_add_tcPr().append(
                        parse_xml(r'<w:shd {} w:fill="A3D9EA"/>'.format(
                            nsdecls('w'))))
Пример #23
0
def import_index_daily(chain_param=None):
    """导入指数数据
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    table_name = "wind_index_daily"
    has_table = engine_md.has_table(table_name)
    col_name_param_list = [
        ('open', DOUBLE),
        ('high', DOUBLE),
        ('low', DOUBLE),
        ('close', DOUBLE),
        ('volume', DOUBLE),
        ('amt', DOUBLE),
        ('turn', DOUBLE),
        ('free_turn', DOUBLE),
    ]
    wind_indictor_str = ",".join([key for key, _ in col_name_param_list])
    rename_col_dic = {
        key.upper(): key.lower()
        for key, _ in col_name_param_list
    }
    dtype = {key: val for key, val in col_name_param_list}
    dtype['wind_code'] = String(20)
    # TODO: 'trade_date' 声明为 Date 类型后,插入数据库会报错,目前原因不详,日后再解决
    # dtype['trade_date'] = Date,

    # yesterday = date.today() - timedelta(days=1)
    # date_ending = date.today() - ONE_DAY if datetime.now().hour < BASE_LINE_HOUR else date.today()
    # sql_str = """select wii.wind_code, wii.sec_name, ifnull(adddate(latest_date, INTERVAL 1 DAY), wii.basedate) date_from
    #     from wind_index_info wii left join
    #     (
    #         select wind_code,index_name, max(trade_date) as latest_date
    #         from wind_index_daily group by wind_code
    #     ) daily
    #     on wii.wind_code=daily.wind_code"""
    # with with_db_session(engine_md) as session:
    #     table = session.execute(sql_str)
    #     wind_code_date_from_dic = {wind_code: (sec_name, date_from) for wind_code, sec_name, date_from in table.fetchall()}
    # with with_db_session(engine_md) as session:
    #     # 获取市场有效交易日数据
    #     sql_str = "select trade_date from wind_trade_date where trade_date > '2005-1-1'"
    #     table = session.execute(sql_str)
    #     trade_date_sorted_list = [t[0] for t in table.fetchall()]
    #     trade_date_sorted_list.sort()
    # date_to = get_last(trade_date_sorted_list, lambda x: x <= date_ending)
    # data_len = len(wind_code_date_from_dic)
    if has_table:
        sql_str = """
              SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to
              FROM
              (
                  SELECT info.wind_code, ifnull(trade_date, basedate) date_frm, null,
                  if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                  FROM 
                      wind_index_info info 
                  LEFT OUTER JOIN
                      (SELECT wind_code, adddate(max(trade_date),1) trade_date FROM {table_name} GROUP BY wind_code) daily
                  ON info.wind_code = daily.wind_code
              ) tt
              WHERE date_frm <= if(null<end_date, null, end_date) 
              ORDER BY wind_code""".format(table_name=table_name)
    else:
        logger.warning('%s 不存在,仅使用 wind_index_info 表进行计算日期范围', table_name)
        sql_str = """
              SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to
              FROM
              (
                  SELECT info.wind_code, basedate date_frm, null,
                  if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                  FROM wind_index_info info 
              ) tt
              WHERE date_frm <= if(null<end_date, null, end_date) 
              ORDER BY wind_code;"""

    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 获取每只股票需要获取日线数据的日期区间
        begin_time = None
        wind_code_date_from_dic = {
            wind_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for wind_code, date_from, date_to in table.fetchall()
            if wind_code_set is None or wind_code in wind_code_set
        }

    data_len = len(wind_code_date_from_dic)

    logger.info('%d indexes will been import', data_len)
    for data_num, (wind_code,
                   (date_from,
                    date_to)) in enumerate(wind_code_date_from_dic.items()):
        if str_2_date(date_from) > date_to:
            logger.warning("%d/%d) %s %s - %s 跳过", data_num, data_len,
                           wind_code, date_from, date_to)
            continue
        try:
            temp = invoker.wsd(wind_code, wind_indictor_str, date_from,
                               date_to)
        except APIError as exp:
            logger.exception("%d/%d) %s 执行异常", data_num, data_len, wind_code)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                break
        temp.reset_index(inplace=True)
        temp.rename(columns={'index': 'trade_date'}, inplace=True)
        temp.rename(columns=rename_col_dic, inplace=True)
        temp.trade_date = temp.trade_date.apply(str_2_date)
        temp['wind_code'] = wind_code
        bunch_insert_on_duplicate_update(temp,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
        logger.info('更新指数 %s 至 %s 成功', wind_code, date_2_str(date_to))
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])
Пример #24
0
def import_jq_stock_daily(chain_param=None, code_set=None):
    """
    插入股票日线数据到最近一个工作日-1。
    如果超过 BASE_LINE_HOUR 时间,则获取当日的数据
    :return:
    """
    table_name_info = TABLE_NAME_INFO
    table_name = TABLE_NAME
    table_name_bak = get_bak_table_name(table_name)
    logging.info("更新 %s 开始", table_name)

    # 根据 info table 查询每只股票日期区间
    sql_info_str = f"""
        SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to
        FROM
          (
            SELECT info.jq_code, start_date date_frm, end_date date_to,
            if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
            FROM {table_name_info} info
          ) tt
        WHERE date_frm <= if(date_to<end_date, date_to, end_date)
        ORDER BY jq_code"""

    has_table = engine_md.has_table(table_name)
    has_bak_table = engine_md.has_table(table_name_bak)
    # 进行表格判断,确定是否含有 jq_stock_daily_md
    if has_table:
        # 这里对原始的 sql语句进行了调整
        # 以前的逻辑:每只股票最大的一个交易日+1天作为起始日期
        # 现在的逻辑:每只股票最大一天的交易日作为起始日期
        # 主要原因在希望通过接口获取到数据库中现有最大交易日对应的 factor因子以进行比对
        sql_trade_date_range_str = f"""
            SELECT jq_code, date_frm, if(date_to<end_date, date_to, end_date) date_to
            FROM
            (
                SELECT info.jq_code, ifnull(trade_date, info.start_date) date_frm, info.end_date date_to,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM 
                    {table_name_info} info 
                LEFT OUTER JOIN
                    (SELECT jq_code, max(trade_date) trade_date FROM {table_name} GROUP BY jq_code) daily
                ON info.jq_code = daily.jq_code
            ) tt
            WHERE date_frm < if(date_to<end_date, date_to, end_date) 
            ORDER BY jq_code"""

    else:
        sql_trade_date_range_str = sql_info_str
        logger.warning('%s 不存在,仅使用 %s 表进行计算日期范围', table_name, table_name_info)

    sql_trade_date_str = """SELECT trade_date FROM jq_trade_date trddate 
       WHERE trade_date <= if(hour(now())<16, subdate(curdate(),1), curdate()) ORDER BY trade_date"""

    with with_db_session(engine_md) as session:
        # 获取截至当期全部交易日前
        table = session.execute(sql_trade_date_str)
        trade_date_list = [row[0] for row in table.fetchall()]
        trade_date_list.sort()
        # 获取每只股票日线数据的日期区间
        table = session.execute(sql_trade_date_range_str)
        # 计算每只股票需要获取日线数据的日期区间
        # 获取date_from,date_to,将date_from,date_to做为value值
        code_date_range_dic = {
            key_code: (date_from, date_to)
            for key_code, date_from, date_to in table.fetchall()
            if code_set is None or key_code in code_set
        }

        # 从 info 表中查询全部日期区间
        if sql_info_str == sql_trade_date_range_str:
            code_date_range_from_info_dic = code_date_range_dic
        else:
            # 获取每只股票日线数据的日期区间
            table = session.execute(sql_info_str)
            # 计算每只股票需要获取日线数据的日期区间
            # 获取date_from,date_to,将date_from,date_to做为value值
            code_date_range_from_info_dic = {
                key_code: (date_from, date_to)
                for key_code, date_from, date_to in table.fetchall()
                if code_set is None or key_code in code_set
            }

    # data_len = len(code_date_range_dic)
    data_df_list, data_count, all_data_count, data_len = [], 0, 0, len(
        code_date_range_dic)
    logger.info('%d stocks will been import into %s', data_len, table_name)
    # 将data_df数据,添加到data_df_list

    try:
        for num, (key_code,
                  (date_from_tmp,
                   date_to_tmp)) in enumerate(code_date_range_dic.items(),
                                              start=1):
            data_df = None
            try:
                for loop_count in range(2):
                    # 根据交易日数据取交集,避免不用的请求耽误时间
                    date_from = get_first(trade_date_list,
                                          lambda x: x >= date_from_tmp)
                    date_to = get_last(trade_date_list,
                                       lambda x: x <= date_to_tmp)
                    if date_from is None or date_to is None or date_from >= date_to:
                        logger.debug('%d/%d) %s [%s - %s] 跳过', num, data_len,
                                     key_code, date_from, date_to)
                        break
                    logger.debug('%d/%d) %s [%s - %s] %s', num, data_len,
                                 key_code, date_from, date_to,
                                 '第二次查询' if loop_count > 0 else '')
                    data_df = invoke_daily(key_code=key_code,
                                           start_date=date_2_str(date_from),
                                           end_date=date_2_str(date_to))

                    # 该判断只在第一次循环时执行
                    if loop_count == 0 and has_table:
                        # 进行 factor 因子判断,如果发现最小的一个交易日的因子不为1,则删除数据库中该股票的全部历史数据,然后重新下载。
                        # 因为当期股票下载的数据为前复权价格,如果股票出现复权调整,则历史数据全部需要重新下载
                        factor_value = data_df.sort_values('trade_date').iloc[
                            0, :]['factor']
                        if factor_value != 1 and (
                                code_date_range_from_info_dic[key_code][0] !=
                                code_date_range_dic[key_code][0]):
                            # 删除该股屏历史数据
                            sql_str = f"delete from {table_name} where jq_code=:jq_code"
                            row_count = execute_sql_commit(
                                sql_str, params={'jq_code': key_code})
                            date_from_tmp, date_to_tmp = code_date_range_from_info_dic[
                                key_code]
                            if has_bak_table:
                                sql_str = f"delete from {table_name_bak} where jq_code=:jq_code"
                                row_count = execute_sql_commit(
                                    sql_str, params={'jq_code': key_code})
                                date_from_tmp, date_to_tmp = code_date_range_from_info_dic[
                                    key_code]
                                logger.info(
                                    '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s] 同时清除bak表中相应记录',
                                    num, data_len, key_code, row_count,
                                    date_from_tmp, date_to_tmp)
                            else:
                                logger.info(
                                    '%d/%d) %s %d 条历史记录被清除,重新加载前复权历史数据 [%s - %s]',
                                    num, data_len, key_code, row_count,
                                    date_from_tmp, date_to_tmp)

                            # 重新设置起止日期,进行第二次循环
                            continue

                    # 退出  for _ in range(2): 循环
                    break
            except Exception as exp:
                data_df = None
                logger.exception('%s [%s - %s]', key_code,
                                 date_2_str(date_from_tmp),
                                 date_2_str(date_to_tmp))
                if exp.args[0].find('超过了每日最大查询限制'):
                    break

            # 把数据攒起来
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df_list.append(data_df)

            # 大于阀值有开始插入
            if data_count >= 500:
                data_df_all = pd.concat(data_df_list)
                bunch_insert(data_df_all,
                             table_name,
                             dtype=DTYPE,
                             primary_keys=['jq_code', 'trade_date'])
                all_data_count += data_count
                data_df_list, data_count = [], 0

            if DEBUG and num >= 2:
                break

    finally:
        # 导入数据库
        if len(data_df_list) > 0:
            data_df_all = pd.concat(data_df_list)
            data_count = bunch_insert(data_df_all,
                                      table_name,
                                      dtype=DTYPE,
                                      primary_keys=['jq_code', 'trade_date'])
            all_data_count = all_data_count + data_count
            logging.info("更新 %s 结束 %d 条信息被更新", table_name, all_data_count)
Пример #25
0
def transfer_2_batch(df: pd.DataFrame,
                     n_step,
                     labels=None,
                     date_from=None,
                     date_to=None):
    """
    [num, factor_count] -> [num - n_step + 1, n_step, factor_count]
    将 df 转化成 n_step 长度的一段一段的数据
    labels 为与 df对应的数据,处理方式与index相同,如果labels不为空,则返回数据最后增加以下 new_ys
    :param df:
    :param n_step:
    :param labels:如果不为 None,则长度必须与 df.shape[0] 一致
    :param date_from:
    :param date_to:
    :return:
    """
    df_len = df.shape[0]
    if labels is not None and df_len != len(labels):
        raise ValueError("ys 长度 %d 必须与 df 长度 %d 保持一致", len(labels), df_len)
    # TODO: date_from, date_to 的逻辑可以进一步优化,延期为了省时间先保持这样
    # 根据 date_from 对factor进行截取
    if date_from is not None:
        date_from = pd.to_datetime(date_from)
        is_fit = df.index >= date_from
        if np.any(is_fit):
            start_idx = np.argmax(is_fit) - n_step

            if start_idx < 0:
                start_idx = 0
                logger.warning("%s 为起始日期的数据,前向历史数据不足 %d 条,因此,起始日期向后推移至 %s",
                               date_2_str(date_from), n_step,
                               date_2_str(df.index[60]))

            df = df.iloc[start_idx:]
            df_len = df.shape[0]
            if labels is not None:
                labels = labels[start_idx:]

        else:
            logger.warning("没有 %s 之后的数据,当前数据最晚日期为 %s", date_2_str(date_from),
                           date_2_str(max(df.index)))
            if labels is not None:
                return None, None, None, None
            else:
                return None, None, None

    # 根据 date_from 对factor进行截取
    if date_to is not None:
        date_to = pd.to_datetime(date_to)
        is_fit = df.index <= date_to
        if np.any(is_fit):
            to_idx = np.argmin(is_fit)
            df = df.iloc[:to_idx]
            df_len = df.shape[0]
            if labels is not None:
                labels = labels[:to_idx]

        else:
            logger.warning("没有 %s 之前的数据,当前数据最晚日期为 %s", date_2_str(date_to),
                           date_2_str(min(df.index)))
            if labels is not None:
                return None, None, None, None
            else:
                return None, None, None

    new_shape = [df_len - n_step + 1, n_step]
    new_shape.extend(df.shape[1:])
    df_index, df_columns = df.index[(n_step - 1):], df.columns
    data_arr_batch, factor_arr = np.zeros(new_shape), df.to_numpy(
        dtype=np.float32)

    for idx_from, idx_to in enumerate(range(n_step, factor_arr.shape[0] + 1)):
        data_arr_batch[idx_from] = factor_arr[idx_from:idx_to]

    if labels is not None:
        new_ys = labels[(n_step - 1):]
        return df_index, df_columns, data_arr_batch, new_ys
    else:
        return df_index, df_columns, data_arr_batch
Пример #26
0
def merge_tushare_stock_daily(ths_code_set: set = None, date_from=None):
    """A股行情数据、财务信息 合并成为到 日级别数据"""
    table_name = 'tushare_stock_daily'
    logging.info("合成 %s 开始", table_name)
    has_table = engine_md.has_table(table_name)
    if date_from is None and has_table:
        sql_str = "select adddate(max(`trade_date`),1) from {table_name}".format(table_name=table_name)
        with with_db_session(engine_md) as session:
            date_from = date_2_str(session.execute(sql_str).scalar())

    # 获取日级别数据
    # TODO: 增加 ths_code_set 参数
    daily_df, dtype_daily = get_tushare_daily_merged_df(ths_code_set, date_from)

    daily_df_g = daily_df.groupby('ts_code')
    ths_code_set_4_daily = set(daily_df_g.size().index)

    # 获取合并后的财务数据
    ifind_fin_df, dtype_fin = get_tushre_merge_stock_fin_df()

    # 整理 dtype
    dtype = dtype_daily.copy()
    dtype.update(dtype_fin)
    logging.debug("提取财务数据完成")
    # 计算 财报披露时间
    report_date_dic_dic = {}
    for num, ((ths_code, report_date), data_df) in enumerate(
            ifind_fin_df.groupby(['ts_code', 'f_ann_date']), start=1):
        if ths_code_set is not None and ths_code not in ths_code_set:
            continue
        if is_nan_or_none(report_date):
            continue
        report_date_dic = report_date_dic_dic.setdefault(ths_code, {})
        if report_date not in report_date_dic_dic:
            if data_df.shape[0] > 0:
                report_date_dic[report_date] = data_df.iloc[0]

    logger.debug("计算财报日期完成")
    # 整理 data_df 数据
    tot_data_count, data_count, data_df_list, for_count = 0, 0, [], len(report_date_dic_dic)
    try:
        for num, (ths_code, report_date_dic) in enumerate(report_date_dic_dic.items(), start=1):  # key:ths_code
            # TODO: 檢查判斷 ths_code 是否存在在ifind_fin_df_g 裏面,,size暫時使用  以後在驚醒改進
            if ths_code not in ths_code_set_4_daily:
                logger.error('fin 表中不存在 %s 的財務數據', ths_code)
                continue

            daily_df_cur_ts_code = daily_df_g.get_group(ths_code)
            logger.debug('%d/%d) 处理 %s %d 条数据', num, for_count, ths_code, daily_df_cur_ts_code.shape[0])
            report_date_list = list(report_date_dic.keys())
            report_date_list.sort()
            report_date_list_len = len(report_date_list)
            for num_sub, (report_date_from, report_date_to) in enumerate(iter_2_range(report_date_list)):
                logger.debug('%d/%d) %d/%d) 处理 %s [%s - %s]',
                             num, for_count, num_sub, report_date_list_len,
                             ths_code, date_2_str(report_date_from), date_2_str(report_date_to))
                # 计算有效的日期范围
                if report_date_from is None:
                    is_fit = daily_df_cur_ts_code['trade_date'] < report_date_to
                elif report_date_to is None:
                    is_fit = daily_df_cur_ts_code['trade_date'] >= report_date_from
                else:
                    is_fit = (daily_df_cur_ts_code['trade_date'] < report_date_to) & (
                            daily_df_cur_ts_code['trade_date'] >= report_date_from)
                # 获取日期范围内的数据
                ifind_his_ds_df_segment = daily_df_cur_ts_code[is_fit].copy()
                segment_count = ifind_his_ds_df_segment.shape[0]
                if segment_count == 0:
                    continue
                fin_s = report_date_dic[report_date_from] if report_date_from is not None else None
                for key in dtype_fin.keys():
                    if key in ('ts_code', 'trade_date'):
                        continue
                    ifind_his_ds_df_segment[key] = fin_s[key] if fin_s is not None and key in fin_s else None

                ifind_his_ds_df_segment['report_date'] = report_date_from
                # 添加数据到列表
                data_df_list.append(ifind_his_ds_df_segment)
                data_count += segment_count

            if DEBUG and len(data_df_list) > 1:
                break

            # 保存数据库
            if data_count > 1000:
                # 保存到数据库
                data_df = pd.concat(data_df_list)
                data_count = bunch_insert_on_duplicate_update(
                    data_df, table_name, engine_md, dtype, myisam_if_create_table=True)
                tot_data_count += data_count
                data_count, data_df_list = 0, []

    finally:
        # 保存到数据库
        if len(data_df_list) > 0:
            data_df = pd.concat(data_df_list)
            data_count = bunch_insert_on_duplicate_update(
                data_df, table_name, engine_md, dtype, myisam_if_create_table=True)
            tot_data_count += data_count

        logger.info('%s 新增或更新记录 %d 条', table_name, tot_data_count)
        if not has_table and engine_md.has_table(table_name):
            build_primary_key([table_name])
Пример #27
0
def invoke_index_weight(index_code, trade_date):
    trade_date = date_2_str(trade_date, STR_FORMAT_DATE_TS)
    invoke_index_weight = pro.index_weight(index_code=index_code,
                                           trade_date=trade_date)
    return invoke_index_weight
Пример #28
0
def import_future_info(chain_param=None):
    """
    更新期货合约列表信息
    :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用
    :return:
    """
    table_name = 'ifind_future_info'
    has_table = engine_md.has_table(table_name)
    logger.info("更新 %s [%s] 开始", table_name, has_table)
    # 获取已存在合约列表
    if has_table:
        sql_str = f'SELECT ths_code, ths_start_trade_date_future FROM {table_name}'
        with with_db_session(engine_md) as session:
            table = session.execute(sql_str)
            code_ipo_date_dic = dict(table.fetchall())
        exchange_latest_ipo_date_dic = get_exchange_latest_data()
    else:
        code_ipo_date_dic = {}
        exchange_latest_ipo_date_dic = {}

    exchange_sectorid_dic_list = [
        {
            'exch_eng': 'SHFE',
            'exchange_name': '上海期货交易所',
            'sectorid': '091001',
            'date_establish': '1995-05-10'
        },
        {
            'exch_eng': 'CFFEX',
            'exchange_name': '中国金融期货交易所',
            'sectorid': '091004',
            'date_establish': '2013-09-10'
        },
        {
            'exch_eng': 'DCE',
            'exchange_name': '大连商品交易所',
            'sectorid': '091002',
            'date_establish': '1999-01-10'
        },
        {
            'exch_eng': 'CZCE',
            'exchange_name': '郑州商品交易所',
            'sectorid': '091003',
            'date_establish': '1999-01-10'
        },
    ]

    # 字段列表及参数
    indicator_param_list = [
        ('ths_future_short_name_future', '', String(50)),
        ('ths_future_code_future', '', String(20)),
        ('ths_sec_type_future', '', String(20)),
        ('ths_td_variety_future', '', String(20)),
        ('ths_td_unit_future', '', DOUBLE),
        ('ths_pricing_unit_future', '', String(20)),
        ('ths_mini_chg_price_future', '', DOUBLE),
        ('ths_chg_ratio_lmit_future', '', DOUBLE),
        ('ths_td_deposit_future', '', DOUBLE),
        ('ths_start_trade_date_future', '', Date),
        ('ths_last_td_date_future', '', Date),
        ('ths_last_delivery_date_future', '', Date),
        ('ths_delivery_month_future', '', String(10)),
        ('ths_listing_benchmark_price_future', '', DOUBLE),
        ('ths_initial_td_deposit_future', '', DOUBLE),
        ('ths_contract_month_explain_future', '', String(120)),
        ('ths_td_time_explain_future', '', String(120)),
        ('ths_last_td_date_explian_future', '', String(120)),
        ('ths_delivery_date_explain_future', '', String(120)),
        ('ths_exchange_short_name_future', '', String(50)),
        ('ths_contract_en_short_name_future', '', String(50)),
        ('ths_contract_en_name_future', '', String(50)),
    ]
    json_indicator, json_param = unzip_join(
        [(key, val) for key, val, _ in indicator_param_list], sep=';')

    # 设置 dtype
    dtype = {key: val for key, _, val in indicator_param_list}
    dtype['ths_code'] = String(20)
    dtype['exch_eng'] = String(20)

    # 获取合约列表
    code_set = set()
    ndays_per_update = 90
    # 获取历史期货合约列表信息
    sector_count = len(exchange_sectorid_dic_list)
    for num, exchange_sectorid_dic in enumerate(exchange_sectorid_dic_list,
                                                start=1):
        exchange_name = exchange_sectorid_dic['exchange_name']
        exch_eng = exchange_sectorid_dic['exch_eng']
        sector_id = exchange_sectorid_dic['sectorid']
        date_establish = exchange_sectorid_dic['date_establish']
        # 计算获取合约列表的起始日期
        date_since = str_2_date(
            exchange_latest_ipo_date_dic.setdefault(exch_eng, date_establish))
        date_yestoday = date.today() - timedelta(days=1)
        logger.info("%d/%d) %s[%s][%s] %s ~ %s", num, sector_count,
                    exchange_name, exch_eng, sector_id, date_since,
                    date_yestoday)
        while date_since <= date_yestoday:
            date_since_str = date_2_str(date_since)
            # #数据池-板块_板块成分-日期;同花顺代码;证券名称;当日行情端证券名称(仅股票节点有效)-iFinD数据接口
            # 获取板块成分(期货商品的合约)
            # THS_DP('block','2021-01-15;091002003','date:Y,thscode:Y,security_name:Y,security_name_in_time:Y')
            try:
                future_info_df = invoker.THS_DataPool(
                    'block', '%s;%s' % (date_since_str, sector_id),
                    'thscode:Y,security_name:Y')
            except APIError as exp:
                if exp.ret_dic['error_code'] in (
                        -4001,
                        -4210,
                ):
                    future_info_df = None
                else:
                    logger.exception("THS_DataPool %s 获取失败, '%s;%s'",
                                     exchange_name, date_since_str, sector_id)
                    break
            # if future_info_df is None or future_info_df.shape[0] == 0:
            #     break
            if future_info_df is not None and future_info_df.shape[0] > 0:
                code_set |= set(future_info_df['THSCODE'])

            if date_since >= date_yestoday:
                break
            else:
                date_since += timedelta(days=ndays_per_update)
                if date_since > date_yestoday:
                    date_since = date_yestoday

        if DEBUG:
            break

    # 获取合约列表
    code_list = [wc for wc in code_set if wc not in code_ipo_date_dic]
    # 获取合约基本信息
    if len(code_list) > 0:
        for code_list in split_chunk(code_list, 500):
            future_info_df = invoker.THS_BasicData(code_list, json_indicator,
                                                   json_param)
            if future_info_df is None or future_info_df.shape[0] == 0:
                data_count = 0
                logger.warning("更新 %s 结束 %d 条记录被更新", table_name, data_count)
            else:
                # 补充 exch_eng 字段
                future_info_df['exch_eng'] = ''
                for exchange_sectorid_dic in exchange_sectorid_dic_list:
                    future_info_df['exch_eng'][
                        future_info_df['ths_exchange_short_name_future'] ==
                        exchange_sectorid_dic[
                            'exchange_name']] = exchange_sectorid_dic[
                                'exch_eng']

                data_count = bunch_insert_on_duplicate_update(
                    future_info_df,
                    table_name,
                    engine_md,
                    dtype,
                    primary_keys=['ths_code'],
                    schema=config.DB_SCHEMA_MD)
                logger.info("更新 %s 结束 %d 条记录被更新", table_name, data_count)
Пример #29
0
def import_index_daily_ds(chain_param=None,
                          ths_code_set: set = None,
                          begin_time=None):
    """
    通过date_serise接口将历史数据保存到 ifind_index_daily_ds,该数据作为 History数据的补充数据 例如:复权因子af、涨跌停标识、停牌状态、原因等
    :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用
    :param ths_code_set:
    :param begin_time:
    :return:
    """
    table_name = 'ifind_index_daily_ds'
    has_table = engine_md.has_table(table_name)
    json_indicator, json_param = unzip_join(
        [(key, val) for key, val, _ in INDICATOR_PARAM_LIST_INDEX_DAILY_DS],
        sep=';')
    if has_table:
        sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ifnull(trade_date_max_1, ths_index_base_period_index) date_frm, NULL,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM 
                    ifind_index_info info 
                LEFT OUTER JOIN
                    (SELECT ths_code, adddate(max(time),1) trade_date_max_1 FROM {table_name} GROUP BY ths_code) daily
                ON info.ths_code = daily.ths_code
            ) tt
            WHERE date_frm <= if(NULL<end_date, NULL, end_date) 
            ORDER BY ths_code""".format(table_name=table_name)
    else:
        sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ths_index_base_period_index date_frm, NULL,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM ifind_index_info info 
            ) tt
            WHERE date_frm <= if(NULL<end_date, NULL, end_date) 
            ORDER BY ths_code;"""
        logger.warning('%s 不存在,仅使用 ifind_index_info 表进行计算日期范围' % table_name)
    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)

        # 获取每只股票需要获取日线数据的日期区间
        code_date_range_dic = {
            ths_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for ths_code, date_from, date_to in table.fetchall()
            if ths_code_set is None or ths_code in ths_code_set
        }

    if TRIAL:
        date_from_min = date.today() - timedelta(days=(365 * 5))
        # 试用账号只能获取近5年数据
        code_date_range_dic = {
            ths_code: (max([date_from, date_from_min]), date_to)
            for ths_code, (date_from, date_to) in code_date_range_dic.items()
            if date_to is not None and date_from_min <= date_to
        }

    data_df_list, data_count, tot_data_count, code_count = [], 0, 0, len(
        code_date_range_dic)
    try:
        for num, (ths_code,
                  (begin_time,
                   end_time)) in enumerate(code_date_range_dic.items(),
                                           start=1):
            logger.debug('%d/%d) %s [%s - %s]', num, code_count, ths_code,
                         begin_time, end_time)
            end_time = date_2_str(end_time)
            data_df = invoker.THS_DateSerial(
                ths_code, json_indicator, json_param,
                'Days:Tradedays,Fill:Previous,Interval:D', begin_time,
                end_time)
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df_list.append(data_df)
            # 大于阀值有开始插入
            if data_count >= 10000:
                data_df_all = pd.concat(data_df_list)
                # data_df_all.to_sql(table_name, engine_md, if_exists='append', index=False, dtype=dtype)
                data_count = bunch_insert_on_duplicate_update(
                    data_df_all, table_name, engine_md, DTYPE_INDEX_DAILY_DS)
                tot_data_count += data_count
                data_df_list, data_count = [], 0

            # 仅调试使用
            if DEBUG and len(data_df_list) > 1:
                break
    finally:
        if data_count > 0:
            data_df_all = pd.concat(data_df_list)
            data_count = bunch_insert_on_duplicate_update(
                data_df_all, table_name, engine_md, DTYPE_INDEX_DAILY_DS)
            tot_data_count += data_count

        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])

        logging.info("更新 %s 完成 新增数据 %d 条", table_name, tot_data_count)
def update_private_fund_nav(chain_param=None,
                            get_df=False,
                            wind_code_list=None):
    """
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :param get_df:
    :param wind_code_list:
    :return:
    """
    table_name = 'wind_fund_nav'
    # 初始化数据下载端口
    # 初始化数据库engine
    # 链接数据库,并获取fundnav旧表
    # with get_db_session(engine) as session:
    #     table = session.execute('select wind_code, ADDDATE(max(trade_date),1) from wind_fund_nav group by wind_code')
    #     fund_trade_date_begin_dic = dict(table.fetchall())
    # 获取wind_fund_info表信息
    has_table = engine_md.has_table(table_name)
    if has_table:
        fund_info_df = pd.read_sql_query(
            """SELECT DISTINCT fi.wind_code AS wind_code, 
               IFNULL(trade_date_from, if(trade_date_latest BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1), ADDDATE(trade_date_latest,1) , fund_setupdate) ) date_from,
               if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to 
               FROM fund_info fi
               LEFT JOIN
               (
               SELECT wind_code, ADDDATE(max(trade_date),1) trade_date_from FROM wind_fund_nav
               GROUP BY wind_code
               ) wfn
               ON fi.wind_code = wfn.wind_code""", engine_md)
    else:

        logger.warning('wind_fund_nav 不存在,仅使用 fund_info 表进行计算日期范围')
        fund_info_df = pd.read_sql_query(
            """SELECT DISTINCT fi.wind_code AS wind_code, 
                     fund_setupdate date_from,
                    if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to 
                    FROM fund_info fi ORDER BY wind_code""", engine_md)
    wind_code_date_frm_to_dic = {
        wind_code: (str_2_date(date_from), str_2_date(date_to))
        for wind_code, date_from, date_to in
        zip(fund_info_df['wind_code'], fund_info_df['date_from'],
            fund_info_df['date_to'])
    }
    fund_info_df.set_index('wind_code', inplace=True)
    if wind_code_list is None:
        wind_code_list = list(fund_info_df.index)
    else:
        wind_code_list = list(set(wind_code_list) & set(fund_info_df.index))
    # 结束时间
    date_last_day = date.today() - timedelta(days=1)
    # date_end_str = date_end.strftime(STR_FORMAT_DATE)

    fund_nav_all_df = []
    no_data_count = 0
    code_count = len(wind_code_list)
    # 对每个新获取的基金名称进行判断,若存在 fundnav 中,则只获取部分净值
    wind_code_trade_date_latest_dic = {}
    date_gap = timedelta(days=10)
    try:
        for num, wind_code in enumerate(wind_code_list):
            date_begin, date_end = wind_code_date_frm_to_dic[wind_code]

            # if date_end > date_last_day:
            #     date_end = date_last_day
            if date_begin > date_end:
                continue
            # 设定数据获取的起始日期
            # wind_code_trade_date_latest_dic[wind_code] = date_to
            # if wind_code in fund_trade_date_begin_dic:
            #     trade_latest = fund_trade_date_begin_dic[wind_code]
            #     if trade_latest > date_end:
            #         continue
            #     date_begin = max([date_begin, trade_latest])
            # if date_begin is None:
            #     continue
            # elif isinstance(date_begin, str):
            #     date_begin = datetime.strptime(date_begin, STR_FORMAT_DATE).date()

            if isinstance(date_begin, date):
                if date_begin.year < 1900:
                    continue
                if date_begin > date_end:
                    continue
                date_begin_str = date_begin.strftime('%Y-%m-%d')
            else:
                logger.error("%s date_begin:%s", wind_code, date_begin)
                continue

            if isinstance(date_end, date):
                if date_begin.year < 1900:
                    continue
                if date_begin > date_end:
                    continue
                date_end_str = date_end.strftime('%Y-%m-%d')
            else:
                logger.error("%s date_end:%s", wind_code, date_end)
                continue
            # 尝试获取 fund_nav 数据
            for k in range(2):
                try:
                    fund_nav_tmp_df = invoker.wsd(
                        codes=wind_code,
                        fields='nav,NAV_acc,NAV_date',
                        beginTime=date_2_str(date_begin_str),
                        endTime=date_2_str(date_end_str),
                        options='Fill=Previous')
                    trade_date_latest = datetime.strptime(
                        date_end_str, '%Y-%m-%d').date() - date_gap
                    wind_code_trade_date_latest_dic[
                        wind_code] = trade_date_latest
                    break
                except APIError as exp:
                    # -40520007z
                    if exp.ret_dic.setdefault('error_code', 0) == -40520007:
                        trade_date_latest = datetime.strptime(
                            date_end_str, '%Y-%m-%d').date() - date_gap
                        wind_code_trade_date_latest_dic[
                            wind_code] = trade_date_latest
                    logger.error("%s Failed, ErrorMsg: %s" %
                                 (wind_code, str(exp)))
                    continue
                except Exception as exp:
                    logger.error("%s Failed, ErrorMsg: %s" %
                                 (wind_code, str(exp)))
                    continue
            else:
                fund_nav_tmp_df = None

            if fund_nav_tmp_df is None:
                logger.info('%s No data', wind_code)
                # del wind_code_trade_date_latest_dic[wind_code]
                no_data_count += 1
                logger.warning('%d funds no data', no_data_count)
            else:
                fund_nav_tmp_df.dropna(how='all', inplace=True)
                df_len = fund_nav_tmp_df.shape[0]
                if df_len == 0:
                    continue
                fund_nav_tmp_df['wind_code'] = wind_code
                # 此处删除 trade_date_latest 之后再加上,主要是为了避免因抛出异常而导致的该条数据也被记录更新
                # del wind_code_trade_date_latest_dic[wind_code]
                trade_date_latest = fund_nav_df_2_sql(table_name,
                                                      fund_nav_tmp_df,
                                                      engine_md,
                                                      is_append=True)
                if trade_date_latest is None:
                    logger.error('%s[%d] data insert failed', wind_code)
                else:
                    wind_code_trade_date_latest_dic[
                        wind_code] = trade_date_latest
                    logger.info('%d) %s updated, %d funds left', num,
                                wind_code, code_count - num)
                    if get_df:
                        fund_nav_all_df = fund_nav_all_df.append(
                            fund_nav_tmp_df)
            if DEBUG and num > 4:  # 调试使用
                break
    finally:
        # import_wind_fund_nav_to_fund_nav()
        # # update_trade_date_latest(wind_code_trade_date_latest_dic)
        # try:
        #     # update_fund_mgrcomp_info()
        # except:
        #     # 新功能上线前由于数据库表不存在,可能导致更新失败,属于正常现象
        logger.exception('新功能上线前由于数据库表不存在,可能导致更新失败,属于正常现象')
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])
    return fund_nav_all_df