コード例 #1
0
ファイル: py_expression_eval.py プロジェクト: sukeyisme/JAQS
 def _align_bivariate(self, df1, df2, force_align=False):
     if isinstance(df1, pd.DataFrame) and isinstance(df2, pd.DataFrame):
         len1 = len(df1.index)
         len2 = len(df2.index)
         if (self.ann_dts is not None) and (self.trade_dts is not None):
             if len1 > len2:
                 df2 = align(df2, self.ann_dts, self.trade_dts)
             elif len1 < len2:
                 df1 = align(df1, self.ann_dts, self.trade_dts)
             elif force_align:
                 df1 = align(df1, self.ann_dts, self.trade_dts)
                 df2 = align(df2, self.ann_dts, self.trade_dts)
     return (df1, df2)
コード例 #2
0
ファイル: py_expression_eval.py プロジェクト: znxkxx/JAQS
 def _align_bivariate(self, df1, df2, force_align=False):
     if isinstance(df1, pd.DataFrame) and isinstance(df2, pd.DataFrame):
         len1 = len(df1.index)
         len2 = len(df2.index)
         if (self.ann_dts is not None) and (self.trade_dts is not None):
             if len1 > len2:
                 df2 = align(df2, self.ann_dts, self.trade_dts)
             elif len1 < len2:
                 df1 = align(df1, self.ann_dts, self.trade_dts)
             elif force_align:
                 df1 = align(df1, self.ann_dts, self.trade_dts)
                 df2 = align(df2, self.ann_dts, self.trade_dts)
     return (df1, df2)
コード例 #3
0
ファイル: test_dataservice.py プロジェクト: smartgang/JAQS
def test_remote_data_service_industry():
    from jaqs.data.align import align
    import pandas as pd
    
    arr = ds.get_index_comp(index='000300.SH', start_date=20130101, end_date=20170505)
    df = ds.get_industry_raw(symbol=','.join(arr), type_='ZZ')
    
    # df_ann = df.loc[:, ['in_date', 'symbol']]
    # df_ann = df_ann.set_index(['symbol', 'in_date'])
    # df_ann = df_ann.unstack(level='symbol')
    
    from jaqs.data import DataView
    dic_sec = jutil.group_df_to_dict(df, by='symbol')
    dic_sec = {sec: df.reset_index() for sec, df in dic_sec.items()}
    
    df_ann = pd.concat([df.loc[:, 'in_date'].rename(sec) for sec, df in dic_sec.items()], axis=1)
    df_value = pd.concat([df.loc[:, 'industry1_code'].rename(sec) for sec, df in dic_sec.items()], axis=1)
    
    dates_arr = ds.get_trade_date_range(20140101, 20170505)
    res = align(df_value, df_ann, dates_arr)
    # df_ann = df.pivot(index='in_date', columns='symbol', values='in_date')
    # df_value = df.pivot(index=None, columns='symbol', values='industry1_code')
    
    def align_single_df(df_one_sec):
        df_value = df_one_sec.loc[:, ['industry1_code']]
        df_ann = df_one_sec.loc[:, ['in_date']]
        res = align(df_value, df_ann, dates_arr)
        return res
    # res_list = [align_single_df(df) for sec, df in dic_sec.items()]
    res_list = [align_single_df(df) for df in list(dic_sec.values())[:10]]
    res = pd.concat(res_list, axis=1)
コード例 #4
0
    def get_industry_daily(self,
                           symbol,
                           start_date,
                           end_date,
                           type_='SW',
                           level=1):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_industry_raw(symbol, type_=type_, level=level)

        dic_sec = jutil.group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.sort_values(by='in_date', axis=0).reset_index()
            for sec, df in dic_sec.viewitems()
        }

        df_ann_tmp = pd.concat(
            {sec: df.loc[:, 'in_date']
             for sec, df in dic_sec.viewitems()},
            axis=1)
        df_value_tmp = pd.concat(
            {
                sec: df.loc[:, 'industry{:d}_code'.format(level)]
                for sec, df in dic_sec.viewitems()
            },
            axis=1)

        idx = np.unique(
            np.concatenate([df.index.values for df in dic_sec.values()]))
        symbol_arr = np.sort(symbol.split(','))
        df_ann = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_ann.loc[df_ann_tmp.index, df_ann_tmp.columns] = df_ann_tmp
        df_value = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_value.loc[df_value_tmp.index, df_value_tmp.columns] = df_value_tmp

        dates_arr = self.get_trade_date_range(start_date, end_date)
        df_industry = align.align(df_value, df_ann, dates_arr)

        # TODO before industry classification is available, we assume they belong to their first group.
        df_industry = df_industry.fillna(method='bfill')
        df_industry = df_industry.astype(str)

        return df_industry
コード例 #5
0
ファイル: py_expression_eval.py プロジェクト: znxkxx/JAQS
 def _align_univariate(self, df1):
     if isinstance(df1, pd.DataFrame):
         if (self.ann_dts is not None) and (self.trade_dts is not None):
             len1 = len(df1.index)
             len2 = len(self.trade_dts)
             if len1 != len2:
                 return align(df1, self.ann_dts, self.trade_dts)
     return df1
コード例 #6
0
ファイル: py_expression_eval.py プロジェクト: sukeyisme/JAQS
 def _align_univariate(self, df1):
     if isinstance(df1, pd.DataFrame):
         if (self.ann_dts is not None) and (self.trade_dts is not None):
             len1 = len(df1.index)
             len2 = len(self.trade_dts)
             if len1 != len2:
                 return align(df1, self.ann_dts, self.trade_dts)
     return df1
コード例 #7
0
ファイル: dataview.py プロジェクト: williamsyb/jaqs-fxdayu
 def append_df_quarter(self, df, field_name, overwrite=True):
     if field_name in self.fields:
         if overwrite:
             self.remove_field(field_name)
             print("Field [{:s}] is overwritten.".format(field_name))
         else:
             print("Append df failed: name [{:s}] exist. Try another name.".format(field_name))
             return
     self.append_df(df, field_name, is_quarterly=True)
     df_ann = self._get_ann_df()
     df_expanded = align(df.reindex(df_ann.index), df_ann, self.dates)
     self.append_df(df_expanded, field_name, is_quarterly=False)
コード例 #8
0
ファイル: dataservice.py プロジェクト: tianhm/jaqs
    def get_industry_daily(self, symbol, start_date, end_date, type_='SW'):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.get_industry_raw(symbol, type_=type_)

        dic_sec = self._group_df_to_dict(df_raw, by='symbol')
        dic_sec = {
            sec: df.sort_values(by='in_date', axis=0).reset_index()
            for sec, df in dic_sec.viewitems()
        }

        df_ann = pd.concat([
            df.loc[:, 'in_date'].rename(sec)
            for sec, df in dic_sec.viewitems()
        ],
                           axis=1)
        df_value = pd.concat([
            df.loc[:, 'industry1_code'].rename(sec)
            for sec, df in dic_sec.viewitems()
        ],
                             axis=1)

        dates_arr = self.get_trade_date(start_date, end_date)
        df_industry = align.align(df_value, df_ann, dates_arr)

        # TODO before industry classification is available, we assume they belong to their first group.
        df_industry = df_industry.fillna(method='bfill')
        df_industry = df_industry.astype(str)

        return df_industry
コード例 #9
0
ファイル: dataservice.py プロジェクト: sukeyisme/JAQS
    def query_industry_daily(self, symbol, start_date, end_date, type_='SW', level=1):
        """
        Get index components on each day during start_date and end_date.
        
        Parameters
        ----------
        symbol : str
            separated by ','
        start_date : int
        end_date : int
        type_ : {'SW', 'ZZ'}

        Returns
        -------
        res : pd.DataFrame
            index dates, columns symbols
            values are industry code

        """
        df_raw = self.query_industry_raw(symbol, type_=type_, level=level)
        
        dic_sec = jutil.group_df_to_dict(df_raw, by='symbol')
        dic_sec = {sec: df.sort_values(by='in_date', axis=0).reset_index()
                   for sec, df in dic_sec.items()}

        df_ann_tmp = pd.concat({sec: df.loc[:, 'in_date'] for sec, df in dic_sec.items()}, axis=1)
        df_value_tmp = pd.concat({sec: df.loc[:, 'industry{:d}_code'.format(level)]
                                  for sec, df in dic_sec.items()},
                                 axis=1)
        
        idx = np.unique(np.concatenate([df.index.values for df in dic_sec.values()]))
        symbol_arr = np.sort(symbol.split(','))
        df_ann = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_ann.loc[df_ann_tmp.index, df_ann_tmp.columns] = df_ann_tmp
        df_value = pd.DataFrame(index=idx, columns=symbol_arr, data=np.nan)
        df_value.loc[df_value_tmp.index, df_value_tmp.columns] = df_value_tmp

        dates_arr = self.query_trade_dates(start_date, end_date)
        df_industry = align.align(df_value, df_ann, dates_arr)
        
        # TODO before industry classification is available, we assume they belong to their first group.
        df_industry = df_industry.fillna(method='bfill')
        df_industry = df_industry.astype(str)
        
        return df_industry
コード例 #10
0
ファイル: test_dataservice.py プロジェクト: sukeyisme/JAQS
def test_remote_data_service_industry():
    from jaqs.data.align import align
    import pandas as pd
    
    arr = ds.query_index_member(index='000300.SH', start_date=20130101, end_date=20170505)
    df = ds.query_industry_raw(symbol=','.join(arr), type_='SW')
    df = ds.query_industry_raw(symbol=','.join(arr), type_='ZZ')
    
    # errors
    try:
        ds.query_industry_raw(symbol=','.join(arr), type_='ZZ', level=5)
    except ValueError:
        pass
    try:
        ds.query_industry_raw(symbol=','.join(arr), type_='blabla')
    except ValueError:
        pass
    
    # df_ann = df.loc[:, ['in_date', 'symbol']]
    # df_ann = df_ann.set_index(['symbol', 'in_date'])
    # df_ann = df_ann.unstack(level='symbol')
    
    from jaqs.data import DataView
    dic_sec = jutil.group_df_to_dict(df, by='symbol')
    dic_sec = {sec: df.reset_index() for sec, df in dic_sec.items()}
    
    df_ann = pd.concat([df.loc[:, 'in_date'].rename(sec) for sec, df in dic_sec.items()], axis=1)
    df_value = pd.concat([df.loc[:, 'industry1_code'].rename(sec) for sec, df in dic_sec.items()], axis=1)
    
    dates_arr = ds.query_trade_dates(20140101, 20170505)
    res = align(df_value, df_ann, dates_arr)
    # df_ann = df.pivot(index='in_date', columns='symbol', values='in_date')
    # df_value = df.pivot(index=None, columns='symbol', values='industry1_code')
    
    def align_single_df(df_one_sec):
        df_value = df_one_sec.loc[:, ['industry1_code']]
        df_ann = df_one_sec.loc[:, ['in_date']]
        res = align(df_value, df_ann, dates_arr)
        return res
    # res_list = [align_single_df(df) for sec, df in dic_sec.items()]
    res_list = [align_single_df(df) for df in list(dic_sec.values())[:10]]
    res = pd.concat(res_list, axis=1)
コード例 #11
0
ファイル: test_dataservice.py プロジェクト: sukeyisme/JAQS
 def align_single_df(df_one_sec):
     df_value = df_one_sec.loc[:, ['industry1_code']]
     df_ann = df_one_sec.loc[:, ['in_date']]
     res = align(df_value, df_ann, dates_arr)
     return res
コード例 #12
0
ファイル: test_dataservice.py プロジェクト: wocclyl/JAQS
 def align_single_df(df_one_sec):
     df_value = df_one_sec.loc[:, ['industry1_code']]
     df_ann = df_one_sec.loc[:, ['in_date']]
     res = align(df_value, df_ann, dates_arr)
     return res
コード例 #13
0
    def add_formula(self,
                    field_name,
                    formula,
                    is_quarterly,
                    add_data=False,
                    overwrite=True,
                    formula_func_name_style='camel',
                    data_api=None,
                    register_funcs=None,
                    within_index=True):
        """
        Add a new field, which is calculated using existing fields.

        Parameters
        ----------
        formula : str or unicode
            A formula contains operations and function calls.
        field_name : str or unicode
            A custom name for the new field.
        is_quarterly : bool
            Whether df is quarterly data (like quarterly financial statement) or daily data.
        add_data: bool
            Whether add new data to the data set or return directly.
        overwrite : bool, optional
            Whether overwrite existing field. True by default.
        formula_func_name_style : {'upper', 'lower'}, optional
        data_api : RemoteDataService, optional
        register_funcs :Dict of functions you definite by yourself like {"name1":func1},
                        optional
        within_index : bool
            When do cross-section operatioins, whether just do within index components.

        Notes
        -----
        Time cost of this function:
            For a simple formula (like 'a + 1'), almost all time is consumed by append_df;
            For a complex formula (like 'GroupRank'), half of time is consumed by evaluation and half by append_df.
        """
        if data_api is not None:
            self.data_api = data_api

        if add_data:
            if field_name in self.fields:
                if overwrite:
                    self.remove_field(field_name)
                    print("Field [{:s}] is overwritten.".format(field_name))
                else:
                    raise ValueError(
                        "Add formula failed: name [{:s}] exist. Try another name."
                        .format(field_name))
            elif self._is_predefined_field(field_name):
                raise ValueError(
                    "[{:s}] is alread a pre-defined field. Please use another name."
                    .format(field_name))

        parser = Parser()
        parser.set_capital(formula_func_name_style)

        # 注册自定义函数
        if register_funcs is not None:
            for func in register_funcs.keys():
                if func in parser.ops1 or func in parser.ops2 or func in parser.functions or \
                                func in parser.consts or func in parser.values:
                    raise ValueError(
                        "注册的自定义函数名%s与内置的函数名称重复,请更换register_funcs中定义的相关函数名称." %
                        (func, ))
                parser.functions[func] = register_funcs[func]

        expr = parser.parse(formula)

        var_df_dic = dict()
        var_list = expr.variables()

        # TODO: users do not need to prepare data before add_formula
        if not self.fields:
            self.fields.extend(var_list)
            self.prepare_data()
        else:
            for var in var_list:
                if var not in self.fields:
                    print("Variable [{:s}] is not recognized (it may be wrong)," \
                          "try to fetch from the server...".format(var))
                    success = self.add_field(var)
                    if not success:
                        return

        all_quarterly = True
        for var in var_list:
            if self._is_quarter_field(var) and is_quarterly:
                df_var = self.get_ts_quarter(
                    var, start_date=self.extended_start_date_q)
            else:
                # must use extended date. Default is start_date
                df_var = self.get_ts(var,
                                     start_date=self.extended_start_date_d,
                                     end_date=self.end_date)
                all_quarterly = False
            var_df_dic[var] = df_var

        # TODO: send ann_date into expr.evaluate. We assume that ann_date of all fields of a symbol is the same
        df_ann = self._get_ann_df()
        if within_index:
            df_index_member = self.get_ts(
                'index_member',
                start_date=self.extended_start_date_d,
                end_date=self.end_date)
            if df_index_member.size == 0:
                df_index_member = None
            df_eval = parser.evaluate(var_df_dic,
                                      ann_dts=df_ann,
                                      trade_dts=self.dates,
                                      index_member=df_index_member)
        else:
            df_eval = parser.evaluate(var_df_dic,
                                      ann_dts=df_ann,
                                      trade_dts=self.dates)

        if add_data:
            if all_quarterly:
                self.append_df_quarter(df_eval, field_name)
            else:
                self.append_df(df_eval, field_name, is_quarterly=False)

        if all_quarterly:
            df_ann = self._get_ann_df()
            df_expanded = align(df_eval.reindex(df_ann.index), df_ann,
                                self.dates)
            return df_expanded.loc[self.start_date:self.end_date]
        else:
            return df_eval.loc[self.start_date:self.end_date]
コード例 #14
0
    def add_field(self, field_name, data_api=None):
        """
        Query and append new field to DataView.

        Parameters
        ----------
        data_api : BaseDataServer
        field_name : str
            Must be a known field name (which is given in documents).

        Returns
        -------
        bool
            whether add successfully.

        """
        if data_api is None:
            if self.data_api is None:
                print(
                    "Add field failed. No data_api available. Please specify one in parameter."
                )
                return False
        else:
            self.data_api = data_api

        if field_name in self.fields:
            if self.data_d is None:
                self.fields = []
            else:
                print("Field name [{:s}] already exists.".format(field_name))
                return False

        if not self._is_predefined_field(field_name):
            print("Field name [{}] not valid, ignore.".format(field_name))
            return False

        if self.data_d is None:
            self.data_d, _ = self._prepare_daily_quarterly(["trade_status"])
            self._add_field("trade_status")
            trade_status = self.get_ts("trade_status")
            if trade_status.size > 0:
                try:
                    trade_status = trade_status.astype(int)
                except:
                    tmp = (trade_status.fillna("") == u"交易").astype(int)
                    tmp[trade_status.fillna("") == ""] = np.NaN
                    self.append_df(tmp, "trade_status")

        # prepare group type
        group_map = ['sw1', 'sw2', 'sw3', 'sw4', 'zz1', 'zz2']
        if field_name in group_map:
            self._prepare_group([field_name])
            return True

        if self._is_daily_field(field_name):
            merge, _ = self._prepare_daily_quarterly([field_name])
            is_quarterly = False
        else:
            if self.data_q is None:
                _, self.data_q = self._prepare_daily_quarterly(["ann_date"])
                self._add_field("ann_date")
                self._prepare_report_date()
                self._align_and_merge_q_into_d()
            _, merge = self._prepare_daily_quarterly([field_name])
            is_quarterly = True

        df = merge.loc[:, pd.IndexSlice[:, field_name]]
        df.columns = df.columns.droplevel(level=1)
        # whether contain only trade days is decided by existing data.

        # 季度添加至data_q 日度添加至data_d
        self.append_df(df, field_name, is_quarterly=is_quarterly)
        if is_quarterly:
            df_ann = merge.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]]
            df_ann.columns = df_ann.columns.droplevel(level='field')
            df_expanded = align(df, df_ann, self.dates)
            self.append_df(df_expanded, field_name, is_quarterly=False)
        return True
コード例 #15
0
ファイル: test_data_basic.py プロジェクト: sukeyisme/JAQS
def test_align():
    # -------------------------------------------------------------------------------------
    # input and pre-process demo data
    ds = RemoteDataService()
    ds.init_from_config(data_config)
    
    raw, msg = ds.query_lb_fin_stat('income', '000001.SZ,600000.SH,601328.SH,601988.SH',
                         20160505, 20170505, fields='oper_rev,oper_cost')
    #fp = '../output/test_align.csv'
    #raw = pd.read_csv(fp)
    
    idx_list = ['report_date', 'symbol']
    raw_idx = raw.set_index(idx_list)
    raw_idx.sort_index(axis=0, level=idx_list, inplace=True)
    
    # -------------------------------------------------------------------------------------
    # get DataFrames
    df_ann = raw_idx.loc[pd.IndexSlice[:, :], 'ann_date']
    df_ann = df_ann.unstack(level=1)
    
    df_value = raw_idx.loc[pd.IndexSlice[:, :], 'oper_rev']
    df_value = df_value.unstack(level=1)
    
    # -------------------------------------------------------------------------------------
    # get data array and align
    # date_arr = ds.get_trade_date(20160325, 20170625)
    date_arr = np.array([20160325, 20160328, 20160329, 20160330, 20160331, 20160401, 20160405, 20160406,
                         20160407, 20160408, 20160411, 20160412, 20160413, 20160414, 20160415, 20160418,
                         20160419, 20160420, 20160421, 20160422, 20160425, 20160426, 20160427, 20160428,
                         20160429, 20160503, 20160504, 20160505, 20160506, 20160509, 20160510, 20160511,
                         20160512, 20160513, 20160516, 20160517, 20160518, 20160519, 20160520, 20160523,
                         20160524, 20160525, 20160526, 20160527, 20160530, 20160531, 20160601, 20160602,
                         20160603, 20160606, 20160607, 20160608, 20160613, 20160614, 20160615, 20160616,
                         20160617, 20160620, 20160621, 20160622, 20160623, 20160624, 20160627, 20160628,
                         20160629, 20160630, 20160701, 20160704, 20160705, 20160706, 20160707, 20160708,
                         20160711, 20160712, 20160713, 20160714, 20160715, 20160718, 20160719, 20160720,
                         20160721, 20160722, 20160725, 20160726, 20160727, 20160728, 20160729, 20160801,
                         20160802, 20160803, 20160804, 20160805, 20160808, 20160809, 20160810, 20160811,
                         20160812, 20160815, 20160816, 20160817, 20160818, 20160819, 20160822, 20160823,
                         20160824, 20160825, 20160826, 20160829, 20160830, 20160831, 20160901, 20160902,
                         20160905, 20160906, 20160907, 20160908, 20160909, 20160912, 20160913, 20160914,
                         20160919, 20160920, 20160921, 20160922, 20160923, 20160926, 20160927, 20160928,
                         20160929, 20160930, 20161010, 20161011, 20161012, 20161013, 20161014, 20161017,
                         20161018, 20161019, 20161020, 20161021, 20161024, 20161025, 20161026, 20161027,
                         20161028, 20161031, 20161101, 20161102, 20161103, 20161104, 20161107, 20161108,
                         20161109, 20161110, 20161111, 20161114, 20161115, 20161116, 20161117, 20161118,
                         20161121, 20161122, 20161123, 20161124, 20161125, 20161128, 20161129, 20161130,
                         20161201, 20161202, 20161205, 20161206, 20161207, 20161208, 20161209, 20161212,
                         20161213, 20161214, 20161215, 20161216, 20161219, 20161220, 20161221, 20161222,
                         20161223, 20161226, 20161227, 20161228, 20161229, 20161230, 20170103, 20170104,
                         20170105, 20170106, 20170109, 20170110, 20170111, 20170112, 20170113, 20170116,
                         20170117, 20170118, 20170119, 20170120, 20170123, 20170124, 20170125, 20170126,
                         20170203, 20170206, 20170207, 20170208, 20170209, 20170210, 20170213, 20170214,
                         20170215, 20170216, 20170217, 20170220, 20170221, 20170222, 20170223, 20170224,
                         20170227, 20170228, 20170301, 20170302, 20170303, 20170306, 20170307, 20170308,
                         20170309, 20170310, 20170313, 20170314, 20170315, 20170316, 20170317, 20170320,
                         20170321, 20170322, 20170323, 20170324, 20170327, 20170328, 20170329, 20170330,
                         20170331, 20170405, 20170406, 20170407, 20170410, 20170411, 20170412, 20170413,
                         20170414, 20170417, 20170418, 20170419, 20170420, 20170421, 20170424, 20170425,
                         20170426, 20170427, 20170428, 20170502, 20170503, 20170504, 20170505, 20170508,
                         20170509, 20170510, 20170511, 20170512, 20170515, 20170516, 20170517, 20170518,
                         20170519, 20170522, 20170523, 20170524, 20170525, 20170526, 20170531, 20170601,
                         20170602, 20170605, 20170606, 20170607, 20170608, 20170609, 20170612, 20170613,
                         20170614, 20170615, 20170616, 20170619, 20170620, 20170621, 20170622, 20170623])
    # df_res = align(df_ann, df_evaluate, date_arr)
    
    res_align = align(df_value, df_ann, date_arr)
    
    for symbol, ser_value in df_value.iteritems():
        ser_ann = df_ann[symbol]
        ann_date_last = 0
        
        assert res_align.loc[: ser_ann.iat[0]-1, symbol].isnull().all()
        for i in range(len(ser_value)):
            value = ser_value.iat[i]
            ann_date = ser_ann.iat[i]
            if i+1 >= len(ser_value):
                ann_date_next = 99999999
            else:
                ann_date_next = ser_ann.iat[i+1]
            assert (res_align.loc[ann_date: ann_date_next-1, symbol] == value).all()
コード例 #16
0
ファイル: test_data_basic.py プロジェクト: zorro430/JAQS
def test_align():
    # -------------------------------------------------------------------------------------
    # input and pre-process demo data
    ds = RemoteDataService()
    ds.init_from_config(data_config)

    raw, msg = ds.query_lb_fin_stat('income',
                                    '000001.SZ,600000.SH,601328.SH,601988.SH',
                                    20160505,
                                    20170505,
                                    fields='oper_rev,oper_cost')
    #fp = '../output/test_align.csv'
    #raw = pd.read_csv(fp)

    idx_list = ['report_date', 'symbol']
    raw_idx = raw.set_index(idx_list)
    raw_idx.sort_index(axis=0, level=idx_list, inplace=True)

    # -------------------------------------------------------------------------------------
    # get DataFrames
    df_ann = raw_idx.loc[pd.IndexSlice[:, :], 'ann_date']
    df_ann = df_ann.unstack(level=1)

    df_value = raw_idx.loc[pd.IndexSlice[:, :], 'oper_rev']
    df_value = df_value.unstack(level=1)

    # -------------------------------------------------------------------------------------
    # get data array and align
    # date_arr = ds.get_trade_date(20160325, 20170625)
    date_arr = np.array([
        20160325, 20160328, 20160329, 20160330, 20160331, 20160401, 20160405,
        20160406, 20160407, 20160408, 20160411, 20160412, 20160413, 20160414,
        20160415, 20160418, 20160419, 20160420, 20160421, 20160422, 20160425,
        20160426, 20160427, 20160428, 20160429, 20160503, 20160504, 20160505,
        20160506, 20160509, 20160510, 20160511, 20160512, 20160513, 20160516,
        20160517, 20160518, 20160519, 20160520, 20160523, 20160524, 20160525,
        20160526, 20160527, 20160530, 20160531, 20160601, 20160602, 20160603,
        20160606, 20160607, 20160608, 20160613, 20160614, 20160615, 20160616,
        20160617, 20160620, 20160621, 20160622, 20160623, 20160624, 20160627,
        20160628, 20160629, 20160630, 20160701, 20160704, 20160705, 20160706,
        20160707, 20160708, 20160711, 20160712, 20160713, 20160714, 20160715,
        20160718, 20160719, 20160720, 20160721, 20160722, 20160725, 20160726,
        20160727, 20160728, 20160729, 20160801, 20160802, 20160803, 20160804,
        20160805, 20160808, 20160809, 20160810, 20160811, 20160812, 20160815,
        20160816, 20160817, 20160818, 20160819, 20160822, 20160823, 20160824,
        20160825, 20160826, 20160829, 20160830, 20160831, 20160901, 20160902,
        20160905, 20160906, 20160907, 20160908, 20160909, 20160912, 20160913,
        20160914, 20160919, 20160920, 20160921, 20160922, 20160923, 20160926,
        20160927, 20160928, 20160929, 20160930, 20161010, 20161011, 20161012,
        20161013, 20161014, 20161017, 20161018, 20161019, 20161020, 20161021,
        20161024, 20161025, 20161026, 20161027, 20161028, 20161031, 20161101,
        20161102, 20161103, 20161104, 20161107, 20161108, 20161109, 20161110,
        20161111, 20161114, 20161115, 20161116, 20161117, 20161118, 20161121,
        20161122, 20161123, 20161124, 20161125, 20161128, 20161129, 20161130,
        20161201, 20161202, 20161205, 20161206, 20161207, 20161208, 20161209,
        20161212, 20161213, 20161214, 20161215, 20161216, 20161219, 20161220,
        20161221, 20161222, 20161223, 20161226, 20161227, 20161228, 20161229,
        20161230, 20170103, 20170104, 20170105, 20170106, 20170109, 20170110,
        20170111, 20170112, 20170113, 20170116, 20170117, 20170118, 20170119,
        20170120, 20170123, 20170124, 20170125, 20170126, 20170203, 20170206,
        20170207, 20170208, 20170209, 20170210, 20170213, 20170214, 20170215,
        20170216, 20170217, 20170220, 20170221, 20170222, 20170223, 20170224,
        20170227, 20170228, 20170301, 20170302, 20170303, 20170306, 20170307,
        20170308, 20170309, 20170310, 20170313, 20170314, 20170315, 20170316,
        20170317, 20170320, 20170321, 20170322, 20170323, 20170324, 20170327,
        20170328, 20170329, 20170330, 20170331, 20170405, 20170406, 20170407,
        20170410, 20170411, 20170412, 20170413, 20170414, 20170417, 20170418,
        20170419, 20170420, 20170421, 20170424, 20170425, 20170426, 20170427,
        20170428, 20170502, 20170503, 20170504, 20170505, 20170508, 20170509,
        20170510, 20170511, 20170512, 20170515, 20170516, 20170517, 20170518,
        20170519, 20170522, 20170523, 20170524, 20170525, 20170526, 20170531,
        20170601, 20170602, 20170605, 20170606, 20170607, 20170608, 20170609,
        20170612, 20170613, 20170614, 20170615, 20170616, 20170619, 20170620,
        20170621, 20170622, 20170623
    ])
    # df_res = align(df_ann, df_evaluate, date_arr)

    res_align = align(df_value, df_ann, date_arr)

    for symbol, ser_value in df_value.iteritems():
        ser_ann = df_ann[symbol]
        ann_date_last = 0

        assert res_align.loc[:ser_ann.iat[0] - 1, symbol].isnull().all()
        for i in range(len(ser_value)):
            value = ser_value.iat[i]
            ann_date = ser_ann.iat[i]
            if i + 1 >= len(ser_value):
                ann_date_next = 99999999
            else:
                ann_date_next = ser_ann.iat[i + 1]
            assert (res_align.loc[ann_date:ann_date_next - 1,
                                  symbol] == value).all()