예제 #1
0
def _test_rl_handler_4_train(trade_date_from='2010-1-1', trade_date_to='2018-10-18'):
    from ibats_common.example.reinforcement_learning.v1.rl_stg import get_stg_handler
    from ibats_common.example.data import load_data
    trade_date_from_ = str_2_date(trade_date_from)
    trade_date_to_ = str_2_date(trade_date_to)
    df = load_data('RB.csv')
    md_df = df[df['trade_date'].apply(lambda x: str_2_date(x) <= trade_date_to_)]
    trade_date_s = md_df['trade_date'].apply(lambda x: str_2_date(x))
    rl_handler = RLHandler4Train(retrain_period=360, get_stg_handler=get_stg_handler, q_table_key=trade_date_to)
    rl_handler.init_state(md_df)
    trade_date_action_dic = {}
    for trade_date in trade_date_s[trade_date_s >= trade_date_from_]:
        md_df_curr = md_df[trade_date_s.apply(lambda x: x <= trade_date)]
        action = rl_handler.choose_action(md_df_curr)
        trade_date_action_dic[trade_date] = action

    action_df = pd.DataFrame([trade_date_action_dic]).T
    print("action_df = \n", action_df)
    # ql_table = QLearningTable(rl_handler.actions, key=trade_date_to)
    ql_table_class = load_class(module_name='ibats_common.example.reinforcement_learning.v1.q_learn',
                                class_name='QLearningTable')
    ql_table = ql_table_class(actions=rl_handler.actions, key=trade_date_to)
    assert ql_table.q_table.shape[0] > 0
    print("ql_table.q_table.shape=", ql_table.q_table.shape)
    # rl_handler.init_ql_table()
    assert ql_table.q_table.shape == rl_handler.ql_table.q_table.shape
예제 #2
0
 def __init__(self,
              instrument_id_list,
              md_period: PeriodType,
              exchange_name,
              agent_name=None,
              init_load_md_count=None,
              init_md_date_from=None,
              init_md_date_to=None,
              **kwargs):
     if agent_name is None:
         agent_name = f'{exchange_name}.{md_period}'
     self.exchange_name = exchange_name
     super().__init__(name=agent_name, daemon=True)
     self.md_period = md_period
     self.keep_running = None
     self.instrument_id_list = instrument_id_list
     self.init_load_md_count = int(
         init_load_md_count) if init_load_md_count is not None else None
     self.init_md_date_from = str_2_date(init_md_date_from)
     self.init_md_date_to = str_2_date(init_md_date_to)
     self.logger = logging.getLogger(str(self.__class__))
     self.agent_name = agent_name
     self.params = kwargs
     # 关键 key 信息
     self.timestamp_key = kwargs[
         'timestamp_key'] if 'timestamp_key' in kwargs else None
     self.symbol_key = kwargs[
         'symbol_key'] if 'symbol_key' in kwargs else None
     self.close_key = kwargs['close_key'] if 'close_key' in kwargs else None
예제 #3
0
    def load_model_if_exist(self,
                            trade_date,
                            enable_load_model_if_exist=False):
        """
        将模型导出到文件
        目录结构:
        tf_saves_2019-06-05_16_21_39
          *   model_tfls
          *       *   2012-12-31
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *       *   2013-02-28
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *   tensorboard_logs
          *       *   2012-12-31_496[1]_20190605_184316
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
          *       *   2013-02-28_496[1]_20190605_184716
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
        :param enable_load_model_if_exist:
        :param trade_date:
        :return:
        """
        if self.enable_load_model_if_exist or enable_load_model_if_exist:
            # 获取小于等于当期交易日的最大的一个交易日对应的文件名
            min_available_date = str_2_date(trade_date) - timedelta(
                days=self.retrain_period)
            self.logger.debug('尝试加载现有模型,[%s - %s] %d 天', min_available_date,
                              trade_date, self.retrain_period)
            date_file_path_pair_list = [
                _ for _ in self.get_date_file_path_pair_list()
                if _[0] >= min_available_date
            ]
            if len(date_file_path_pair_list) > 0:
                # 按日期排序
                date_file_path_pair_list.sort(key=lambda x: x[0])
                # 获取小于等于当期交易日的最大的一个交易日对应的文件名
                # file_path = get_last(date_file_path_pair_list, lambda x: x[0] <= trade_date, lambda x: x[1])
                trade_date = str_2_date(trade_date)
                ret = get_last(date_file_path_pair_list,
                               lambda x: x[0] <= trade_date)
                if ret is not None:
                    key, folder_path, predict_test_random_state = ret
                    if folder_path is not None:
                        model = self.get_model(
                            rebuild_model=True)  # 这句话是必须的,需要实现建立模型才可以加载
                        model.load(folder_path)
                        self.trade_date_last_train = key
                        self.predict_test_random_state = predict_test_random_state
                        self.logger.info(
                            "加载模型成功。trade_date_last_train: %s load from path: %s",
                            key, folder_path)
                        return True

        return False
예제 #4
0
def get_wind_kv_per_year(wind_code, wind_indictor_str, date_from, date_to,
                         params):
    """
    \
    :param wind_code:
    :param wind_indictor_str:
    :param date_from:
    :param date_to:
    :param params: "year=%(year)d;westPeriod=180"== > "year=2018;westPeriod=180"
    :return:
    """
    date_from, date_to = str_2_date(date_from), str_2_date(date_to)
    # 以年底为分界线,将日期范围截取成以自然年为分段的日期范围
    date_pair = []
    if date_from <= date_to:
        date_curr = date_from
        while True:
            date_new_year = str_2_date("%d-01-01" % (date_curr.year + 1))
            date_year_end = date_new_year - timedelta(days=1)
            if date_to < date_year_end:
                date_pair.append((date_curr, date_to))
                break
            else:
                date_pair.append((date_curr, date_year_end))
            date_curr = date_new_year
    data_df_list = []
    for date_from_sub, date_to_sub in date_pair:
        params_sub = params % {'year': (date_from_sub.year + 1)}
        try:
            data_df = invoker.wsd(wind_code, wind_indictor_str, date_from_sub,
                                  date_to_sub, params_sub)
        except APIError as exp:
            logger.exception("%s %s [%s ~ %s] %s 执行异常", wind_code,
                             wind_indictor_str, date_2_str(date_from_sub),
                             date_2_str(date_to_sub), params_sub)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                raise exp
        if data_df is None:
            logger.warning('%s %s [%s ~ %s] has no data', wind_code,
                           wind_indictor_str, date_2_str(date_from_sub),
                           date_2_str(date_to_sub))
            continue
        data_df.dropna(inplace=True)
        if data_df.shape[0] == 0:
            # logger.warning('%s %s [%s ~ %s] has 0 data',
            #                wind_code, wind_indictor_str, date_2_str(date_from_sub), date_2_str(date_to_sub))
            continue
        data_df_list.append(data_df)

    # 合并数据
    data_df_tot = pd.concat(data_df_list) if len(data_df_list) > 0 else None
    return data_df_tot
예제 #5
0
 def __init__(self, table_name, dtype, statement):
     self.logger = logging.getLogger(__name__)
     self.BASE_DATE = str_2_date('1989-12-01')
     self.loop_step = 20
     self.dtype = dtype
     self.table_name = table_name
     self.statement = statement
def get_sectorconstituent(index_code, index_name, target_date) -> pd.DataFrame:
    """
    通过 wind 获取指数成分股及权重
    :param index_code:
    :param index_name:
    :param target_date:
    :return:
    """
    target_date_str = date_2_str(target_date)
    logger.info('获取 %s %s %s 板块信息', index_code, index_name, target_date)
    sec_df = invoker.wset(
        "indexconstituent",
        "date=%s;windcode=%s" % (target_date_str, index_code))
    if sec_df is not None and sec_df.shape[0] > 0:
        # 发现部分情况下返回数据的日期与 target_date 日期不匹配
        sec_df = sec_df[sec_df['date'].apply(
            lambda x: str_2_date(x) == target_date)]
    if sec_df is None or sec_df.shape[0] == 0:
        return None
    sec_df["index_code"] = index_code
    sec_df["index_name"] = index_name
    sec_df.rename(columns={
        'date': 'trade_date',
        'sec_name': 'stock_name',
        'i_weight': 'weight',
    },
                  inplace=True)
    return sec_df
예제 #7
0
    def on_min1(self, md_df, context):
        if self.do_nothing_on_min_bar:  # 仅供调试使用
            return

        # 数据整理
        indexed_df = md_df.set_index('trade_date').drop('instrument_type', axis=1)
        indexed_df.index = pd.DatetimeIndex(indexed_df.index)
        # 获取最新交易日
        trade_date = str_2_date(indexed_df.index[-1])
        days_after_last_train = (trade_date - self.trade_date_last_train).days
        if self.retrain_period is not None and 0 < self.retrain_period < days_after_last_train:
            # 重新训练
            self.logger.info('当前日期 %s 距离上一次训练 %s 已经过去 %d 天,重新训练',
                             trade_date, self.trade_date_last_train, days_after_last_train)
            factor_df = self.load_train_test(indexed_df, rebuild_model=True,
                                             enable_load_model=self.enable_load_model_if_exist)
        else:
            factor_df = get_factor(indexed_df, ohlcav_col_name_list=self.ohlcav_col_name_list,
                                   trade_date_series=self.trade_date_series,
                                   delivery_date_series=self.delivery_date_series)

        # 预测
        pred_mark = self.predict_latest(factor_df)
        is_holding, is_buy, is_sell = pred_mark == 0, pred_mark == 1, pred_mark == 2
        # self.logger.info('%s is_buy=%s, is_sell=%s', trade_date, str(is_buy), str(is_sell))
        close = md_df['close'].iloc[-1]
        instrument_id = context[ContextKey.instrument_id_list][0]
        if is_buy:  # is_buy
            position_date_pos_info_dic = self.get_position(instrument_id)
            no_target_position = True
            if position_date_pos_info_dic is not None:
                for position_date, pos_info in position_date_pos_info_dic.items():
                    direction = pos_info.direction
                    if direction == Direction.Short:
                        self.close_short(instrument_id, close, pos_info.position)
                    elif direction == Direction.Long:
                        no_target_position = False
            if no_target_position:
                self.open_long(instrument_id, close, self.unit)
            else:
                self.logger.debug("%s %s     %.2f holding", self.trade_agent.curr_timestamp, instrument_id, close)

        if is_sell:  # is_sell
            position_date_pos_info_dic = self.get_position(instrument_id)
            no_holding_target_position = True
            if position_date_pos_info_dic is not None:
                for position_date, pos_info in position_date_pos_info_dic.items():
                    direction = pos_info.direction
                    if direction == Direction.Long:
                        self.close_long(instrument_id, close, pos_info.position)
                    elif direction == Direction.Short:
                        no_holding_target_position = False
            if no_holding_target_position:
                self.open_short(instrument_id, close, self.unit)
            else:
                self.logger.debug("%s %s     %.2f holding", self.trade_agent.curr_timestamp, instrument_id, close)

        if is_holding:
            self.logger.debug("%s %s * * %.2f holding", self.trade_agent.curr_timestamp, instrument_id, close)
예제 #8
0
    def get_date_file_path_pair_list(self):
        """
        目录结构:
        tf_saves_2019-06-05_16_21_39
          *   model_tfls
          *       *   2012-12-31
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *       *   2013-02-28
          *       *       *   checkpoint
          *       *       *   model_-54_51.tfl.data-00000-of-00001
          *       *       *   model_-54_51.tfl.index
          *       *       *   model_-54_51.tfl.meta
          *   tensorboard_logs
          *       *   2012-12-31_496[1]_20190605_184316
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
          *       *   2013-02-28_496[1]_20190605_184716
          *       *       *   events.out.tfevents.1559731396.mg-ubuntu64
        :return:
        """
        # 获取全部文件名
        pattern = re.compile(r'model_[-]?\d+_\d+.tfl')
        date_file_path_pair_list, model_name_set = [], set()
        for folder_name in os.listdir(self.model_folder_path):
            folder_path = os.path.join(self.model_folder_path, folder_name)
            if os.path.isdir(folder_path):
                try:
                    # 获取 trade_date_last_train
                    key = str_2_date(folder_name)
                    for file_name in os.listdir(folder_path):
                        # 对下列有效文件名,匹配结果:"model_-54_51.tfl"
                        # model_-54_51.tfl.data-00000-of-00001
                        # model_-54_51.tfl.index
                        # model_-54_51.tfl.meta
                        m = pattern.search(file_name)
                        if m is None:
                            continue
                        model_name = m.group()
                        if key in model_name_set:
                            continue
                        model_name_set.add(key)
                        # 获取 model folder_path
                        file_path = os.path.join(folder_path, model_name)
                        # 获取 predict_test_random_state
                        for log_folder_path in os.listdir(self.tensorboard_dir):
                            if log_folder_path.find(folder_name) == 0:
                                predict_test_random_state = int(log_folder_path.split('[')[1].split(']')[0])
                                break
                        else:
                            predict_test_random_state = None

                        date_file_path_pair_list.append([key, file_path, predict_test_random_state])
                except:
                    pass

        return date_file_path_pair_list
예제 #9
0
    def load_train_test(self, indexed_df, enable_load_model, rebuild_model=False, enable_train_if_load_not_suss=True,
                        enable_train_even_load_succ=False, enable_test=False):
        if rebuild_model:
            self.get_model(rebuild_model=True)

        trade_date = str_2_date(indexed_df.index[-1])
        # 加载模型
        if enable_load_model:
            is_load = self.load_model_if_exist(trade_date)
        else:
            is_load = False

        if enable_train_even_load_succ or (enable_train_if_load_not_suss and not is_load):
            factor_df_dic = get_factor(indexed_df, ohlcav_col_name_list=self.ohlcav_col_name_list,
                                       trade_date_series=self.trade_date_series,
                                       delivery_date_series=self.delivery_date_series, do_multiple_factors=True)
            factor_df = factor_df_dic[1]
            num = 0
            while True:
                num += 1
                if num > 1:
                    self.get_model(rebuild_model=True)
                # 训练模型
                train_acc, val_acc = self.train(factor_df_dic, predict_test_random_state=num)
                if self.over_fitting_train_acc is not None and train_acc > self.over_fitting_train_acc:
                    self.logger.warning('第 %d 次训练,训练集精度 train_acc=%.2f%% 过高,可能存在过拟合,重新采样训练',
                                        num, train_acc * 100)
                    continue
                if self.validation_accuracy_base_line is not None:
                    if val_acc < self.validation_accuracy_base_line:
                        self.logger.warning('第 %d 次训练,训练结果不及预期,重新采样训练', num)
                        continue
                    # elif train_acc - val_acc > 0.15 and val_acc < 0.75:
                    #     self.logger.warning('第 %d 次训练,train_acc=%.2f%%, val_acc=%.2f%% 相差大于 15%% 且验证集正确率小于75%%,重新采样训练',
                    #                    num, train_acc * 100, val_acc * 100)
                    #     continue
                    else:
                        break
                else:
                    break

            self.save_model(trade_date)
            self.trade_date_last_train = trade_date
        else:
            factor_df = get_factor(indexed_df, ohlcav_col_name_list=self.ohlcav_col_name_list,
                                   trade_date_series=self.trade_date_series,
                                   delivery_date_series=self.delivery_date_series)
            train_acc, val_acc = self.valid_model_acc(factor_df)

        self.trade_date_acc_list[trade_date] = [train_acc, val_acc]

        # enable_test 默认为 False
        # self.valid_model_acc(factor_df) 以及完全取代 self.predict_test
        # self.predict_test 仅用于内部测试使用
        if enable_test:
            self.predict_test(factor_df)

        return factor_df
예제 #10
0
def _test_bunch_insert_sqlite():
    """
    检查 bunch_insert_sqlite 函数功能
    :return:
    """
    mysql_table_name = 'test_only'
    file_name = TABLE_NAME_SQLITE_FILE_NAME_DIC[mysql_table_name]
    file_path = get_sqlite_file_path(file_name)
    if os.path.exists(file_path):
        os.remove(file_path)

    df = pd.DataFrame({
        'ts_code': ['600010.SH', '600010.SH', '600010.SH', '000010.SZ', '000010.SZ'],
        'trade_date': [str_2_date(_) for _ in ['2018-1-3', '2018-1-4', '2018-1-5', '2018-1-3', '2018-1-4']],
        'close': [111, 222, 333, 444, 555],
    })
    primary_keys = ['Date']
    bunch_insert_sqlite(df, mysql_table_name=mysql_table_name,
                        table_name_key='ts_code', primary_keys=primary_keys)

    df = pd.DataFrame({
        'ts_code': ['600010.SH', '600010.SH', '000010.SZ', '000010.SZ'],
        'trade_date': [str_2_date(_) for _ in ['2018-1-5', '2018-1-6', '2018-1-4', '2018-1-5']],
        'close': [555, 666, 44400, 55500],
    })
    bunch_insert_sqlite(df, mysql_table_name=mysql_table_name,
                        table_name_key='ts_code', primary_keys=primary_keys)
    with with_sqlite_conn(file_name=file_name) as conn:
        table = conn.execute('select adj_factor from SH600010 where Date = ?', ['2018-01-03'])
        assert table.fetchone()[0], 111
        table = conn.execute('select adj_factor from SH600010 where Date = ?', ['2018-01-05'])
        assert table.fetchone()[0], 555
        table = conn.execute('select adj_factor from SH600010 where Date = ?', ['2018-01-06'])
        assert table.fetchone()[0], 666
        table = conn.execute('select adj_factor from SZ000010 where Date = ?', ['2018-01-05'])
        assert table.fetchone()[0], 55500

    logger.info('检查完成')
예제 #11
0
def _test_fill_season_data():
    """
    测试 filll_season_data 函数
    测试数据
                        code report_date  revenue
    report_date
    2000-12-31   000001.XSHE  2000-12-31    400.0
    2001-03-31   000001.XSHE  2001-03-31      NaN
    2001-06-30   000001.XSHE  2001-06-30    600.0
    2001-09-30   000001.XSHE  2001-09-30      NaN
    2001-12-31   000001.XSHE  2001-12-31   1400.0
    2002-12-31   000001.XSHE  2002-12-31   1600.0

    转换后数据
                        code report_date  revenue  revenue_season
    report_date
    2000-12-31   000001.XSHE  2000-12-31    400.0           100.0
    2001-03-31   000001.XSHE  2001-03-31    100.0           100.0
    2001-06-30   000001.XSHE  2001-06-30    600.0           500.0
    2001-09-30   000001.XSHE  2001-09-30   1500.0           500.0
    2001-12-31   000001.XSHE  2001-12-31   1400.0          -100.0
    2002-12-31   000001.XSHE  2002-12-31   1600.0           400.0
    :return:
    """
    label = 'revenue'
    df = pd.DataFrame({
        'report_date': [
            str_2_date('2000-12-31'),
            str_2_date('2001-3-31'),
            str_2_date('2001-6-30'),
            str_2_date('2001-9-30'),
            str_2_date('2001-12-31'),
            str_2_date('2002-12-31')
        ],
        label: [400, np.nan, 600, np.nan, 1400, 1600],
    })
    df['code'] = '000001.XSHE'
    df = df[['code', 'report_date', label]]
    df.set_index('report_date', drop=False, inplace=True)
    print(df)
    df_new = fill_season_data(df, label)
    print(df_new)
    assert df.loc[str_2_date('2001-3-31'), label] == 100, \
        f"{label} {str_2_date('2001-3-31')} 应该等于前一年的 1/4,当前 {df.loc[str_2_date('2001-3-31'), label]}"
예제 #12
0
    def train(self, factor_df_dic: dict, predict_test_random_state):
        import tflearn
        factor_df = factor_df_dic[1]

        trade_date_from_str, trade_date_to_str = date_2_str(factor_df.index[0]), date_2_str(factor_df.index[-1])
        # xs_train, xs_validation, ys_train, ys_validation = self.separate_train_validation(xs, ys)
        if self.predict_test_random_state is None:
            random_state = predict_test_random_state
        else:
            random_state = self.predict_test_random_state

        # 利用生成数据做训练数据集,只用原始数据中的 validation 部分做验证集
        arr_list, xs_validation, ys_validation = [], None, None
        for adj_factor, factor_df in factor_df_dic.items():
            xs, ys, _ = self.get_x_y(factor_df)
            xs_train_tmp, xs_validation_tmp, ys_train_tmp, ys_validation_tmp = train_test_split(
                xs, ys, test_size=0.2, random_state=random_state)
            arr_list.append([xs_train_tmp, ys_train_tmp])
            # xs_train, xs_validation = xs_train_tmp, xs_validation_tmp
            # ys_train, ys_validation = ys_train_tmp, ys_validation_tmp
            if adj_factor == 1:
                xs_validation, ys_validation = xs_validation_tmp, ys_validation_tmp

        self.xs_train = xs_train = np.vstack([_[0] for _ in arr_list])
        self.ys_train = ys_train = np.vstack([_[1] for _ in arr_list])

        sess = self.get_session(renew=True)
        train_acc, val_acc = 0, 0
        with sess.as_default():
            # with tf.Graph().as_default():
            # self.logger.debug('sess.graph:%s tf.get_default_graph():%s', sess.graph, tf.get_default_graph())
            self.logger.debug('[%d], xs_train %s, ys_train %s, xs_validation %s, ys_validation %s, [%s, %s]',
                              random_state, xs_train.shape, ys_train.shape, xs_validation.shape, ys_validation.shape,
                              trade_date_from_str, trade_date_to_str)
            max_loop = self.max_loop_4_futher_train
            for num in range(max_loop):
                if num == 0:
                    n_epoch = self.n_epoch
                else:
                    n_epoch = self.n_epoch // max_loop

                self.logger.info('[%d]第 %d/%d 轮训练,开始 [%s, %s] n_epoch=%d', random_state, num + 1, max_loop,
                                 trade_date_from_str, trade_date_to_str, n_epoch)
                run_id = f'{trade_date_to_str}_{xs_train.shape[0]}[{predict_test_random_state}]' \
                         f'_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
                tflearn.is_training(True)
                self.model.fit(
                    xs_train, ys_train, validation_set=(xs_validation, ys_validation),
                    show_metric=True, batch_size=self.batch_size, n_epoch=n_epoch,
                    run_id=run_id)
                tflearn.is_training(False)

                result = self.model.evaluate(xs_train, ys_train, batch_size=self.batch_size)
                # self.logger.info("train accuracy: %.2f%%" % (result[0] * 100))
                train_acc = result[0]
                result = self.model.evaluate(xs_validation, ys_validation, batch_size=self.batch_size)
                val_acc = result[0]
                self.logger.info("[%d]第 %d/%d 轮训练,[%s - %s],训练集准确率(train_acc):%.2f%%, 样本外准确率(val_acc): %.2f%%",
                                 random_state, num + 1, max_loop, trade_date_from_str, trade_date_to_str,
                                 train_acc * 100, val_acc * 100)
                if self.over_fitting_train_acc is not None and train_acc > self.over_fitting_train_acc:
                    self.logger.warning('[%d]第 %d/%d 轮训练,训练集精度 %.2f%% > %.2f%% 可能存在过拟合 [%s, %s]',
                                        random_state, num + 1, max_loop, train_acc * 100,
                                        self.over_fitting_train_acc * 100,
                                        trade_date_from_str, trade_date_to_str)
                    break
                if self.validation_accuracy_base_line is not None:
                    if result[0] > self.validation_accuracy_base_line:
                        break
                    elif num < max_loop - 1:
                        self.logger.warning('[%d]第 %d/%d 轮训练,[%s - %s],样本外训练准确率 %.2f%% < %.0f%%,继续训练',
                                            random_state, num + 1, max_loop, trade_date_from_str, trade_date_to_str,
                                            val_acc * 100, self.validation_accuracy_base_line * 100)
                else:
                    break

        self.trade_date_last_train = str_2_date(trade_date_to_str)
        return train_acc, val_acc
예제 #13
0
def import_stock_fin():
    """
    通过pytdx接口下载财务数据
    :return:
    """
    table_name = 'pytdx_stock_fin'
    has_table = engine_md.has_table(table_name)
    report_date_latest = None
    if has_table:
        with with_db_session(engine_md) as session:
            sql_str = "select max(report_date) from {table_name}".format(table_name=table_name)
            report_date_latest = session.execute(sql_str).scalar()
    # 财务指标的名称含义,参考 issue #163 https://github.com/QUANTAXIS/QUANTAXIS/blob/master/QUANTAXIS/QAData/financial_mean.py
    financial_dict = {
        # 1.每股指标
        '001基本每股收益': 'EPS',
        '002扣除非经常性损益每股收益': 'deductEPS',
        '003每股未分配利润': 'undistributedProfitPerShare',
        '004每股净资产': 'netAssetsPerShare',
        '005每股资本公积金': 'capitalReservePerShare',
        '006净资产收益率': 'ROE',
        '007每股经营现金流量': 'operatingCashFlowPerShare',
        # 2. 资产负债表 BALANCE SHEET
        # 2.1 资产
        # 2.1.1 流动资产
        '008货币资金': 'moneyFunds',
        '009交易性金融资产': 'tradingFinancialAssets',
        '010应收票据': 'billsReceivables',
        '011应收账款': 'accountsReceivables',
        '012预付款项': 'prepayments',
        '013其他应收款': 'otherReceivables',
        '014应收关联公司款': 'interCompanyReceivables',
        '015应收利息': 'interestReceivables',
        '016应收股利': 'dividendsReceivables',
        '017存货': 'inventory',
        '018其中:消耗性生物资产': 'expendableBiologicalAssets',
        '019一年内到期的非流动资产': 'noncurrentAssetsDueWithinOneYear',
        '020其他流动资产': 'otherLiquidAssets',
        '021流动资产合计': 'totalLiquidAssets',
        # 2.1.2 非流动资产
        '022可供出售金融资产': 'availableForSaleSecurities',
        '023持有至到期投资': 'heldToMaturityInvestments',
        '024长期应收款': 'longTermReceivables',
        '025长期股权投资': 'longTermEquityInvestment',
        '026投资性房地产': 'investmentRealEstate',
        '027固定资产': 'fixedAssets',
        '028在建工程': 'constructionInProgress',
        '029工程物资': 'engineerMaterial',
        '030固定资产清理': 'fixedAssetsCleanUp',
        '031生产性生物资产': 'productiveBiologicalAssets',
        '032油气资产': 'oilAndGasAssets',
        '033无形资产': 'intangibleAssets',
        '034开发支出': 'developmentExpenditure',
        '035商誉': 'goodwill',
        '036长期待摊费用': 'longTermDeferredExpenses',
        '037递延所得税资产': 'deferredIncomeTaxAssets',
        '038其他非流动资产': 'otherNonCurrentAssets',
        '039非流动资产合计': 'totalNonCurrentAssets',
        '040资产总计': 'totalAssets',
        # 2.2 负债
        # 2.2.1 流动负债
        '041短期借款': 'shortTermLoan',
        '042交易性金融负债': 'tradingFinancialLiabilities',
        '043应付票据': 'billsPayable',
        '044应付账款': 'accountsPayable',
        '045预收款项': 'advancedReceivable',
        '046应付职工薪酬': 'employeesPayable',
        '047应交税费': 'taxPayable',
        '048应付利息': 'interestPayable',
        '049应付股利': 'dividendPayable',
        '050其他应付款': 'otherPayable',
        '051应付关联公司款': 'interCompanyPayable',
        '052一年内到期的非流动负债': 'noncurrentLiabilitiesDueWithinOneYear',
        '053其他流动负债': 'otherCurrentLiabilities',
        '054流动负债合计': 'totalCurrentLiabilities',
        # 2.2.2 非流动负债
        '055长期借款': 'longTermLoans',
        '056应付债券': 'bondsPayable',
        '057长期应付款': 'longTermPayable',
        '058专项应付款': 'specialPayable',
        '059预计负债': 'estimatedLiabilities',
        '060递延所得税负债': 'defferredIncomeTaxLiabilities',
        '061其他非流动负债': 'otherNonCurrentLiabilities',
        '062非流动负债合计': 'totalNonCurrentLiabilities',
        '063负债合计': 'totalLiabilities',
        # 2.3 所有者权益
        '064实收资本(或股本)': 'totalShare',
        '065资本公积': 'capitalReserve',
        '066盈余公积': 'surplusReserve',
        '067减:库存股': 'treasuryStock',
        '068未分配利润': 'undistributedProfits',
        '069少数股东权益': 'minorityEquity',
        '070外币报表折算价差': 'foreignCurrencyReportTranslationSpread',
        '071非正常经营项目收益调整': 'abnormalBusinessProjectEarningsAdjustment',
        '072所有者权益(或股东权益)合计': 'totalOwnersEquity',
        '073负债和所有者(或股东权益)合计': 'totalLiabilitiesAndOwnersEquity',
        # 3. 利润表
        '074其中:营业收入': 'operatingRevenue',
        '075其中:营业成本': 'operatingCosts',
        '076营业税金及附加': 'taxAndSurcharges',
        '077销售费用': 'salesCosts',
        '078管理费用': 'managementCosts',
        '079堪探费用': 'explorationCosts',
        '080财务费用': 'financialCosts',
        '081资产减值损失': 'assestsDevaluation',
        '082加:公允价值变动净收益': 'profitAndLossFromFairValueChanges',
        '083投资收益': 'investmentIncome',
        '084其中:对联营企业和合营企业的投资收益': 'investmentIncomeFromAffiliatedBusinessAndCooperativeEnterprise',
        '085影响营业利润的其他科目': 'otherSubjectsAffectingOperatingProfit',
        '086三、营业利润': 'operatingProfit',
        '087加:补贴收入': 'subsidyIncome',
        '088营业外收入': 'nonOperatingIncome',
        '089减:营业外支出': 'nonOperatingExpenses',
        '090其中:非流动资产处置净损失': 'netLossFromDisposalOfNonCurrentAssets',
        '091加:影响利润总额的其他科目': 'otherSubjectsAffectTotalProfit',
        '092四、利润总额': 'totalProfit',
        '093减:所得税': 'incomeTax',
        '094加:影响净利润的其他科目': 'otherSubjectsAffectNetProfit',
        '095五、净利润': 'netProfit',
        '096归属于母公司所有者的净利润': 'netProfitsBelongToParentCompanyOwner',
        '097少数股东损益': 'minorityProfitAndLoss',

        # 4. 现金流量表
        # 4.1 经营活动 Operating
        '098销售商品、提供劳务收到的现金': 'cashFromGoodsSalesorOrRenderingOfServices',
        '099收到的税费返还': 'refundOfTaxAndFeeReceived',
        '100收到其他与经营活动有关的现金': 'otherCashRelatedBusinessActivitiesReceived',
        '101经营活动现金流入小计': 'cashInflowsFromOperatingActivities',
        '102购买商品、接受劳务支付的现金': 'buyingGoodsReceivingCashPaidForLabor',
        '103支付给职工以及为职工支付的现金': 'paymentToEmployeesAndCashPaidForEmployees',
        '104支付的各项税费': 'paymentsOfVariousTaxes',
        '105支付其他与经营活动有关的现金': 'paymentOfOtherCashRelatedToBusinessActivities',
        '106经营活动现金流出小计': 'cashOutflowsFromOperatingActivities',
        '107经营活动产生的现金流量净额': 'netCashFlowsFromOperatingActivities',
        # 4.2 投资活动 Investment
        '108收回投资收到的现金': 'cashReceivedFromInvestmentReceived',
        '109取得投资收益收到的现金': 'cashReceivedFromInvestmentIncome',
        '110处置固定资产、无形资产和其他长期资产收回的现金净额': 'disposalOfNetCashForRecoveryOfFixedAssets',
        '111处置子公司及其他营业单位收到的现金净额': 'disposalOfNetCashReceivedFromSubsidiariesAndOtherBusinessUnits',
        '112收到其他与投资活动有关的现金': 'otherCashReceivedRelatingToInvestingActivities',
        '113投资活动现金流入小计': 'cashinFlowsFromInvestmentActivities',
        '114购建固定资产、无形资产和其他长期资产支付的现金': 'cashForThePurchaseConstructionPaymentOfFixedAssets',
        '115投资支付的现金': 'cashInvestment',
        '116取得子公司及其他营业单位支付的现金净额': 'acquisitionOfNetCashPaidBySubsidiariesAndOtherBusinessUnits',
        '117支付其他与投资活动有关的现金': 'otherCashPaidRelatingToInvestingActivities',
        '118投资活动现金流出小计': 'cashOutflowsFromInvestmentActivities',
        '119投资活动产生的现金流量净额': 'netCashFlowsFromInvestingActivities',
        # 4.3 筹资活动 Financing
        '120吸收投资收到的现金': 'cashReceivedFromInvestors',
        '121取得借款收到的现金': 'cashFromBorrowings',
        '122收到其他与筹资活动有关的现金': 'otherCashReceivedRelatingToFinancingActivities',
        '123筹资活动现金流入小计': 'cashInflowsFromFinancingActivities',
        '124偿还债务支付的现金': 'cashPaymentsOfAmountBorrowed',
        '125分配股利、利润或偿付利息支付的现金': 'cashPaymentsForDistrbutionOfDividendsOrProfits',
        '126支付其他与筹资活动有关的现金': 'otherCashPaymentRelatingToFinancingActivities',
        '127筹资活动现金流出小计': 'cashOutflowsFromFinancingActivities',
        '128筹资活动产生的现金流量净额': 'netCashFlowsFromFinancingActivities',
        # 4.4 汇率变动
        '129四、汇率变动对现金的影响': 'effectOfForeignExchangRateChangesOnCash',
        '130四(2)、其他原因对现金的影响': 'effectOfOtherReasonOnCash',
        # 4.5 现金及现金等价物净增加
        '131五、现金及现金等价物净增加额': 'netIncreaseInCashAndCashEquivalents',
        '132期初现金及现金等价物余额': 'initialCashAndCashEquivalentsBalance',
        # 4.6 期末现金及现金等价物余额
        '133期末现金及现金等价物余额': 'theFinalCashAndCashEquivalentsBalance',
        # 4.x 补充项目 Supplementary Schedule:
        # 现金流量附表项目    Indirect Method
        # 4.x.1 将净利润调节为经营活动现金流量 Convert net profit to cash flow from operating activities
        '134净利润': 'netProfitFromOperatingActivities',
        '135资产减值准备': 'provisionForAssetsLosses',
        '136固定资产折旧、油气资产折耗、生产性生物资产折旧': 'depreciationForFixedAssets',
        '137无形资产摊销': 'amortizationOfIntangibleAssets',
        '138长期待摊费用摊销': 'amortizationOfLong_termDeferredExpenses',
        '139处置固定资产、无形资产和其他长期资产的损失': 'lossOfDisposingFixedAssetsIntangibleAssetsAndOtherLongTermAssets',
        '140固定资产报废损失': 'scrapLossOfFixedAssets',
        '141公允价值变动损失': 'lossFromFairValueChange',
        '142财务费用': 'financialExpenses',
        '143投资损失': 'investmentLosses',
        '144递延所得税资产减少': 'decreaseOfDeferredTaxAssets',
        '145递延所得税负债增加': 'increaseOfDeferredTaxLiabilities',
        '146存货的减少': 'decreaseOfInventory',
        '147经营性应收项目的减少': 'decreaseOfOperationReceivables',
        '148经营性应付项目的增加': 'increaseOfOperationPayables',
        '149其他': 'others',
        '150经营活动产生的现金流量净额2': 'netCashFromOperatingActivities2',
        # 4.x.2 不涉及现金收支的投资和筹资活动 Investing and financing activities not involved in cash
        '151债务转为资本': 'debtConvertedToCSapital',
        '152一年内到期的可转换公司债券': 'convertibleBondMaturityWithinOneYear',
        '153融资租入固定资产': 'leaseholdImprovements',
        # 4.x.3 现金及现金等价物净增加情况 Net increase of cash and cash equivalents
        '154现金的期末余额': 'cashEndingBal',
        '155现金的期初余额': 'cashBeginingBal',
        '156现金等价物的期末余额': 'cashEquivalentsEndingBal',
        '157现金等价物的期初余额': 'cashEquivalentsBeginningBal',
        '158现金及现金等价物净增加额': 'netIncreaseOfCashAndCashEquivalents',
        # 5. 偿债能力分析
        '159流动比率': 'currentRatio',  # 流动资产/流动负债
        '160速动比率': 'acidTestRatio',  # (流动资产-存货)/流动负债
        '161现金比率(%)': 'cashRatio',  # (货币资金+有价证券)÷流动负债
        '162利息保障倍数': 'interestCoverageRatio',  # (利润总额+财务费用(仅指利息费用部份))/利息费用
        '163非流动负债比率(%)': 'noncurrentLiabilitiesRatio',
        '164流动负债比率(%)': 'currentLiabilitiesRatio',
        '165现金到期债务比率(%)': 'cashDebtRatio',  # 企业经营现金净流入/(本期到期长期负债+本期应付票据)
        '166有形资产净值债务率(%)': 'debtToTangibleAssetsRatio',
        '167权益乘数(%)': 'equityMultiplier',  # 资产总额/股东权益总额
        '168股东的权益/负债合计(%)': 'equityDebtRatio',  # 权益负债率
        '169有形资产/负债合计(%)': 'tangibleAssetDebtRatio',  # 有形资产负债率
        '170经营活动产生的现金流量净额/负债合计(%)': 'netCashFlowsFromOperatingActivitiesDebtRatio',
        '171EBITDA/负债合计(%)': 'EBITDA_Liabilities',
        # 6. 经营效率分析
        # 销售收入÷平均应收账款=销售收入\(0.5 x(应收账款期初+期末))
        '172应收帐款周转率': 'turnoverRatioOfReceivable',
        '173存货周转率': 'turnoverRatioOfInventory',
        # (存货周转天数+应收帐款周转天数-应付帐款周转天数+预付帐款周转天数-预收帐款周转天数)/365
        '174运营资金周转率': 'turnoverRatioOfOperatingAssets',
        '175总资产周转率': 'turnoverRatioOfTotalAssets',
        '176固定资产周转率': 'turnoverRatioOfFixedAssets',  # 企业销售收入与固定资产净值的比率
        '177应收帐款周转天数': 'daysSalesOutstanding',  # 企业从取得应收账款的权利到收回款项、转换为现金所需要的时间
        '178存货周转天数': 'daysSalesOfInventory',  # 企业从取得存货开始,至消耗、销售为止所经历的天数
        '179流动资产周转率': 'turnoverRatioOfCurrentAssets',  # 流动资产周转率(次)=主营业务收入/平均流动资产总额
        '180流动资产周转天数': 'daysSalesofCurrentAssets',
        '181总资产周转天数': 'daysSalesofTotalAssets',
        '182股东权益周转率': 'equityTurnover',  # 销售收入/平均股东权益
        # 7. 发展能力分析
        '183营业收入增长率(%)': 'operatingIncomeGrowth',
        '184净利润增长率(%)': 'netProfitGrowthRate',  # NPGR  利润总额-所得税
        '185净资产增长率(%)': 'netAssetsGrowthRate',
        '186固定资产增长率(%)': 'fixedAssetsGrowthRate',
        '187总资产增长率(%)': 'totalAssetsGrowthRate',
        '188投资收益增长率(%)': 'investmentIncomeGrowthRate',
        '189营业利润增长率(%)': 'operatingProfitGrowthRate',
        '190暂无': 'None1',
        '191暂无': 'None2',
        '192暂无': 'None3',
        # 8. 获利能力分析
        '193成本费用利润率(%)': 'rateOfReturnOnCost',
        '194营业利润率': 'rateOfReturnOnOperatingProfit',
        '195营业税金率': 'rateOfReturnOnBusinessTax',
        '196营业成本率': 'rateOfReturnOnOperatingCost',
        '197净资产收益率': 'rateOfReturnOnCommonStockholdersEquity',
        '198投资收益率': 'rateOfReturnOnInvestmentIncome',
        '199销售净利率(%)': 'rateOfReturnOnNetSalesProfit',
        '200总资产报酬率': 'rateOfReturnOnTotalAssets',
        '201净利润率': 'netProfitMargin',
        '202销售毛利率(%)': 'rateOfReturnOnGrossProfitFromSales',
        '203三费比重': 'threeFeeProportion',
        '204管理费用率': 'ratioOfChargingExpense',
        '205财务费用率': 'ratioOfFinancialExpense',
        '206扣除非经常性损益后的净利润': 'netProfitAfterExtraordinaryGainsAndLosses',
        '207息税前利润(EBIT)': 'EBIT',
        '208息税折旧摊销前利润(EBITDA)': 'EBITDA',
        '209EBITDA/营业总收入(%)': 'EBITDA_GrossRevenueRate',
        # 9. 资本结构分析
        '210资产负债率(%)': 'assetsLiabilitiesRatio',
        '211流动资产比率': 'currentAssetsRatio',  # 期末的流动资产除以所有者权益
        '212货币资金比率': 'monetaryFundRatio',
        '213存货比率': 'inventoryRatio',
        '214固定资产比率': 'fixedAssetsRatio',
        '215负债结构比': 'liabilitiesStructureRatio',
        '216归属于母公司股东权益/全部投入资本(%)': 'shareholdersOwnershipOfAParentCompany_TotalCapital',
        '217股东的权益/带息债务(%)': 'shareholdersInterest_InterestRateDebtRatio',
        '218有形资产/净债务(%)': 'tangibleAssets_NetDebtRatio',
        # 10. 现金流量分析
        '219每股经营性现金流(元)': 'operatingCashFlowPerShareY',
        '220营业收入现金含量(%)': 'cashOfOperatingIncome',
        '221经营活动产生的现金流量净额/经营活动净收益(%)': 'netOperatingCashFlow_netOperationProfit',
        '222销售商品提供劳务收到的现金/营业收入(%)': 'cashFromGoodsSales_OperatingRevenue',
        '223经营活动产生的现金流量净额/营业收入': 'netOperatingCashFlow_OperatingRevenue',
        '224资本支出/折旧和摊销': 'capitalExpenditure_DepreciationAndAmortization',
        '225每股现金流量净额(元)': 'netCashFlowPerShare',
        '226经营净现金比率(短期债务)': 'operatingCashFlow_ShortTermDebtRatio',
        '227经营净现金比率(全部债务)': 'operatingCashFlow_LongTermDebtRatio',
        '228经营活动现金净流量与净利润比率': 'cashFlowRateAndNetProfitRatioOfOperatingActivities',
        '229全部资产现金回收率': 'cashRecoveryForAllAssets',
        # 11. 单季度财务指标
        '230营业收入': 'operatingRevenueSingle',
        '231营业利润': 'operatingProfitSingle',
        '232归属于母公司所有者的净利润': 'netProfitBelongingToTheOwnerOfTheParentCompanySingle',
        '233扣除非经常性损益后的净利润': 'netProfitAfterExtraordinaryGainsAndLossesSingle',
        '234经营活动产生的现金流量净额': 'netCashFlowsFromOperatingActivitiesSingle',
        '235投资活动产生的现金流量净额': 'netCashFlowsFromInvestingActivitiesSingle',
        '236筹资活动产生的现金流量净额': 'netCashFlowsFromFinancingActivitiesSingle',
        '237现金及现金等价物净增加额': 'netIncreaseInCashAndCashEquivalentsSingle',
        # 12.股本股东
        '238总股本': 'totalCapital',
        '239已上市流通A股': 'listedAShares',
        '240已上市流通B股': 'listedBShares',
        '241已上市流通H股': 'listedHShares',
        '242股东人数(户)': 'numberOfShareholders',
        '243第一大股东的持股数量': 'theNumberOfFirstMajorityShareholder',
        '244十大流通股东持股数量合计(股)': 'totalNumberOfTopTenCirculationShareholders',
        '245十大股东持股数量合计(股)': 'totalNumberOfTopTenMajorShareholders',
        # 13.机构持股
        '246机构总量(家)': 'institutionNumber',
        '247机构持股总量(股)': 'institutionShareholding',
        '248QFII机构数': 'QFIIInstitutionNumber',
        '249QFII持股量': 'QFIIShareholding',
        '250券商机构数': 'brokerNumber',
        '251券商持股量': 'brokerShareholding',
        '252保险机构数': 'securityNumber',
        '253保险持股量': 'securityShareholding',
        '254基金机构数': 'fundsNumber',
        '255基金持股量': 'fundsShareholding',
        '256社保机构数': 'socialSecurityNumber',
        '257社保持股量': 'socialSecurityShareholding',
        '258私募机构数': 'privateEquityNumber',
        '259私募持股量': 'privateEquityShareholding',
        '260财务公司机构数': 'financialCompanyNumber',
        '261财务公司持股量': 'financialCompanyShareholding',
        '262年金机构数': 'pensionInsuranceAgencyNumber',
        '263年金持股量': 'pensionInsuranceAgencyShareholfing',
        # 14.新增指标
        # [注:季度报告中,若股东同时持有非流通A股性质的股份(如同时持有流通A股和流通B股),取的是包含同时持有非流通A股性质的流通股数]
        '264十大流通股东中持有A股合计(股)': 'totalNumberOfTopTenCirculationShareholdersForA',
        '265第一大流通股东持股量(股)': 'firstLargeCirculationShareholdersNumber',
        # [注:1.自由流通股=已流通A股-十大流通股东5%以上的A股;2.季度报告中,若股东同时持有非流通A股性质的股份(如同时持有流通A股和流通H股),5%以上的持股取的是不包含同时持有非流通A股性质的流通股数,结果可能偏大; 3.指标按报告期展示,新股在上市日的下个报告期才有数据]
        '266自由流通股(股)': 'freeCirculationStock',
        '267受限流通A股(股)': 'limitedCirculationAShares',
        '268一般风险准备(金融类)': 'generalRiskPreparation',
        '269其他综合收益(利润表)': 'otherComprehensiveIncome',
        '270综合收益总额(利润表)': 'totalComprehensiveIncome',
        '271归属于母公司股东权益(资产负债表)': 'shareholdersOwnershipOfAParentCompany',
        '272银行机构数(家)(机构持股)': 'bankInstutionNumber',
        '273银行持股量(股)(机构持股)': 'bankInstutionShareholding',
        '274一般法人机构数(家)(机构持股)': 'corporationNumber',
        '275一般法人持股量(股)(机构持股)': 'corporationShareholding',
        '276近一年净利润(元)': 'netProfitLastYear',
        '277信托机构数(家)(机构持股)': 'trustInstitutionNumber',
        '278信托持股量(股)(机构持股)': 'trustInstitutionShareholding',
        '279特殊法人机构数(家)(机构持股)': 'specialCorporationNumber',
        '280特殊法人持股量(股)(机构持股)': 'specialCorporationShareholding',
        '281加权净资产收益率(每股指标)': 'weightedROE',
        '282扣非每股收益(单季度财务指标)': 'nonEPSSingle',
    }
    # 整理字段
    for key in financial_dict.keys():
        val = financial_dict[key]
        val = val.strip().replace('/', '_').replace('-', '_')[:64]
        financial_dict[key] = val
    # 设置列名称对应关系
    _pattern = re.compile(r'\d{3}')
    col_name_dic = {
        'col%d' % int(_pattern.search(key).group()): val
        for key, val in financial_dict.items() if _pattern.search(key) is not None
    }
    # 设置 dtype
    dtype = {val: DOUBLE for key, val in financial_dict.items() if _pattern.search(key) is not None}
    dtype['ts_code'] = String(10)
    dtype['report_date'] = Date
    # 下载财务数据
    crawler = HistoryFinancialListCrawler()
    list_data = crawler.fetch_and_parse()
    # print(pd.DataFrame(data=list_data))
    list_count = len(list_data)
    logger.debug('%d 财务数据包可用', len(list_data))
    datacrawler = HistoryFinancialCrawler()
    pd.set_option('display.max_columns', None)
    tot_data_count = 0
    _pattern_file_date = re.compile(r'(?<=gpcw)\d{8}(?=.zip)')
    try:
        for num, file_info in enumerate(list_data, start=1):
            filename = file_info['filename']
            # 检查当前文件的日期是否大于数据库中的最大记录日期
            if report_date_latest is not None:
                m = _pattern_file_date.search(filename)
                if m is None:
                    logger.error('filename:%s 格式匹配失败 %s', filename, _pattern_file_date)
                else:
                    report_date_cur = str_2_date(m.group(), '%Y%m%d')
                    if report_date_cur <= report_date_latest:
                        continue
            logger.info('%d/%d) 开始下载 %s 数据', num, list_count, filename)
            # result = datacrawler.fetch_and_parse(
            #   reporthook=demo_reporthook, filename=filename, path_to_download="/tmpfile.zip")
            result = fetch_and_parse(datacrawler, reporthook=demo_reporthook, filename=filename)
            if result is None:
                continue
            data_df = datacrawler.to_df(data=result)
            data_df.rename(columns=col_name_dic, inplace=True)
            data_df.reset_index(inplace=True)
            data_df['ts_code'] = data_df['code'].apply(lambda x: x + ".SH" if x[0] == '6' else x + '.SZ')
            data_df.drop(['code'], axis=1, inplace=True)
            data_count = bunch_insert_on_duplicate_update(data_df, table_name, engine_md, dtype=dtype,
                                                          myisam_if_create_table=True)
            tot_data_count += data_count
    finally:
        logging.info("更新 %s 结束 %d 条信息被更新", table_name, tot_data_count)
        if not has_table and engine_md.has_table(table_name):
            create_pk_str = """ALTER TABLE {table_name}
                CHANGE COLUMN `ts_code` `ts_code` VARCHAR(20) NOT NULL FIRST,
                CHANGE COLUMN `report_date` `report_date` DATE NOT NULL AFTER `ts_code`,
                ADD PRIMARY KEY (`ts_code`, `report_date`)""".format(table_name=table_name)
            execute_sql(create_pk_str, engine_md, commit=True)
            logger.info('%s 建立主键 [code, report_date]', table_name)
def import_sectorconstituent(sector_code,
                             sector_name,
                             date_start,
                             chain_param=None,
                             exch_code='SZSE'):
    """
    导入 sector_code 板块的成分股
    :param sector_code:默认"SZSE":"深圳"
    :param sector_name:
    :param date_start:
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    # 根据 exch_code 获取交易日列表
    trade_date_list_sorted = get_trade_date_list_sorted(exch_code)
    if trade_date_list_sorted is None or len(trade_date_list_sorted) == 0:
        raise ValueError("没有交易日数据")
    trade_date_list_count = len(trade_date_list_sorted)
    # 格式化 日期字段
    date_start = str_2_date(date_start)

    date_constituent_df_dict = {}
    idx_constituent_set_dic = {}
    # 从数据库中获取最近一个交易日的成分股列表,如果为空,则代表新导入数据 date, constituent_df
    date_latest, constituent_df = get_latest_constituent_df(sector_code)
    # date_constituent_df_dict[date] = constituent_df
    date_latest = str_2_date(date_latest)
    if date_latest is None or date_latest < date_start:
        idx_start = get_last_idx(trade_date_list_sorted,
                                 lambda x: x <= date_start)
        sec_df, _ = get_sectorconstituent_2_dic(sector_code, sector_name,
                                                date_start, idx_start,
                                                trade_date_list_sorted,
                                                date_constituent_df_dict,
                                                idx_constituent_set_dic)
        # 保存板块数据
        sec_df.to_sql("wind_sectorconstituent",
                      engine_md,
                      if_exists='append',
                      index=False)
    else:
        date_start = date_latest
        idx_start = get_last_idx(trade_date_list_sorted,
                                 lambda x: x <= date_start)
        date_constituent_df_dict[date_latest] = constituent_df
        idx_constituent_set_dic[idx_start] = set(constituent_df['wind_code'])

    # 设定日期字段
    # idx_end = idx_start + span if idx_start + span < trade_date_list_count - 1 else trade_date_list_count -1
    yesterday = date.today() - timedelta(days=1)
    idx_end = get_last_idx(trade_date_list_sorted, lambda x: x <= yesterday)
    if idx_start >= idx_end:
        return

    left_or_right = 1
    recursion_get_sectorconstituent(idx_start, idx_end, trade_date_list_sorted,
                                    date_constituent_df_dict,
                                    idx_constituent_set_dic, left_or_right,
                                    sector_code, sector_name)

    # 剔除 date_start 点的数据,该日期数据以及纳入数据库
    del date_constituent_df_dict[date_start]
    # 其他数据导入数据库
    for num, (date_cur, sec_df) in enumerate(date_constituent_df_dict.items(),
                                             start=1):
        sec_df.to_sql("wind_sectorconstituent",
                      engine_md,
                      if_exists='append',
                      index=False)
        logger.info("%d) %s %d 条 %s 成分股数据导入数据库", num, date_cur,
                    sec_df.shape[0], sector_name)
        #仅仅调试时使用
        if DEBUG and num >= 20:
            break
예제 #15
0
def import_index_daily(chain_param=None):
    """导入指数数据
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    table_name = "wind_index_daily"
    has_table = engine_md.has_table(table_name)
    col_name_param_list = [
        ('open', DOUBLE),
        ('high', DOUBLE),
        ('low', DOUBLE),
        ('close', DOUBLE),
        ('volume', DOUBLE),
        ('amt', DOUBLE),
        ('turn', DOUBLE),
        ('free_turn', DOUBLE),
    ]
    wind_indictor_str = ",".join([key for key, _ in col_name_param_list])
    rename_col_dic = {
        key.upper(): key.lower()
        for key, _ in col_name_param_list
    }
    dtype = {key: val for key, val in col_name_param_list}
    dtype['wind_code'] = String(20)
    # TODO: 'trade_date' 声明为 Date 类型后,插入数据库会报错,目前原因不详,日后再解决
    # dtype['trade_date'] = Date,

    # yesterday = date.today() - timedelta(days=1)
    # date_ending = date.today() - ONE_DAY if datetime.now().hour < BASE_LINE_HOUR else date.today()
    # sql_str = """select wii.wind_code, wii.sec_name, ifnull(adddate(latest_date, INTERVAL 1 DAY), wii.basedate) date_from
    #     from wind_index_info wii left join
    #     (
    #         select wind_code,index_name, max(trade_date) as latest_date
    #         from wind_index_daily group by wind_code
    #     ) daily
    #     on wii.wind_code=daily.wind_code"""
    # with with_db_session(engine_md) as session:
    #     table = session.execute(sql_str)
    #     wind_code_date_from_dic = {wind_code: (sec_name, date_from) for wind_code, sec_name, date_from in table.fetchall()}
    # with with_db_session(engine_md) as session:
    #     # 获取市场有效交易日数据
    #     sql_str = "select trade_date from wind_trade_date where trade_date > '2005-1-1'"
    #     table = session.execute(sql_str)
    #     trade_date_sorted_list = [t[0] for t in table.fetchall()]
    #     trade_date_sorted_list.sort()
    # date_to = get_last(trade_date_sorted_list, lambda x: x <= date_ending)
    # data_len = len(wind_code_date_from_dic)
    if has_table:
        sql_str = """
              SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to
              FROM
              (
                  SELECT info.wind_code, ifnull(trade_date, basedate) date_frm, null,
                  if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                  FROM 
                      wind_index_info info 
                  LEFT OUTER JOIN
                      (SELECT wind_code, adddate(max(trade_date),1) trade_date FROM {table_name} GROUP BY wind_code) daily
                  ON info.wind_code = daily.wind_code
              ) tt
              WHERE date_frm <= if(null<end_date, null, end_date) 
              ORDER BY wind_code""".format(table_name=table_name)
    else:
        logger.warning('%s 不存在,仅使用 wind_index_info 表进行计算日期范围', table_name)
        sql_str = """
              SELECT wind_code, date_frm, if(null<end_date, null, end_date) date_to
              FROM
              (
                  SELECT info.wind_code, basedate date_frm, null,
                  if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                  FROM wind_index_info info 
              ) tt
              WHERE date_frm <= if(null<end_date, null, end_date) 
              ORDER BY wind_code;"""

    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 获取每只股票需要获取日线数据的日期区间
        begin_time = None
        wind_code_date_from_dic = {
            wind_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for wind_code, date_from, date_to in table.fetchall()
            if wind_code_set is None or wind_code in wind_code_set
        }

    data_len = len(wind_code_date_from_dic)

    logger.info('%d indexes will been import', data_len)
    for data_num, (wind_code,
                   (date_from,
                    date_to)) in enumerate(wind_code_date_from_dic.items()):
        if str_2_date(date_from) > date_to:
            logger.warning("%d/%d) %s %s - %s 跳过", data_num, data_len,
                           wind_code, date_from, date_to)
            continue
        try:
            temp = invoker.wsd(wind_code, wind_indictor_str, date_from,
                               date_to)
        except APIError as exp:
            logger.exception("%d/%d) %s 执行异常", data_num, data_len, wind_code)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                break
        temp.reset_index(inplace=True)
        temp.rename(columns={'index': 'trade_date'}, inplace=True)
        temp.rename(columns=rename_col_dic, inplace=True)
        temp.trade_date = temp.trade_date.apply(str_2_date)
        temp['wind_code'] = wind_code
        bunch_insert_on_duplicate_update(temp,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
        logger.info('更新指数 %s 至 %s 成功', wind_code, date_2_str(date_to))
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])
예제 #16
0
"""
@author  : MG
@Time    : 2020/9/13 19:03
@File    : backtest.py
@contact : [email protected]
@desc    : 用于
"""
from ibats_utils.mess import str_2_date
from vnpy.app.cta_strategy.backtesting import BacktestingEngine
from vnpy.trader.constant import Interval
from strategy.trandition.atr_rsi_strategy import AtrRsiStrategy

engine = BacktestingEngine()

engine.set_parameters(
    vt_symbol='rb2101',
    interval=Interval.MINUTE,
    start=str_2_date('2020-01-01'),
    rate=3e-5,  # 手续费
    slippage=0.001,  # 滑点
    size=1,  # 乘数
    pricetick=1,  # 最小价格变动
    capital=1000000,
)
engine.add_strategy(AtrRsiStrategy, setting={})
engine.run_backtesting()


if __name__ == "__main__":
    pass
예제 #17
0
def import_future_info(chain_param=None):
    """
    更新期货合约列表信息
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    table_name = "wind_future_info"
    has_table = engine_md.has_table(table_name)
    logger.info("更新 %s 开始", table_name)
    # 获取已存在合约列表
    if has_table:
        sql_str = 'select wind_code, ipo_date from {table_name}'.format(
            table_name=table_name)
        with with_db_session(engine_md) as session:
            table = session.execute(sql_str)
            wind_code_ipo_date_dic = dict(table.fetchall())
    else:
        wind_code_ipo_date_dic = {}

    # 按交易所获取合约列表
    # 上期所
    # w.wset("sectorconstituent","date=1995-05-10;sectorid=a599010201000000")
    # 金交所
    # w.wset("sectorconstituent","date=2013-09-10;sectorid=a599010101000000")
    # 大商所
    # w.wset("sectorconstituent","date=1999-01-10;sectorid=a599010301000000")
    # 郑商所
    # w.wset("sectorconstituent","date=1999-01-10;sectorid=a599010401000000")
    exchange_sectorid_dic_list = [
        {
            'exch_eng': 'SHFE',
            'exchange_name': '上期所',
            'sectorid': 'a599010201000000',
            'date_establish': '1995-05-10'
        },
        {
            'exch_eng': 'CFFEX',
            'exchange_name': '金交所',
            'sectorid': 'a599010101000000',
            'date_establish': '2013-09-10'
        },
        {
            'exch_eng': 'DCE',
            'exchange_name': '大商所',
            'sectorid': 'a599010301000000',
            'date_establish': '1999-01-10'
        },
        {
            'exch_eng': 'CZCE',
            'exchange_name': '郑商所',
            'sectorid': 'a599010401000000',
            'date_establish': '1999-01-10'
        },
    ]
    exchange_latest_ipo_date_dic = get_exchange_latest_data()
    wind_code_set = set()
    ndays_per_update = 90
    # 获取接口参数以及参数列表
    col_name_param_list = [
        ("ipo_date", Date),
        ("sec_name", String(50)),
        ("sec_englishname", String(200)),
        ("exch_eng", String(200)),
        ("lasttrade_date", Date),
        ("lastdelivery_date", Date),
        ("dlmonth", String(20)),
        ("lprice", DOUBLE),
        ("sccode", String(20)),
        ("margin", DOUBLE),
        ("punit", String(200)),
        ("changelt", DOUBLE),
        ("mfprice", DOUBLE),
        ("contractmultiplier", DOUBLE),
        ("ftmargins", String(100)),
        ("trade_code", String(200)),
    ]
    wind_indictor_str = ",".join(col_name
                                 for col_name, _ in col_name_param_list)
    dtype = {key: val for key, val in col_name_param_list}
    dtype['wind_code'] = String(20)
    # 获取历史期货合约列表信息
    logger.info("获取历史期货合约列表信息")
    for exchange_sectorid_dic in exchange_sectorid_dic_list:
        exchange_name = exchange_sectorid_dic['exchange_name']
        exch_eng = exchange_sectorid_dic['exch_eng']
        sector_id = exchange_sectorid_dic['sectorid']
        date_establish = exchange_sectorid_dic['date_establish']
        date_since = str_2_date(
            exchange_latest_ipo_date_dic.setdefault(exch_eng, date_establish))
        date_yestoday = date.today() - timedelta(days=1)
        logger.info("%s[%s][%s] %s ~ %s", exchange_name, exch_eng, sector_id,
                    date_since, date_yestoday)
        while date_since <= date_yestoday:
            date_since_str = date_since.strftime(STR_FORMAT_DATE)
            future_info_df = invoker.wset(
                "sectorconstituent",
                "date=%s;sectorid=%s" % (date_since_str, sector_id))
            data_count = 0 if future_info_df is None else future_info_df.shape[
                0]
            logger.info("subject_name=%s[%s] %s 返回 %d 条数据", exchange_name,
                        sector_id, date_since_str, data_count)
            if data_count > 0:
                wind_code_set |= set(future_info_df['wind_code'])

            if date_since >= date_yestoday:
                break
            else:
                date_since += timedelta(days=ndays_per_update)
                if date_since > date_yestoday:
                    date_since = date_yestoday

    # 获取合约列表
    wind_code_list = [
        wc for wc in wind_code_set if wc not in wind_code_ipo_date_dic
    ]
    # 获取合约基本信息
    # w.wss("AU1706.SHF,AG1612.SHF,AU0806.SHF", "ipo_date,sec_name,sec_englishname,exch_eng,lasttrade_date,lastdelivery_date,dlmonth,lprice,sccode,margin,punit,changelt,mfprice,contractmultiplier,ftmargins,trade_code")
    if len(wind_code_list) > 0:
        logger.info("%d wind_code will be invoked by wss, wind_code_list=%s",
                    len(wind_code_list), wind_code_list)
        future_info_df = invoker.wss(wind_code_list, wind_indictor_str)
        future_info_df['MFPRICE'] = future_info_df['MFPRICE'].apply(
            mfprice_2_num)
        future_info_count = future_info_df.shape[0]

        future_info_df.rename(
            columns={c: str.lower(c)
                     for c in future_info_df.columns},
            inplace=True)
        future_info_df.index.rename('wind_code', inplace=True)
        future_info_df.reset_index(inplace=True)
        data_count = bunch_insert_on_duplicate_update(future_info_df,
                                                      table_name,
                                                      engine_md,
                                                      dtype=dtype)
        logging.info("更新 %s 结束 %d 条信息被更新", table_name, data_count)
        if not has_table and engine_md.has_table(table_name):
            # alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])

        logger.info("更新 wind_future_info 结束 %d 条记录被更新", future_info_count)
        update_from_info_table(table_name)
def import_index_constituent(index_code,
                             index_name,
                             date_start,
                             exch_code='SZSE',
                             date_end=None,
                             method='loop'):
    """
    导入 sector_code 板块的成分股
    :param index_code:默认"SZSE":"深圳"
    :param index_name:
    :param date_start:
    :param exch_code:
    :param date_end:默认为None,到最近交易日的历史数据
    :return:
    """
    table_name = 'wind_index_constituent'
    param_list = [
        ('trade_date', Date),
        ('weight', DOUBLE),
        ('stock_name', String(80)),
        ('index_code', String(20)),
        ('index_name', String(80)),
    ]
    #  sldksldDFGDFGD,Nlfkgldfngldldfngldnzncvxcvnx
    dtype = {key: val for key, val in param_list}
    dtype['wind_cod'] = String(20)
    # 根据 exch_code 获取交易日列表
    trade_date_list_sorted = get_trade_date_list_sorted(exch_code)
    if trade_date_list_sorted is None or len(trade_date_list_sorted) == 0:
        raise ValueError("没有交易日数据")
    trade_date_list_count = len(trade_date_list_sorted)
    # 格式化 日期字段
    date_start = str_2_date(date_start)
    if date_end is not None:
        date_end = str_2_date(date_end)
        idx_end = get_first_idx(trade_date_list_sorted,
                                lambda x: x >= date_end)
        if idx_end is not None:
            trade_date_list_sorted = trade_date_list_sorted[:(idx_end + 1)]

    date_constituent_df_dict = OrderedDict()
    idx_constituent_set_dic = {}
    # 从数据库中获取最近一个交易日的成分股列表,如果为空,则代表新导入数据 date, constituent_df
    date_latest, constituent_df = get_latest_constituent_df(index_code)
    # date_constituent_df_dict[date] = constituent_df
    date_latest = str_2_date(date_latest)
    if date_latest is None or date_latest < date_start:
        idx_start = get_last_idx(trade_date_list_sorted,
                                 lambda x: x <= date_start)
        sec_df, _ = get_index_constituent_2_dic(index_code, index_name,
                                                date_start, idx_start,
                                                date_constituent_df_dict,
                                                idx_constituent_set_dic)
        if sec_df is None or sec_df.shape[0] == 0:
            return
        # 保存板块数据
        # sec_df.to_sql(table_name, engine_md, if_exists='append', index=False)
        bunch_insert_on_duplicate_update(sec_df,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
    else:
        date_start = date_latest
        idx_start = get_last_idx(trade_date_list_sorted,
                                 lambda x: x <= date_start)
        date_constituent_df_dict[date_latest] = constituent_df
        idx_constituent_set_dic[idx_start] = set(constituent_df['wind_code'])

    # 设定日期字段
    # idx_end = idx_start + span if idx_start + span < trade_date_list_count - 1 else trade_date_list_count -1
    yesterday = date.today() - timedelta(days=1)
    idx_end = get_last_idx(trade_date_list_sorted, lambda x: x <= yesterday)
    if idx_start >= idx_end:
        return

    if method == 'loop':
        try:
            idx_end = idx_start + 10  # 调试使用
            loop_get_data(idx_start + 1, idx_end, trade_date_list_sorted,
                          date_constituent_df_dict, idx_constituent_set_dic,
                          index_code, index_name)
        except APIError:
            logger.exception(
                'loop_get_data (idx_start=%d, idx_end=%d, index_code=%s, index_name=%s)',
                idx_start, idx_end, index_code, index_name)
    elif method == 'recursion':
        left_or_right = 1
        recursion_dichotomy_get_data(idx_start, idx_end,
                                     trade_date_list_sorted,
                                     date_constituent_df_dict,
                                     idx_constituent_set_dic, left_or_right,
                                     index_code, index_name)
    else:
        raise ValueError('method = %s error' % method)

    # 剔除 date_start 点的数据,该日期数据以及纳入数据库
    del date_constituent_df_dict[date_start]
    # 其他数据导入数据库
    for num, (date_cur, sec_df) in enumerate(date_constituent_df_dict.items(),
                                             start=1):
        # sec_df.to_sql(table_name, engine_md, if_exists='append', index=False)
        bunch_insert_on_duplicate_update(sec_df,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
        logger.info("%d) %s %d 条 %s 成分股数据导入数据库", num, date_cur,
                    sec_df.shape[0], index_name)
예제 #19
0
def import_smfund_info(chain_param=None):
    """
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    table_name = "wind_smfund_info"
    has_table = engine_md.has_table(table_name)
    # w.start()
    types = {
        u'主动股票型分级母基金': 1000007766000000,
        u'被动股票型分级母基金': 1000007767000000,
        u'纯债券型分级母基金': 1000007768000000,
        u'混合债券型分级母基金': 1000007769000000,
        u'混合型分级母基金': 1000026143000000,
        u'QDII分级母基金': 1000019779000000
    }
    col_name_param_list = [
        ('wind_code', String(20)),
        ('fund_type', String(20)),
        ('sec_name', String(50)),
        ('class_a_code', String(20)),
        ('class_a_name', String(50)),
        ('class_b_code', String(20)),
        ('class_b_name', String(50)),
        ('track_indexcode', String(20)),
        ('track_indexname', String(50)),
        ('a_pct', DOUBLE),
        ('b_pct', DOUBLE),
        ('upcv_nav', DOUBLE),
        ('downcv_nav', DOUBLE),
        ('max_purchasefee', DOUBLE),
        ('max_redemptionfee', DOUBLE),
    ]
    col_name = ",".join([col_name for col_name, _ in col_name_param_list])
    # 获取各个历史时段的分级基金列表,并汇总全部基金代码
    dates = [
        '2011-01-01', '2013-01-01', '2015-01-01', '2017-01-01', '2018-01-01'
    ]  # 分三个时间点获取市场上所有分级基金产品
    df = pd.DataFrame()
    # 获取接数据
    for date_p in dates:
        temp_df = invoker.wset("sectorconstituent",
                               "date=%s;sectorid=1000006545000000" % date_p)
        df = df.append(temp_df)
    wind_code_all = df['wind_code'].unique()
    # 设置dtype
    dtype = {key: val for key, val in col_name_param_list}
    dtype['wind_code'] = String(20)
    dtype["tradable"] = String(20)
    dtype["fund_setupdate"] = Date
    dtype["fund_maturitydate"] = Date
    if has_table:
        with with_db_session(engine_md) as session:
            table = session.execute("SELECT wind_code FROM wind_smfund_info")
            wind_code_existed = set(
                [content[0] for content in table.fetchall()])
        wind_code_new = list(set(wind_code_all) - wind_code_existed)
    else:
        wind_code_new = list(set(wind_code_all))
    # if len(wind_code_new) == 0:
    #     print('no sm fund imported')
    # 查询数据库,剔除已存在的基金代码
    wind_code_new = [code for code in wind_code_new if code.find('!') < 0]
    info_df = invoker.wss(wind_code_new, 'fund_setupdate, fund_maturitydate')
    if info_df is None:
        raise Exception('no data')
    info_df['FUND_SETUPDATE'] = info_df['FUND_SETUPDATE'].apply(
        lambda x: str_2_date(x))
    info_df['FUND_MATURITYDATE'] = info_df['FUND_MATURITYDATE'].apply(
        lambda x: str_2_date(x))
    info_df.rename(columns={
        'FUND_SETUPDATE': 'fund_setupdate',
        'FUND_MATURITYDATE': 'fund_maturitydate'
    },
                   inplace=True)
    field = col_name
    # field = "fund_type,wind_code,sec_name,class_a_code,class_a_name,class_b_code,class_b_name,a_pct,b_pct,upcv_nav,downcv_nav,track_indexcode,track_indexname,max_purchasefee,max_redemptionfee"

    for code in info_df.index:
        beginDate = info_df.loc[code, 'fund_setupdate'].strftime('%Y-%m-%d')
        temp_df = invoker.wset("leveragedfundinfo",
                               "date=%s;windcode=%s;field=%s" %
                               (beginDate, code, field))  # ;field=%s  , field
        df = df.append(temp_df)
        if DEBUG and len(df) > 10:
            break
    df.set_index('wind_code', inplace=True)
    df['tradable'] = df.index.map(lambda x: x if 'S' in x else None)
    # df.index = df.index.map(lambda x: x[:-2] + 'OF')
    info_df = info_df.join(df, how='outer')
    # TODO: 需要检查一下代码
    info_df.rename(
        columns={
            'a_nav': 'nav_a',
            'b_nav': 'nav_b',
            'a_fs_inc': 'fs_inc_a',
            'b_fs_inc': 'fs_inc_b'
        })
    info_df.index.rename('wind_code', inplace=True)
    info_df.reset_index(inplace=True)
    bunch_insert_on_duplicate_update(info_df,
                                     table_name,
                                     engine_md,
                                     dtype=dtype)
    logging.info("更新 %s 完成 存量数据 %d 条", table_name, len(info_df))
    if not has_table and engine_md.has_table(table_name):
        alter_table_2_myisam(engine_md, [table_name])
        build_primary_key([table_name])

    # 更新 code_mapping 表
    update_from_info_table(table_name)
예제 #20
0
def import_edb_monthly():
    table_name = 'wind_edb_monthly'
    has_table = engine_md.has_table(table_name)
    PMI_FIELD_CODE_2_CN_DIC = {
        "M0017126": ("PMI", date(2005, 1, 1)),
        "M0017127": ("PMI:生产", date(2005, 1, 1)),
        "M0017128": ("PMI:新订单", date(2005, 1, 1)),
        "M0017129": ("PMI:新出口订单", date(2005, 1, 1)),
        "M0017130": ("PMI:在手订单", date(2005, 1, 1)),
        "M0017131": ("PMI:产成品库存", date(2005, 1, 1)),
        "M0017132": ("PMI:采购量", date(2005, 1, 1)),
        "M0017133": ("PMI:进口", date(2005, 1, 1)),
        "M5766711": ("PMI:出厂价格", date(2005, 1, 1)),
        "M0017134": ("PMI:主要原材料购进价格", date(2005, 1, 1)),
        "M0017135": ("PMI:原材料库存", date(2005, 1, 1)),
        "M0017136": ("PMI:从业人员", date(2005, 1, 1)),
        "M0017137": ("PMI:供货商配送时间", date(2005, 1, 1)),
        "M5207790": ("PMI:生产经营活动预期", date(2005, 1, 1)),
        "M5206738": ("PMI:大型企业", date(2005, 1, 1)),
        "M5206739": ("PMI:中型企业", date(2005, 1, 1)),
        "M5206740": ("PMI:小型企业", date(2005, 1, 1)),
        "M5407921": ("克强指数:当月值", date(2009, 7, 1)),
        "M0000612": ("CPI:当月同比", date(1990, 1, 1)),
        "M0000616": ("CPI:食品:当月同比", date(1990, 1, 1)),
        "M0000613": ("CPI:非食品:当月同比", date(1990, 1, 1)),
        "M0000614": ("CPI:消费品:当月同比", date(1990, 1, 1)),
        "M0000615": ("CPI:服务:当月同比", date(1990, 1, 1)),
        "M0000705": ("CPI:环比", date(1990, 1, 1)),
        "M0000706": ("CPI:食品:环比", date(1990, 1, 1)),
        "M0061581": ("CPI:非食品:环比", date(1990, 1, 1)),
        "M0061583": ("CPI:消费品:环比", date(1990, 1, 1)),
        "M0001227": ("PPI:全部工业品:当月同比", date(1996, 10, 1)),
        "M0061585": ("PPI:全部工业品:环比", date(2002, 1, 1)),
        "M0001228": ("PPI:生产资料:当月同比", date(1996, 10, 1)),
        "M0066329": ("PPI:生产资料:环比", date(2011, 1, 1)),
        "M0001232": ("PPI:生活资料:当月同比", date(1996, 10, 1)),
        "M0066333": ("PPI:生活资料:环比", date(2011, 1, 1)),
    }
    # 设置表属性类型
    param_list = [
        ('field_name', String(45)),
        ('trade_date', Date),
        ('val', DOUBLE),
    ]
    dtype = {key: val for key, val in param_list}
    dtype['field_code'] = String(20)
    data_len = len(PMI_FIELD_CODE_2_CN_DIC)
    if has_table:
        sql_str = """select field_code, max(trade_date) trade_date_max from wind_edb_monthly group by field_code"""
        # 获取数据库中最大日期
        with with_db_session(engine_md) as session:
            table = session.execute(sql_str)
            field_date_dic = {row[0]: row[1] for row in table.fetchall()}

    else:
        sql_str = f"""
        CREATE TABLE {table_name} (
         `field_code` varchar(20) NOT NULL,
         `field_name` varchar(45) DEFAULT NULL,
         `trade_date` date NOT NULL,
         `val` double DEFAULT NULL,
         PRIMARY KEY (`field_code`,`trade_date`)
       ) ENGINE=MyISAM DEFAULT CHARSET=utf8 COMMENT='保存wind edb 宏观经济数据';"""

    # 循环更新
    for data_num, (wind_code,
                   (field_name,
                    date_from)) in enumerate(PMI_FIELD_CODE_2_CN_DIC.items(),
                                             start=1):
        if wind_code in field_date_dic:
            date_from = field_date_dic[wind_code] + timedelta(days=1)
        date_to = date.today() - timedelta(days=1)
        logger.info('%d/%d) %s %s [%s %s]', data_num, data_len, wind_code,
                    field_name, date_from, date_to)
        try:
            data_df = invoker.edb(wind_code, date_from, date_to,
                                  "Fill=Previous")
        except APIError as exp:
            logger.exception("%d/%d) %s 执行异常", data_num, data_len, wind_code)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                break
        if data_df is None or data_df.shape[0] == 0:
            continue
        trade_date_max = str_2_date(max(data_df.index))
        if trade_date_max <= date_from:
            continue
        data_df.index.rename('trade_date', inplace=True)
        data_df.reset_index(inplace=True)
        data_df.rename(columns={wind_code.upper(): 'val'}, inplace=True)
        data_df['field_code'] = wind_code
        data_df['field_name'] = field_name
        # data_df.to_sql('wind_edb_monthly', engine_md, if_exists='append', index=False)
        bunch_insert_on_duplicate_update(data_df,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])
예제 #21
0
def import_future_info(chain_param=None):
    """
    更新期货合约列表信息
    :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用
    :return:
    """
    table_name = 'ifind_future_info'
    has_table = engine_md.has_table(table_name)
    logger.info("更新 %s [%s] 开始", table_name, has_table)
    # 获取已存在合约列表
    if has_table:
        sql_str = f'SELECT ths_code, ths_start_trade_date_future FROM {table_name}'
        with with_db_session(engine_md) as session:
            table = session.execute(sql_str)
            code_ipo_date_dic = dict(table.fetchall())
        exchange_latest_ipo_date_dic = get_exchange_latest_data()
    else:
        code_ipo_date_dic = {}
        exchange_latest_ipo_date_dic = {}

    exchange_sectorid_dic_list = [
        {
            'exch_eng': 'SHFE',
            'exchange_name': '上海期货交易所',
            'sectorid': '091001',
            'date_establish': '1995-05-10'
        },
        {
            'exch_eng': 'CFFEX',
            'exchange_name': '中国金融期货交易所',
            'sectorid': '091004',
            'date_establish': '2013-09-10'
        },
        {
            'exch_eng': 'DCE',
            'exchange_name': '大连商品交易所',
            'sectorid': '091002',
            'date_establish': '1999-01-10'
        },
        {
            'exch_eng': 'CZCE',
            'exchange_name': '郑州商品交易所',
            'sectorid': '091003',
            'date_establish': '1999-01-10'
        },
    ]

    # 字段列表及参数
    indicator_param_list = [
        ('ths_future_short_name_future', '', String(50)),
        ('ths_future_code_future', '', String(20)),
        ('ths_sec_type_future', '', String(20)),
        ('ths_td_variety_future', '', String(20)),
        ('ths_td_unit_future', '', DOUBLE),
        ('ths_pricing_unit_future', '', String(20)),
        ('ths_mini_chg_price_future', '', DOUBLE),
        ('ths_chg_ratio_lmit_future', '', DOUBLE),
        ('ths_td_deposit_future', '', DOUBLE),
        ('ths_start_trade_date_future', '', Date),
        ('ths_last_td_date_future', '', Date),
        ('ths_last_delivery_date_future', '', Date),
        ('ths_delivery_month_future', '', String(10)),
        ('ths_listing_benchmark_price_future', '', DOUBLE),
        ('ths_initial_td_deposit_future', '', DOUBLE),
        ('ths_contract_month_explain_future', '', String(120)),
        ('ths_td_time_explain_future', '', String(120)),
        ('ths_last_td_date_explian_future', '', String(120)),
        ('ths_delivery_date_explain_future', '', String(120)),
        ('ths_exchange_short_name_future', '', String(50)),
        ('ths_contract_en_short_name_future', '', String(50)),
        ('ths_contract_en_name_future', '', String(50)),
    ]
    json_indicator, json_param = unzip_join(
        [(key, val) for key, val, _ in indicator_param_list], sep=';')

    # 设置 dtype
    dtype = {key: val for key, _, val in indicator_param_list}
    dtype['ths_code'] = String(20)
    dtype['exch_eng'] = String(20)

    # 获取合约列表
    code_set = set()
    ndays_per_update = 90
    # 获取历史期货合约列表信息
    sector_count = len(exchange_sectorid_dic_list)
    for num, exchange_sectorid_dic in enumerate(exchange_sectorid_dic_list,
                                                start=1):
        exchange_name = exchange_sectorid_dic['exchange_name']
        exch_eng = exchange_sectorid_dic['exch_eng']
        sector_id = exchange_sectorid_dic['sectorid']
        date_establish = exchange_sectorid_dic['date_establish']
        # 计算获取合约列表的起始日期
        date_since = str_2_date(
            exchange_latest_ipo_date_dic.setdefault(exch_eng, date_establish))
        date_yestoday = date.today() - timedelta(days=1)
        logger.info("%d/%d) %s[%s][%s] %s ~ %s", num, sector_count,
                    exchange_name, exch_eng, sector_id, date_since,
                    date_yestoday)
        while date_since <= date_yestoday:
            date_since_str = date_2_str(date_since)
            # #数据池-板块_板块成分-日期;同花顺代码;证券名称;当日行情端证券名称(仅股票节点有效)-iFinD数据接口
            # 获取板块成分(期货商品的合约)
            # THS_DP('block','2021-01-15;091002003','date:Y,thscode:Y,security_name:Y,security_name_in_time:Y')
            try:
                future_info_df = invoker.THS_DataPool(
                    'block', '%s;%s' % (date_since_str, sector_id),
                    'thscode:Y,security_name:Y')
            except APIError as exp:
                if exp.ret_dic['error_code'] in (
                        -4001,
                        -4210,
                ):
                    future_info_df = None
                else:
                    logger.exception("THS_DataPool %s 获取失败, '%s;%s'",
                                     exchange_name, date_since_str, sector_id)
                    break
            # if future_info_df is None or future_info_df.shape[0] == 0:
            #     break
            if future_info_df is not None and future_info_df.shape[0] > 0:
                code_set |= set(future_info_df['THSCODE'])

            if date_since >= date_yestoday:
                break
            else:
                date_since += timedelta(days=ndays_per_update)
                if date_since > date_yestoday:
                    date_since = date_yestoday

        if DEBUG:
            break

    # 获取合约列表
    code_list = [wc for wc in code_set if wc not in code_ipo_date_dic]
    # 获取合约基本信息
    if len(code_list) > 0:
        for code_list in split_chunk(code_list, 500):
            future_info_df = invoker.THS_BasicData(code_list, json_indicator,
                                                   json_param)
            if future_info_df is None or future_info_df.shape[0] == 0:
                data_count = 0
                logger.warning("更新 %s 结束 %d 条记录被更新", table_name, data_count)
            else:
                # 补充 exch_eng 字段
                future_info_df['exch_eng'] = ''
                for exchange_sectorid_dic in exchange_sectorid_dic_list:
                    future_info_df['exch_eng'][
                        future_info_df['ths_exchange_short_name_future'] ==
                        exchange_sectorid_dic[
                            'exchange_name']] = exchange_sectorid_dic[
                                'exch_eng']

                data_count = bunch_insert_on_duplicate_update(
                    future_info_df,
                    table_name,
                    engine_md,
                    dtype,
                    primary_keys=['ths_code'],
                    schema=config.DB_SCHEMA_MD)
                logger.info("更新 %s 结束 %d 条记录被更新", table_name, data_count)
예제 #22
0
def _test_check_accumulation_cols():
    label = 'revenue'  # 周期增长
    label2 = 'revenue_season'  # 非周期增长
    df = pd.DataFrame({
        'report_date': [
            str_2_date('2000-3-31'),
            str_2_date('2000-6-30'),
            str_2_date('2000-9-30'),
            str_2_date('2000-12-31'),
            str_2_date('2001-3-31'),
            str_2_date('2001-6-30'),
            str_2_date('2001-12-31'),
            str_2_date('2002-6-30'),
            str_2_date('2002-12-31'),
            str_2_date('2003-3-31'),
            str_2_date('2003-12-31')
        ],
        label: [200, 400, 600, 800, np.nan, 600, 1200, 700, 1400, 400, 1600],
        label2: [200, 200, 200, 200, 200, 400, 600, 700, 700, 400, 400],
    })
    df.set_index('report_date', drop=False, inplace=True)
    df.sort_index(inplace=True)
    print(df)
    accumulation_col_name_list = check_accumulation_cols(df)
    print("accumulation_col_name_list", accumulation_col_name_list)
    assert len(
        accumulation_col_name_list) == 1, f'{accumulation_col_name_list} 长度错误'
    assert 'revenue' in accumulation_col_name_list
예제 #23
0
def generate_md_with_adj_factor(instrument_type: str,
                                method: typing.Optional[Method] = Method.division):
    """
    将指定期货品种生产复权后价格。
    主力合约名称 f"{instrument_type.upper()}9999"
    次主力合约名称 f"{instrument_type.upper()}8888"
    仅针对 OHLC 四个价格进行复权处理
    """
    table_name = 'ifind_future_adj_factor'
    sql_str = f"""SELECT trade_date, instrument_id_main, adj_factor_main, 
        instrument_id_secondary, adj_factor_secondary
        FROM {table_name}
        where instrument_type=%s and method=%s"""
    adj_factor_df = pd.read_sql(sql_str, engine_md, params=[instrument_type, method.name])
    adj_factor_df['trade_date_larger_than'] = adj_factor_df['trade_date'].shift(1)
    adj_factor_df.set_index(['trade_date_larger_than', 'trade_date'], inplace=True)

    instrument_id_set = set(adj_factor_df['instrument_id_main']) | set(adj_factor_df['instrument_id_secondary'])
    in_clause_str = "'" + "', '".join(instrument_id_set) + "'"
    # daily_table_name = 'ifind_future_daily'
    sql_str = f"""SELECT ths_code Contract, `time` trade_date, `Open`, `High`, `Low`, `Close`, Volume, openInterest OI
    FROM ifind_future_daily 
    where ths_code in ({in_clause_str})
    """
    daily_df = pd.read_sql(sql_str, engine_md)
    n_final = adj_factor_df.shape[0]
    data_count = 0
    for n, ((trade_date_larger_than, trade_date), adj_factor_s) in enumerate(adj_factor_df.iterrows(), start=1):
        # 截止当日下午3点收盘,保守计算,延迟1个小时,到16点
        if pd.isna(trade_date_larger_than):
            start = None
        else:
            start = trade_date_larger_than + timedelta(days=1)

        if n == n_final:
            if start is None:
                start = str_2_date('1990-01-01')
                end = str_2_date('2090-12-31')
            else:
                end = start + timedelta(days=365)
        else:
            end = trade_date
            if start is None:
                start = end - timedelta(days=365)

        # 对 主力、次主力合约复权
        is_match = (
                           daily_df['code'] >= pd.to_datetime(start)
                   ) & (
                           daily_df['code'] < pd.to_datetime(end)
                   )
        instrument_id_main = adj_factor_s['instrument_id_main']
        adj_factor_main = adj_factor_s['adj_factor_main']
        main_df = daily_df[(daily_df['code'] == instrument_id_main) & is_match].copy()
        instrument_id_secondary = adj_factor_s['instrument_id_secondary']
        adj_factor_secondary = adj_factor_s['adj_factor_secondary']
        sec_df = daily_df[(daily_df['code'] == instrument_id_secondary) & is_match].copy()
        rename_dic = {_: f"{_}Next" for _ in sec_df.columns if _ != 'trade_date'}
        sec_df.rename(columns=rename_dic, inplace=True)
        main_sec_df = pd.merge(main_df, sec_df, on='trade_date')
        main_sec_df['instrument_type'] = instrument_type
        main_sec_df['adj_factor_main'] = adj_factor_main
        main_sec_df['adj_factor_secondary'] = adj_factor_secondary
        dtype = {
            'trade_date': Date,
            'Contract': String(20),
            'ContractNext': String(20),
            'instrument_type': String(20),
            'Close': DOUBLE,
            'CloseNext': DOUBLE,
            'Volume': DOUBLE,
            'VolumeNext': DOUBLE,
            'OI': DOUBLE,
            'OINext': DOUBLE,
            'Open': DOUBLE,
            'OpenNext': DOUBLE,
            'High': DOUBLE,
            'HighNext': DOUBLE,
            'Low': DOUBLE,
            'LowNext': DOUBLE,
            'adj_factor_main': DOUBLE,
            'adj_factor_secondary': DOUBLE
        }
        table_name = 'wind_future_continuous_no_adj'
        bunch_insert_on_duplicate_update(
            main_sec_df, table_name, engine_md, dtype=dtype,
            primary_keys=['instrument_type', 'trade_date'], schema=config.DB_SCHEMA_MD
        )

        main_sec_df['Open'] *= adj_factor_main
        main_sec_df['High'] *= adj_factor_main
        main_sec_df['Low'] *= adj_factor_main
        main_sec_df['Close'] *= adj_factor_main
        main_sec_df['Volume'] *= adj_factor_main
        main_sec_df['OI'] *= adj_factor_main
        main_sec_df['OpenNext'] *= adj_factor_secondary
        main_sec_df['HighNext'] *= adj_factor_secondary
        main_sec_df['LowNext'] *= adj_factor_secondary
        main_sec_df['CloseNext'] *= adj_factor_secondary
        main_sec_df['VolumeNext'] *= adj_factor_secondary
        main_sec_df['OINext'] *= adj_factor_secondary
        table_name = 'wind_future_continuous_adj'
        bunch_insert_on_duplicate_update(
            main_sec_df, table_name, engine_md, dtype=dtype,
            primary_keys=['instrument_type', 'trade_date'], schema=config.DB_SCHEMA_MD
        )
        data_count += main_sec_df.shape[0]
        logger.info("%s [%s ~ %s] 包含 %d 条数据,复权保存完成",
                    instrument_type, start, end, main_sec_df.shape[0])

    logger.info(f'{instrument_type.upper()} {data_count} 条记录 复权保存完成')
예제 #24
0
def import_coinbar_on_freq_min(freq, code_set=None, base_begin_time=None):
    """
    抓取 日级别以上数据[ 60min, 30min, 15min, 5min, 1min ]级别
    :param freq:
    :param code_set:
    :param base_begin_time:
    :return:
    """
    if base_begin_time is not None and not isinstance(base_begin_time, date):
        base_begin_time = str_2_date(base_begin_time)
    table_name = 'tushare_coin_md_' + freq
    info_table_name = 'tushare_coin_pair_info'
    has_table = engine_md.has_table(table_name)
    if has_table:
        sql_str = """SELECT exchange, exchange_pair, date_frm, 
                if(delist_date<end_date, delist_date, end_date) date_to
            FROM
            (
                SELECT info.exchange, info.exchange_pair, 
                    ifnull(trade_date_max_1, adddate(trade_date_latest,1)) date_frm, 
                    delist_date,
                    if(hour(now())<8, subdate(curdate(),2), subdate(curdate(),1)) end_date
                FROM 
                (
                    select exchange, exchange_pair,
                    ifnull(trade_date_latest_{freq},'2010-01-01') trade_date_latest,
                    delist_date_{freq} delist_date
                    from {info_table_name}
                ) info
                LEFT OUTER JOIN
                    (SELECT exchange, symbol, adddate(max(`date`),1) trade_date_max_1 
                     FROM {table_name} GROUP BY exchange, symbol) daily
                ON info.exchange = daily.exchange
                AND info.exchange_pair = daily.symbol
            ) tt
            WHERE date_frm <= if(delist_date<end_date, delist_date, end_date) 
            ORDER BY exchange, exchange_pair""".format(
            table_name=table_name, info_table_name=info_table_name, freq=freq)
    else:
        sql_str = """SELECT exchange, exchange_pair, date_frm, 
            if(delist_date<end_date, delist_date, end_date) date_to
        FROM
        (
            SELECT exchange, exchange_pair, 
                ifnull(trade_date_latest_{freq},date('2010-01-01')) date_frm, 
                delist_date_{freq} delist_date, 
                if(hour(now())<8, subdate(curdate(),2), subdate(curdate(),1)) end_date
            FROM {info_table_name} info 
            ORDER BY exchange, exchange_pair
        ) tt
        WHERE date_frm <= if(delist_date<end_date, delist_date, end_date) 
        ORDER BY exchange, exchange_pair""".format(
            info_table_name=info_table_name, freq=freq)
        logger.warning('%s 不存在,仅使用 %s 表进行计算日期范围', table_name, info_table_name)

    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 获取每只股票需要获取日线数据的日期区间
        code_date_range_dic = {
            (exchange, symbol): (date_from if base_begin_time is None else min(
                [date_from, base_begin_time]), date_to)
            for exchange, symbol, date_from, date_to in table.fetchall()
            if code_set is None or (exchange, symbol) in code_set
        }

    # 设置 dtype
    dtype = {
        'exchange': String(60),
        'symbol': String(60),
        'date': Date,
        'datetime': DateTime,
        'open': DOUBLE,
        'high': DOUBLE,
        'low': DOUBLE,
        'close': DOUBLE,
        'vol': DOUBLE,
    }

    # 更新 info 表 trade_date_latest 字段
    trade_date_latest_list = []
    update_trade_date_latest_str = """UPDATE tushare_coin_pair_info info
        SET info.trade_date_latest_daily = :trade_date_latest
        WHERE info.exchange = :exchange AND exchange_pair=:exchange_pair"""

    data_df_list, data_count, tot_data_count, code_count = [], 0, 0, len(
        code_date_range_dic)
    try:
        for num, ((exchange, exchange_pair),
                  (begin_time,
                   end_time)) in enumerate(code_date_range_dic.items(),
                                           start=1):
            begin_time_str = date_2_str(begin_time, DATE_FORMAT_STR)
            end_time_str = date_2_str(end_time, DATE_FORMAT_STR)
            logger.debug('%d/%d) %s %s [%s - %s]', num, code_count, exchange,
                         exchange_pair, begin_time, end_time)
            try:
                # data_df = pro.coinbar(exchange='huobi', symbol='gxsbtc', freq='1min', start_date='20180701', end_date='20180801')
                data_df = pro.coinbar(exchange=exchange,
                                      symbol=exchange_pair,
                                      freq=freq,
                                      start_date=begin_time_str,
                                      end_date=end_time_str)
            except Exception as exp:
                if len(exp.args) >= 1 and exp.args[0] == '系统内部错误':
                    trade_date_latest_list.append({
                        'exchange':
                        exchange,
                        'exchange_pair':
                        exchange_pair,
                        'trade_date_latest':
                        '2020-02-02',
                    })
                    logger.warning(
                        "coinbar(exchange='%s', symbol='%s', freq='%s', start_date='%s', end_date='%s') 系统内部错误",
                        exchange, exchange_pair, freq, begin_time_str,
                        end_time_str)
                    continue
                logger.exception(
                    "coinbar(exchange='%s', symbol='%s', freq='%s', start_date='%s', end_date='%s')",
                    exchange, exchange_pair, freq, begin_time_str,
                    end_time_str)
                raise exp from exp

            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df['exchange'] = exchange
                data_df['datetime'] = data_df['date']
                data_df['date'] = data_df['date'].apply(
                    lambda x: str_2_datetime(x).date())
                data_df_list.append(data_df)

            # 记录最新交易日变化
            trade_date_latest_list.append({
                'exchange': exchange,
                'exchange_pair': exchange_pair,
                'trade_date_latest': end_time_str,
            })
            # 大于阀值有开始插入
            if data_count >= 10000:
                data_df_all = pd.concat(data_df_list)
                # data_df_all.to_sql(table_name, engine_md, if_exists='append', index=False, dtype=dtype)
                data_count = bunch_insert_on_duplicate_update(
                    data_df_all, table_name, engine_md, dtype)
                tot_data_count += data_count
                data_df_list, data_count = [], 0

                # 更新 info 表 trade_date_latest 字段
                with with_db_session(engine_md) as session:
                    result = session.execute(update_trade_date_latest_str,
                                             params=trade_date_latest_list)
                    update_count = result.rowcount
                    session.commit()
                    logger.info('更新 %d 条交易对的最新交易 %s 信息', update_count, freq)
                trade_date_latest_list = []

            # 仅调试使用
            if DEBUG and len(data_df_list) > 1:
                break
    finally:
        if data_count > 0:
            data_df_all = pd.concat(data_df_list)
            # data_df_all.to_sql(table_name, engine_md, if_exists='append', index=False, dtype=dtype)
            data_count = bunch_insert_on_duplicate_update(
                data_df_all, table_name, engine_md, dtype)
            tot_data_count += data_count

        # 更新 info 表 trade_date_latest 字段
        if len(trade_date_latest_list) > 0:
            with with_db_session(engine_md) as session:
                result = session.execute(update_trade_date_latest_str,
                                         params=trade_date_latest_list)
                update_count = result.rowcount
                session.commit()
                logger.info('更新 %d 条交易对的最新交易日信息', update_count)

        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            # build_primary_key([table_name])
            create_pk_str = """ALTER TABLE {table_name}
                CHANGE COLUMN `exchange` `exchange` VARCHAR(60) NOT NULL FIRST,
                CHANGE COLUMN `symbol` `symbol` VARCHAR(60) NOT NULL AFTER `exchange`,
                CHANGE COLUMN `datetime` `datetime` DATETIME NOT NULL AFTER `symbol`,
                ADD PRIMARY KEY (`exchange`, `symbol`, `datetime`)""".format(
                table_name=table_name)
            with with_db_session(engine_md) as session:
                session.execute(create_pk_str)

        logging.info("更新 %s 完成 新增数据 %d 条", table_name, tot_data_count)
예제 #25
0
def import_index_daily_his(chain_param=None,
                           ths_code_set: set = None,
                           begin_time=None):
    """
    通过history接口将历史数据保存到 ifind_index_daily_his
    :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用
    :param ths_code_set:
    :param begin_time: 默认为None,如果非None则代表所有数据更新日期不得晚于该日期
    :return:
    """
    table_name = 'ifind_index_daily_his'
    if begin_time is not None and type(begin_time) == date:
        begin_time = str_2_date(begin_time)
    # THS_HistoryQuotes('600006.SH,600010.SH',
    # 'preClose,open,high,low,close,avgPrice,changeRatio,volume,amount,turnoverRatio,transactionAmount,totalShares,totalCapital,floatSharesOfAShares,floatSharesOfBShares,floatCapitalOfAShares,floatCapitalOfBShares,pe_ttm,pe,pb,ps,pcf',
    # 'Interval:D,CPS:1,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
    # '2018-06-30','2018-07-30')
    json_indicator, _ = unzip_join(
        [(key, val) for key, val, _ in INDICATOR_PARAM_LIST_INDEX_DAILY_HIS],
        sep=';')
    has_table = engine_md.has_table(table_name)
    if has_table:
        sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ifnull(trade_date_max_1, ths_index_base_period_index) date_frm, NULL,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM 
                    ifind_index_info info 
                LEFT OUTER JOIN
                    (SELECT ths_code, adddate(max(time),1) trade_date_max_1 FROM ifind_index_daily_his GROUP BY ths_code) daily
                ON info.ths_code = daily.ths_code
            ) tt
            WHERE date_frm <= if(NULL<end_date, NULL, end_date) 
            ORDER BY ths_code;"""
    else:
        logger.warning('%s 不存在,仅使用 ifind_index_info 表进行计算日期范围', table_name)
        sql_str = """SELECT ths_code, date_frm, if(NULL<end_date, NULL, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ths_index_base_period_index date_frm, NULL,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM ifind_index_info info 
            ) tt
            WHERE date_frm <= if(NULL<end_date, NULL, end_date) 
            ORDER BY ths_code"""

    # 计算每只股票需要获取日线数据的日期区间
    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 计算每只股票需要获取日线数据的日期区间
        code_date_range_dic = {
            ths_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for ths_code, date_from, date_to in table.fetchall()
            if ths_code_set is None or ths_code in ths_code_set
        }

    if TRIAL:
        date_from_min = date.today() - timedelta(days=(365 * 5))
        # 试用账号只能获取近5年数据
        code_date_range_dic = {
            ths_code: (max([date_from, date_from_min]), date_to)
            for ths_code, (date_from, date_to) in code_date_range_dic.items()
            if date_to is not None and date_from_min <= date_to
        }

    data_df_list, data_count, tot_data_count, code_count = [], 0, 0, len(
        code_date_range_dic)
    try:
        for num, (ths_code,
                  (begin_time,
                   end_time)) in enumerate(code_date_range_dic.items(),
                                           start=1):
            logger.debug('%d/%d) %s [%s - %s]', num, code_count, ths_code,
                         begin_time, end_time)
            data_df = invoker.THS_HistoryQuotes(
                ths_code, json_indicator,
                'Interval:D,CPS:1,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
                begin_time, end_time)
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df_list.append(data_df)
            # 大于阀值有开始插入
            if data_count >= 10000:
                data_count = bunch_insert_on_duplicate_update(
                    data_df, table_name, engine_md, DTYPE_INDEX_DAILY_HIS)
                tot_data_count += data_count
                data_df_list, data_count = [], 0

            # 仅调试使用
            if DEBUG and len(data_df_list) > 5:
                break
    finally:
        if data_count > 0:
            data_df_all = pd.concat(data_df_list)
            data_count = bunch_insert_on_duplicate_update(
                data_df_all, table_name, engine_md, DTYPE_INDEX_DAILY_HIS)
            tot_data_count += data_count

        logging.info("更新 %s 完成 新增数据 %d 条", table_name, tot_data_count)
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])
def import_private_fund_info(table_name, chain_param=None, get_df=False):
    # 初始化服务器接口,用于下载万得数据
    # table_name = 'fund_info'
    has_table = engine_md.has_table(table_name)
    types = {
        u'股票多头策略': 1000023122000000,
        u'股票多空策略': 1000023123000000,
        u'其他股票策略': 1000023124000000,
        u'阿尔法策略': 1000023125000000,
        u'其他市场中性策略': 1000023126000000,
        u'事件驱动策略': 1000023113000000,
        u'债券策略': 1000023114000000,
        u'套利策略': 1000023115000000,
        u'宏观策略': 1000023116000000,
        u'管理期货': 1000023117000000,
        u'组合基金策略': 1000023118000000,
        u'货币市场策略': 1000023119000000,
        u'多策略': 100002312000000,
        u'其他策略': 1000023121000000
    }
    df = pd.DataFrame()
    today = date.today().strftime('%Y-%m-%d')
    param_list = [
        ('FUND_SETUPDATE', Date),
        ('FUND_MATURITYDATE', Date),
        ('FUND_MGRCOMP', String(800)),
        ('FUND_EXISTINGYEAR', String(500)),
        ('FUND_PTMYEAR', String(30)),
        ('FUND_TYPE', String(20)),
        ('FUND_FUNDMANAGER', String(500)),
        ('SEC_NAME', String(2000)),
        ('STRATEGY_TYPE', String(200)),
        ('TRADE_DATE_LATEST', String(200)),
    ]
    col_name_dic = {
        col_name.upper(): col_name.lower()
        for col_name, _ in param_list
    }
    # 获取列表名
    col_name_list = [col_name.lower() for col_name in col_name_dic.keys()]
    param_str = ",".join(col_name_list[:8])
    # 设置dtype类型
    dtype = {key.lower(): val for key, val in param_list}
    dtype['wind_code'] = String(20)
    for i in types.keys():
        temp = invoker.wset("sectorconstituent",
                            "date=%s;sectorid=%s" % (today, str(types[i])))
        temp['strategy_type'] = i
        df = pd.concat([df, temp], axis=0)
        if DEBUG and len(df) > 1000:
            break
    # 插入数据库
    # 初始化数据库engine
    # 整理数据
    fund_types_df = df[['wind_code', 'sec_name', 'strategy_type']]
    fund_types_df.set_index('wind_code', inplace=True)
    # 获取基金基本面信息
    code_list = list(fund_types_df.index)  # df['wind_code']
    code_count = len(code_list)
    seg_count = 5000
    info_df = pd.DataFrame()
    for n in range(int(code_count / seg_count) + 1):
        num_start = n * seg_count
        num_end = (n + 1) * seg_count
        num_end = num_end if num_end <= code_count else code_count
        if num_start <= code_count:
            codes = ','.join(code_list[num_start:num_end])
            # 分段获取基金成立日期数据
            info2_df = invoker.wss(codes, param_str)
            logging.info('%05d ) [%d %d]' % (n, num_start, num_end))
            info_df = info_df.append(info2_df)
            if DEBUG and len(info_df) > 1000:
                break
        else:
            break
            # 整理数据插入数据库)
    info_df['FUND_SETUPDATE'] = info_df['FUND_SETUPDATE'].apply(
        lambda x: str_2_date(x))
    info_df['FUND_MATURITYDATE'] = info_df['FUND_MATURITYDATE'].apply(
        lambda x: str_2_date(x))
    info_df = fund_types_df.join(info_df, how='right')
    info_df.rename(columns=col_name_dic, inplace=True)
    info_df['trade_date_latest'] = None
    info_df.index.names = ['wind_code']
    info_df.reset_index(inplace=True)
    info_df.drop_duplicates(inplace=True)
    bunch_insert_on_duplicate_update(info_df,
                                     table_name,
                                     engine_md,
                                     dtype=dtype)
    logging.info('%d funds inserted' % len(info_df))
    if not has_table and engine_md.has_table(table_name):
        alter_table_2_myisam(engine_md, [table_name])
        build_primary_key([table_name])

    # 更新 code_mapping 表
    update_from_info_table(table_name)
    if get_df:
        return info_df
예제 #27
0
def import_smfund_daily(chain_param=None):
    """
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :return:
    """
    table_name = "wind_smfund_daily"
    has_table = engine_md.has_table(table_name)
    col_name_param_list = [
        ('next_pcvdate', Date),
        ('a_nav', DOUBLE),
        ('b_nav', DOUBLE),
        ('a_fs_inc', DOUBLE),
        ('b_fs_inc', DOUBLE),
        ('cur_interest', DOUBLE),
        ('next_interest', DOUBLE),
        ('ptm_year', DOUBLE),
        ('anal_pricelever', DOUBLE),
        ('anal_navlevel', DOUBLE),
        ('t1_premium', DOUBLE),
        ('t2_premium', DOUBLE),
        ('dq_status', String(50)),
        ('tm_type', TEXT),
        ('code_p', String(20)),
        ('trade_date', Date),
        ('open', DOUBLE),
        ('high', DOUBLE),
        ('low', DOUBLE),
        ('close', DOUBLE),
        ('volume', DOUBLE),
        ('amt', DOUBLE),
        ('pct_chg', DOUBLE),
        ('open_a', DOUBLE),
        ('high_a', DOUBLE),
        ('low_a', DOUBLE),
        ('close_a', DOUBLE),
        ('volume_a', DOUBLE),
        ('amt_a', DOUBLE),
        ('pct_chg_a', DOUBLE),
        ('open_b', DOUBLE),
        ('high_b', DOUBLE),
        ('low_b', DOUBLE),
        ('close_b', DOUBLE),
        ('volume_b', DOUBLE),
        ('amt_b', DOUBLE),
        ('pct_chg_b', DOUBLE),
    ]
    # wset的调用参数
    wind_indictor_str = ",".join([key for key, _ in col_name_param_list[:14]])
    # 设置dtype类型
    dtype = {key: val for key, val in col_name_param_list}

    date_ending = date.today() - ONE_DAY if datetime.now(
    ).hour < BASE_LINE_HOUR else date.today()
    date_ending_str = date_ending.strftime('%Y-%m-%d')
    # 对于 表格是否存在进行判断,取值
    if has_table:
        sql_str = """
            SELECT wind_code, ifnull(date, fund_setupdate) date_start, class_a_code, class_b_code
            FROM wind_smfund_info fi LEFT OUTER JOIN
            (SELECT code_p, adddate(max(trade_date), 1) trade_date_max FROM wind_smfund_daily GROUP BY code_p) smd
            ON fi.wind_code = smd.code_p
            WHERE fund_setupdate IS NOT NULL
            AND class_a_code IS NOT NULL
            AND class_b_code IS NOT NULL"""
    else:
        sql_str = """
            SELECT wind_code, ifnull(date, fund_setupdate) date_start, class_a_code, class_b_code
            FROM wind_smfund_info
            WHERE fund_setupdate IS NOT NULL
            AND class_a_code IS NOT NULL
            AND class_b_code IS NOT NULL"""
    df = pd.read_sql(sql_str, engine_md)
    df.set_index('wind_code', inplace=True)

    data_len = df.shape[0]
    logger.info('分级基金数量: %d', data_len)
    index_start = 1
    # 获取data_from
    for data_num, wind_code in enumerate(
            df.index, start=1):  # 可调整 # [100:min([df_count, 200])]
        if data_num < index_start:
            continue
        logger.info('%d/%d) %s start to import', data_num, data_len, wind_code)
        date_from = df.loc[wind_code, 'date_start']
        date_from = str_2_date(date_from)
        if type(date_from) not in (date, datetime, Timestamp):
            logger.info('%d/%d) %s has no fund_setupdate will be ignored',
                        data_num, data_len, wind_code)
            # print(df.iloc[i, :])
            continue
        date_from_str = date_from.strftime('%Y-%m-%d')
        if date_from > date_ending:
            logger.info('%d/%d) %s %s %s 跳过', data_num, data_len, wind_code,
                        date_from_str, date_ending_str)
            continue
        # 设置wsd接口参数
        field = "open,high,low,close,volume,amt,pct_chg"
        # wsd_cache(w, code, field, beginTime, today, "")
        try:
            df_p = invoker.wsd(wind_code, field, date_from_str,
                               date_ending_str, "")
        except APIError as exp:
            logger.exception("%d/%d) %s 执行异常", data_num, data_len, wind_code)
            if exp.ret_dic.setdefault('error_code', 0) in (
                    -40520007,  # 没有可用数据
                    -40521009,  # 数据解码失败。检查输入参数是否正确,如:日期参数注意大小月月末及短二月
            ):
                continue
            else:
                break
        if df_p is None:
            continue
        df_p.rename(columns=lambda x: x.swapcase(), inplace=True)
        df_p['code_p'] = wind_code
        code_a = df.loc[wind_code, 'class_a_code']
        if code_a is None:
            print('%d %s has no code_a will be ignored' %
                  (data_num, wind_code))
            # print(df.iloc[i, :])
            continue
        # 获得数据存储到df_a里面
        # df_a = wsd_cache(w, code_a, field, beginTime, today, "")
        df_a = invoker.wsd(code_a, field, date_from_str, date_ending_str, "")
        df_a.rename(columns=lambda x: x.swapcase() + '_a', inplace=True)
        code_b = df.loc[wind_code, 'class_b_code']
        # df_b = wsd_cache(w, code_b, field, beginTime, today, "")
        # 获取接口数据 获得df_b
        df_b = invoker.wsd(code_b, field, date_from_str, date_ending_str, "")
        df_b.columns = df_b.columns.map(lambda x: x.swapcase() + '_b')
        new_df = pd.DataFrame()
        for date_str in df_p.index:
            # time = date_str.date().strftime('%Y-%m-%d')
            field = "date=%s;windcode=%s;field=%s" % (date_str, wind_code,
                                                      wind_indictor_str)
            # wset_cache(w, "leveragedfundinfo", field)
            temp = invoker.wset("leveragedfundinfo", field)
            temp['date'] = date_str
            new_df = new_df.append(temp)
            if DEBUG and len(new_df) > 8:
                break
        # 将获取信息进行表格联立 合并
        new_df['next_pcvdate'] = new_df['next_pcvdate'].map(
            lambda x: str_2_date(x) if x is not None else x)
        new_df.set_index('date', inplace=True)
        one_df = pd.concat([df_p, df_a, df_b, new_df], axis=1)
        one_df.index.rename('trade_date', inplace=True)
        one_df.reset_index(inplace=True)
        #    one_df['date'] = one_df['date'].map(lambda x: x.date())
        one_df.rename(columns={'date': 'trade_date'}, inplace=True)
        # one_df.rename(columns={"index":'trade_date'},inplace=True)
        # one_df.set_index(['code_p', 'trade_date'], inplace=True)
        bunch_insert_on_duplicate_update(one_df,
                                         table_name,
                                         engine_md,
                                         dtype=dtype)
        logger.info('%d/%d) %s import success', data_num, data_len, wind_code)
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            # build_primary_key([table_name])
            # 手动创建主键, 主键不是wind_code
            create_pk_str = """ALTER TABLE {table_name}
                CHANGE COLUMN `code_p` `code_p` VARCHAR(20) NOT NULL ,
                CHANGE COLUMN `trade_date` `trade_date` DATE NOT NULL ,
                ADD PRIMARY KEY (`code_p`, `trade_date`)""".format(
                table_name=table_name)
            with with_db_session(engine_md) as session:
                session.execute(create_pk_str)
def import_private_fund_nav_daily(chain_param=None, wind_code_list=None):
    """
    :param chain_param:  在celery 中將前面結果做爲參數傳給後面的任務
    :param wind_code_list:
    :return:
    """
    table_name = 'wind_fund_nav_daily'
    # 初始化数据下载端口
    # 初始化数据库engine
    # 链接数据库,并获取fundnav旧表
    # with get_db_session(engine) as session:
    #     table = session.execute('select wind_code, ADDDATE(max(trade_date),1) from wind_fund_nav group by wind_code')
    #     fund_trade_date_begin_dic = dict(table.fetchall())
    # 获取wind_fund_info表信息
    col_name_param_list = [
        ('trade_date', Date),
        ('nav', DOUBLE),
        ('nav_acc', DOUBLE),
        ('nav_date', Date),
    ]
    dtype = {col_name: val for col_name, val in col_name_param_list}
    dtype['wind_code'] = String(200)
    has_table = engine_md.has_table(table_name)
    if has_table:
        fund_info_df = pd.read_sql_query(
            """SELECT DISTINCT fi.wind_code AS wind_code, 
                IFNULL(fund_setupdate, if(trade_date_latest BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1), ADDDATE(trade_date_latest,1) , fund_setupdate) ) date_from,
                if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to 
                FROM fund_info fi
                LEFT JOIN
                (
                SELECT wind_code, ADDDATE(max(trade_date),1) trade_date_from FROM wind_fund_nav_daily
                GROUP BY wind_code
                ) wfn
                ON fi.wind_code = wfn.wind_code""", engine_md)
    else:
        fund_info_df = pd.read_sql_query(
            """SELECT DISTINCT fi.wind_code AS wind_code, 
                IFNULL(fund_setupdate, if(trade_date_latest BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1), ADDDATE(trade_date_latest,1) , fund_setupdate) ) date_from,
                if(fund_maturitydate BETWEEN '1900-01-01' AND ADDDATE(CURDATE(), -1),fund_maturitydate,ADDDATE(CURDATE(), -1)) date_to 
                FROM fund_info fi
                ORDER BY wind_code;""", engine_md)

        wind_code_date_frm_to_dic = {
            wind_code: (str_2_date(date_from), str_2_date(date_to))
            for wind_code, date_from, date_to in
            zip(fund_info_df['wind_code'], fund_info_df['date_from'],
                fund_info_df['date_to'])
        }
        fund_info_df.set_index('wind_code', inplace=True)
        if wind_code_list is None:
            wind_code_list = list(fund_info_df.index)
        else:
            wind_code_list = list(
                set(wind_code_list) & set(fund_info_df.index))
        # 结束时间
        date_last_day = date.today() - timedelta(days=1)
        # date_end_str = date_end.strftime(STR_FORMAT_DATE)

        fund_nav_all_df = []
        no_data_count = 0
        code_count = len(wind_code_list)
        # 对每个新获取的基金名称进行判断,若存在 fundnav 中,则只获取部分净值
        wind_code_trade_date_latest_dic = {}
        date_gap = timedelta(days=10)
        try:
            for num, wind_code in enumerate(wind_code_list):
                date_begin, date_end = wind_code_date_frm_to_dic[wind_code]

                # if date_end > date_last_day:
                #     date_end = date_last_day
                if date_begin > date_end:
                    continue
                # 设定数据获取的起始日期
                # wind_code_trade_date_latest_dic[wind_code] = date_to
                # if wind_code in fund_trade_date_begin_dic:
                #     trade_latest = fund_trade_date_begin_dic[wind_code]
                #     if trade_latest > date_end:
                #         continue
                #     date_begin = max([date_begin, trade_latest])
                # if date_begin is None:
                #     continue
                # elif isinstance(date_begin, str):
                #     date_begin = datetime.strptime(date_begin, STR_FORMAT_DATE).date()

                if isinstance(date_begin, date):
                    if date_begin.year < 1900:
                        continue
                    if date_begin > date_end:
                        continue
                    date_begin_str = date_begin.strftime('%Y-%m-%d')
                else:
                    logger.error("%s date_begin:%s", wind_code, date_begin)
                    continue

                if isinstance(date_end, date):
                    if date_begin.year < 1900:
                        continue
                    if date_begin > date_end:
                        continue
                    date_end_str = date_end.strftime('%Y-%m-%d')
                else:
                    logger.error("%s date_end:%s", wind_code, date_end)
                    continue
                # 尝试获取 fund_nav 数据
                for k in range(2):
                    try:
                        fund_nav_tmp_df = invoker.wsd(
                            codes=wind_code,
                            fields='nav,NAV_acc,NAV_date',
                            beginTime=date_begin_str,
                            endTime=date_end_str,
                            options='Fill=Previous')
                        trade_date_latest = datetime.strptime(
                            date_end_str, '%Y-%m-%d').date() - date_gap
                        wind_code_trade_date_latest_dic[
                            wind_code] = trade_date_latest
                        break
                    except APIError as exp:
                        # -40520007z
                        if exp.ret_dic.setdefault('error_code',
                                                  0) == -40520007:
                            trade_date_latest = datetime.strptime(
                                date_end_str, '%Y-%m-%d').date() - date_gap
                            wind_code_trade_date_latest_dic[
                                wind_code] = trade_date_latest
                        logger.error("%s Failed, ErrorMsg: %s" %
                                     (wind_code, str(exp)))
                        continue
                    except Exception as exp:
                        logger.error("%s Failed, ErrorMsg: %s" %
                                     (wind_code, str(exp)))
                        continue
                else:
                    fund_nav_tmp_df = None

                if fund_nav_tmp_df is None:
                    logger.info('%s No data', wind_code)
                    # del wind_code_trade_date_latest_dic[wind_code]
                    no_data_count += 1
                    logger.warning('%d funds no data', no_data_count)
                else:
                    fund_nav_tmp_df.dropna(how='all', inplace=True)
                    df_len = fund_nav_tmp_df.shape[0]
                    if df_len == 0:
                        continue
                    fund_nav_tmp_df['wind_code'] = wind_code
                    # 此处删除 trade_date_latest 之后再加上,主要是为了避免因抛出异常而导致的该条数据也被记录更新
                    # del wind_code_trade_date_latest_dic[wind_code]
                    trade_date_latest = fund_nav_df_2_sql(table_name,
                                                          fund_nav_tmp_df,
                                                          engine_md,
                                                          is_append=True)
                    if trade_date_latest is None:
                        logger.error('%s[%d] data insert failed', wind_code)
                    else:
                        wind_code_trade_date_latest_dic[
                            wind_code] = trade_date_latest
                        logger.info('%d) %s updated, %d funds left', num,
                                    wind_code, code_count - num)

                if DEBUG and num > 1:  # 调试使用
                    break
            #
        finally:
            import_wind_fund_nav_to_nav()
            update_trade_date_latest(wind_code_trade_date_latest_dic)
            if not has_table and engine_md.has_table(table_name):
                alter_table_2_myisam(engine_md, [table_name])
                build_primary_key([table_name])
        return fund_nav_all_df
예제 #29
0
    def load_history(self, date_from=None, date_to=None, load_md_count=None) -> (pd.DataFrame, dict):
        """
        从文件中加载历史数据
        :param date_from: None代表沿用类的 init_md_date_from 属性
        :param date_to: None代表沿用类的 init_md_date_from 属性
        :param load_md_count: 0 代表不限制,None代表沿用类的 init_load_md_count 属性,其他数字代表相应的最大加载条数
        :return: md_df 或者
         ret_data {
            'md_df': md_df, 'datetime_key': 'ts_start',
            'date_key': **, 'time_key': **, 'microseconds_key': **
            }
        """
        # 如果 init_md_date_from 以及 init_md_date_to 为空,则不加载历史数据
        if self.init_md_date_from is None and self.init_md_date_to is None:
            ret_data = {'md_df': None, 'datetime_key': self.datetime_key, 'date_key': self.date_key,
                        'time_key': self.time_key, 'microseconds_key': self.microseconds_key}
            return ret_data

        # 加载历史数据
        self.logger.debug("加载历史数据 %s", self.file_path)
        date_col_name_set = {self.datetime_key, self.date_key, self.timestamp_key}
        if None in date_col_name_set:
            date_col_name_set.remove(None)
        md_df = pd.read_csv(self.file_path, parse_dates=list(date_col_name_set))
        if self.ffill_on_load_history:
            md_df = md_df.ffill()
        if self.timestamp_key is not None:
            timestamp_key = self.timestamp_key
        elif self.datetime_key is not None:
            timestamp_key = self.datetime_key
        elif self.date_key is not None:
            timestamp_key = self.date_key
        else:
            timestamp_key = None
            self.logger.warning('没有设置 timestamp_key、datetime_key、date_key 中的任何一个,无法进行日期过滤')

        # 对日期区间进行过滤
        filter_mark = None
        if timestamp_key is not None:
            if date_from is None:
                date_from = pd.Timestamp(str_2_date(self.init_md_date_from))
            else:
                date_from = pd.Timestamp(str_2_date(date_from))

            if date_from is not None:
                filter_mark = md_df[timestamp_key] >= date_from

            if date_to is None:
                date_to = pd.Timestamp(str_2_date(self.init_md_date_to))
            else:
                date_to = pd.Timestamp(str_2_date(date_to))

            if date_to is not None:
                if filter_mark is None:
                    filter_mark = md_df[timestamp_key] <= date_to
                else:
                    filter_mark &= md_df[timestamp_key] <= date_to

        if filter_mark is not None:
            ret_df = md_df[filter_mark]
        else:
            ret_df = md_df

        # 返回数据
        ret_data = {'md_df': ret_df, 'datetime_key': self.datetime_key, 'date_key': self.date_key,
                    'time_key': self.time_key, 'microseconds_key': self.microseconds_key,
                    'symbol_key': self.symbol_key, 'close_key': 'close'}
        return ret_data
예제 #30
0
def import_pub_fund_daily(chain_param=None,
                          ths_code_set: set = None,
                          begin_time=None):
    """
    通过history接口将历史数据保存到 ifind_pub_fund_daily
    :param chain_param: 该参数仅用于 task.chain 串行操作时,上下传递参数使用
    :param ths_code_set:
    :param begin_time: 默认为None,如果非None则代表所有数据更新日期不得晚于该日期
    :return:
    """
    table_name = 'ifind_pub_fund_daily'
    has_table = engine_md.has_table(table_name)
    if begin_time is not None and type(begin_time) == date:
        begin_time = str_2_date(begin_time)

    indicator_param_list = [('netAssetValue', '', DOUBLE),
                            ('adjustedNAV', '', DOUBLE),
                            ('accumulatedNAV', '', DOUBLE)]
    # THS_HistoryQuotes('600006.SH,600010.SH',
    # 'preClose,open,high,low,close,avgPrice,changeRatio,volume,amount,turnoverRatio,transactionAmount,totalShares,totalCapital,floatSharesOfAShares,floatSharesOfBShares,floatCapitalOfAShares,floatCapitalOfBShares,pe_ttm,pe,pb,ps,pcf',
    # 'Interval:D,CPS:1,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
    # '2018-06-30','2018-07-30')
    json_indicator, _ = unzip_join([(key, val)
                                    for key, val, _ in indicator_param_list],
                                   sep=';')
    if has_table:
        sql_str = """SELECT ths_code, date_frm, if(ths_fund_expiry_date_fund<end_date, ths_fund_expiry_date_fund, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ifnull(trade_date_max_1, ths_lof_listed_date_fund) date_frm, ths_fund_expiry_date_fund,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM 
                    ifind_pub_fund_info info 
                LEFT OUTER JOIN
                    (SELECT ths_code, adddate(max(time),1) trade_date_max_1 FROM {table_name} GROUP BY ths_code) daily
                ON info.ths_code = daily.ths_code
            ) tt
            WHERE date_frm <= if(ths_fund_expiry_date_fund<end_date, ths_fund_expiry_date_fund, end_date) 
            ORDER BY ths_code""".format(table_name=table_name)
    else:
        logger.warning('%s 不存在,仅使用 ifind_pub_fund_info 表进行计算日期范围', table_name)
        sql_str = """SELECT ths_code, date_frm, if(ths_fund_expiry_date_fund<end_date, ths_fund_expiry_date_fund, end_date) date_to
            FROM
            (
                SELECT info.ths_code, ths_lof_listed_date_fund date_frm, ths_fund_expiry_date_fund,
                if(hour(now())<16, subdate(curdate(),1), curdate()) end_date
                FROM ifind_pub_fund_info info 
            ) tt
            WHERE date_frm <= if(ths_fund_expiry_date_fund<end_date, ths_fund_expiry_date_fund, end_date) 
            ORDER BY ths_code"""

    # 计算每只股票需要获取日线数据的日期区间
    with with_db_session(engine_md) as session:
        # 获取每只股票需要获取日线数据的日期区间
        table = session.execute(sql_str)
        # 计算每只股票需要获取日线数据的日期区间
        code_date_range_dic = {
            ths_code:
            (date_from if begin_time is None else min([date_from, begin_time]),
             date_to)
            for ths_code, date_from, date_to in table.fetchall()
            if ths_code_set is None or ths_code in ths_code_set
        }

    if TRIAL:
        date_from_min = date.today() - timedelta(days=(365 * 5))
        # 试用账号只能获取近5年数据
        code_date_range_dic = {
            ths_code: (max([date_from, date_from_min]), date_to)
            for ths_code, (date_from, date_to) in code_date_range_dic.items()
            if date_from_min <= date_to
        }

    # 设置 dtype
    dtype = {key: val for key, _, val in indicator_param_list}
    dtype['ths_code'] = String(20)
    dtype['time'] = Date

    data_df_list, data_count, tot_data_count, code_count = [], 0, 0, len(
        code_date_range_dic)
    try:
        for num, (ths_code,
                  (begin_time,
                   end_time)) in enumerate(code_date_range_dic.items(),
                                           start=1):
            logger.debug('%d/%d) %s [%s - %s]', num, code_count, ths_code,
                         begin_time, end_time)
            data_df = invoker.THS_HistoryQuotes(
                ths_code, json_indicator,
                'Interval:D,CPS:1,baseDate:1900-01-01,Currency:YSHB,fill:Previous',
                begin_time, end_time)
            if data_df is not None and data_df.shape[0] > 0:
                data_count += data_df.shape[0]
                data_df.rename(
                    columns={col: col.lower()
                             for col in data_df.columns},
                    inplace=True)
                data_df_list.append(data_df)

            if DEBUG and len(data_df_list) > 1:
                break

            # 大于阀值有开始插入
            if data_count >= 10000:
                tot_data_df = pd.concat(data_df_list)
                data_count = bunch_insert_on_duplicate_update(
                    tot_data_df, table_name, engine_md, dtype)
                tot_data_count += data_count
                data_df_list, data_count = [], 0
    finally:
        if len(data_df_list) > 0:
            tot_data_df = pd.concat(data_df_list)
            data_count = bunch_insert_on_duplicate_update(
                tot_data_df, table_name, engine_md, dtype)
            tot_data_count += data_count

        logging.info("更新 %s 完成 新增数据 %d 条", table_name, tot_data_count)
        if not has_table and engine_md.has_table(table_name):
            alter_table_2_myisam(engine_md, [table_name])
            build_primary_key([table_name])