Esempio n. 1
0
    def get_start_date(
            self,
            code_list: Union[str,
                             Tuple[str],
                             List[str]],
            factor_time_range: list
    ):
        """
        获取上市时间
        """
        # stock_list = utils.QA_fmt_code_list(code_list, style="jq")
        # df_local = jqdatasdk.get_all_securities(types="stock")
        stock_list = utils.QA_fmt_code_list(code_list)
        df_local = QA_fetch_stock_basic(status=None).set_index("code")
        intersection = list(df_local.index.intersection(stock_list))
        ss = df_local.loc[intersection]["list_date"]
        # ss.index = ss.index.map(lambda x: x[:6])
        # 日期处理
        date_range = list(map(lambda x: x.date(), factor_time_range))

        multiindex = pd.MultiIndex.from_product(
            [date_range,
             utils.QA_fmt_code_list(intersection)]
        )
        values = multiindex.map(lambda x: ss.loc[x[1]]).tolist()
        df_local = pd.Series(index=multiindex, data=values)
        df_local.index.names = ["date", "code"]
        return df_local
Esempio n. 2
0
    def get_groupby(
        self,
        code_list: Union[str, Tuple[str], List[str]],
        factor_time_range: list = None,
        industry_cls: str = None,
        industry_data: Union[dict, pd.Series, pd.DataFrame] = None,
        frequence: str = 'DAY',
        detailed: bool = False,
    ) -> pd.Series:
        """
        获取行业信息

        参数
        ---
        :param code_list: 股票代码
        :param factor_time_range: 因子时间范围
        :param industry_cls: 行业分类
        :param industry_data: 行业信息
        :param detailed: 如果按日期获取行业数据,对速度影响很大,
              设置该参数为 False, 则默认按照 end_time 的行业分类作为全部日期的行业信息

        返回
        ---
        :return: 以 ['日期' '股票'] 为索引的行业信息
        """
        # 1. 股票代码
        if isinstance(code_list, tuple):
            code_list = list(code_list)

        # 2. 日期处理
        if not detailed:
            date_range = [pd.Timestamp(max(factor_time_range))]
        else:
            date_range = factor_time_range

        # 3. 行业处理
        if not industry_cls:
            industry_cls = self.industry_cls

        if not industry_data:
            industry_data = self.industry_data

        if (not industry_cls) and (not industry_data):
            warnings.warn("没有指定行业分类方式,也没有输入行业分类信息", UserWarning)
            return pd.Series(
                index=pd.MultiIndex.from_product([date_range, code_list]),
                name="group",
                data="NA",
            )

        # 4. 具体行业
        if not industry_cls:  # 没有输入行业分类信息
            df_local = pd.DataFrame()
            if isinstance(industry_data, dict):
                df_local = pd.DataFrame(index=date_range,
                                        data=[industry_data] *
                                        len(date_range)).stack(level=-1)
            if isinstance(industry_data, pd.DataFrame):
                df_local = industry_data.stack(level=-1)
            else:
                df_local = industry_data
            df_local.index.names = ["date", "code"]
            return df_local
            # 如果输入了行业分类信息,按行业分类信息进行处理
            # FIXME: 暂时使用聚宽的行业数据
        stock_list = utils.QA_fmt_code_list(code_list, style="jq")
        df_local = pd.DataFrame()
        industries = map(partial(jqdatasdk.get_industry, stock_list),
                         date_range)
        industries = {
            d: {
                s: ind.get(s).get(industry_cls,
                                  dict()).get("industry_name", "NA")
                for s in stock_list
            }
            for d, ind in zip(date_range, industries)
        }
        df_local = pd.DataFrame(industries).T.sort_index()
        df_local.columns = df_local.columns.map(lambda x: x[0:6])
        df_local = df_local.stack(level=-1)
        df_local.index.names = ["date", "code"]
        return df_local
Esempio n. 3
0
    def __gen_clean_factor_and_forward_returns(self):
        """
        格式化因子数据,附加因子远期收益,分组,权重信息
        """
        factor_data = self.factor

        # 股票代码: 默认转换为 QA 支持格式
        code_list = utils.QA_fmt_code_list(
            list(factor_data.index.get_level_values("code").drop_duplicates())
        )

        # 因子日期
        factor_time_range = list(factor_data.index.levels[0].drop_duplicates())
        start_time = min(factor_time_range)
        end_time = max(factor_time_range)

        # 附加数据
        if hasattr(self.prices, "__call__"):
            prices = self.prices(
                code_list=code_list,
                start_time=start_time,
                end_time=end_time,
                frequence=self.frequence,
            )
            prices = prices.loc[~prices.index.duplicated()]
        else:
            prices = self.prices

        self.prices = prices

        if hasattr(self.groupby, "__call__"):
            groupby = self.groupby(
                code_list=code_list,
                factor_time_range=factor_time_range
            )
        else:
            groupby = self.groupby
        self.groupby = groupby

        if hasattr(self.stock_start_date, "__call__"):
            stock_start_date = self.stock_start_date(
                code_list=code_list,
                factor_time_range=factor_time_range
            )
        else:
            stock_start_date = self.stock_start_date
        self.stock_start_date = stock_start_date

        if hasattr(self.weights, "__call__"):
            weights = self.weights(
                code_list=code_list,
                factor_time_range=factor_time_range,
                frequence=self.frequence
            )
        else:
            weights = self.weights
        self.weights = weights

        # 周期处理
        # self.interval = utils.get_interval(self.frequence)

        # 4. 因子处理
        self._clean_factor_data = get_clean_factor_and_forward_returns(
            factor=factor_data,
            prices=self.prices,
            groupby=self.groupby,
            stock_start_date=self.stock_start_date,
            weights=self.weights,
            binning_by_group=self.binning_by_group,
            quantiles=self.quantiles,
            bins=self.bins,
            periods=self.periods,
            max_loss=self.max_loss,
            zero_aware=self.zero_aware,
            frequence=self.frequence,
        )
Esempio n. 4
0
    def get_groupby(
            self,
            code_list: Union[str,
                             Tuple[str],
                             List[str]],
            factor_time_range: list = None,
            industry_cls: str = None,
            industry_data: Union[dict,
                                 pd.Series,
                                 pd.DataFrame] = None,
            frequence: str = 'DAY',
            detailed: bool = False,
    ) -> pd.Series:
        """
        获取行业信息

        参数
        ---
        :param code_list: 股票代码
        :param factor_time_range: 因子时间范围
        :param industry_cls: 行业分类
        :param industry_data: 行业信息
        :param detailed: 如果按日期获取行业数据,对速度影响很大,
              设置该参数为 False, 则默认按照 end_time 的行业分类作为全部日期的行业信息

        返回
        ---
        :return: 以 ['日期' '股票'] 为索引的行业信息
        """
        # 1. 股票代码
        if isinstance(code_list, tuple):
            code_list = list(code_list)

        # 2. 日期处理
        if not detailed:
            date_range = [pd.Timestamp(max(factor_time_range)).date()]
        else:
            date_range = list(map(lambda x: x.date(), factor_time_range))

        # 3. 行业处理
        if not industry_cls:
            industry_cls = self.industry_cls

        if not industry_data:
            industry_data = self.industry_data

        if (industry_cls is None) and (industry_data is None):
            warnings.warn("没有指定行业分类方式,也没有输入行业分类信息", UserWarning)
            return pd.Series(
                index=pd.MultiIndex.from_product([date_range,
                                                  code_list]),
                name="group",
                data="NA",
            )

        # 4. 具体行业
        if not industry_cls:                                   # 没有输入行业分类信息
            df_local = pd.DataFrame()
            if isinstance(industry_data, dict):
                df_local = pd.DataFrame(
                    index=date_range,
                    data=[industry_data] * len(date_range)
                ).stack(level=-1)
            if isinstance(industry_data, pd.DataFrame):
                df_local = industry_data.stack(level=-1)
            else:
                df_local = industry_data
            df_local.index.names = ["date", "code"]
            return df_local
                                                               # 如果输入了行业分类信息,按行业分类信息进行处理
                                                               # FIXME: 暂时使用聚宽的行业数据
        # stock_list = utils.QA_fmt_code_list(code_list, style="jq")
        stock_list = utils.QA_fmt_code_list(code_list)
        df_local = pd.DataFrame()
        for cursor_date in date_range:
            df_tmp = QA_fetch_industry_adv(
                    code=code_list,
                    cursor_date = cursor_date,
                    levels=industry_cls.split('_')[1],
                    src=industry_cls.split("_")[0])[["code", "industry_name"]]
            df_tmp["date"] = cursor_date
            df_local = df_local.append(df_tmp)
        # industries = map(
        #     partial(jqdatasdk.get_industry,
        #             stock_list),
        #     date_range
        # )
        # industries = {
        #     d: {
        #         s: ind.get(s).get(industry_cls,
        #                           dict()).get("industry_name",
        #                                       "NA")
        #         for s in stock_list
        #     }
        #     for d,
        #     ind in zip(date_range,
        #                industries)
        # }
        # df_local = pd.DataFrame(industries).T.sort_index()
        # df_local.columns = df_local.columns.map(lambda x: x[0:6])
        # df_local = df_local.stack(level=-1)
        # df_local.index.names = ["date", "code"]
        df_local = df_local.set_index(["date", "code"])
        return df_local["industry_name"]