def get_start_date( self, code_list: Union[str, Tuple[str], List[str]], factor_time_range: list ): """ 获取上市时间 """ # stock_list = utils.QA_fmt_code_list(code_list, style="jq") # df_local = jqdatasdk.get_all_securities(types="stock") stock_list = utils.QA_fmt_code_list(code_list) df_local = QA_fetch_stock_basic(status=None).set_index("code") intersection = list(df_local.index.intersection(stock_list)) ss = df_local.loc[intersection]["list_date"] # ss.index = ss.index.map(lambda x: x[:6]) # 日期处理 date_range = list(map(lambda x: x.date(), factor_time_range)) multiindex = pd.MultiIndex.from_product( [date_range, utils.QA_fmt_code_list(intersection)] ) values = multiindex.map(lambda x: ss.loc[x[1]]).tolist() df_local = pd.Series(index=multiindex, data=values) df_local.index.names = ["date", "code"] return df_local
def get_groupby( self, code_list: Union[str, Tuple[str], List[str]], factor_time_range: list = None, industry_cls: str = None, industry_data: Union[dict, pd.Series, pd.DataFrame] = None, frequence: str = 'DAY', detailed: bool = False, ) -> pd.Series: """ 获取行业信息 参数 --- :param code_list: 股票代码 :param factor_time_range: 因子时间范围 :param industry_cls: 行业分类 :param industry_data: 行业信息 :param detailed: 如果按日期获取行业数据,对速度影响很大, 设置该参数为 False, 则默认按照 end_time 的行业分类作为全部日期的行业信息 返回 --- :return: 以 ['日期' '股票'] 为索引的行业信息 """ # 1. 股票代码 if isinstance(code_list, tuple): code_list = list(code_list) # 2. 日期处理 if not detailed: date_range = [pd.Timestamp(max(factor_time_range))] else: date_range = factor_time_range # 3. 行业处理 if not industry_cls: industry_cls = self.industry_cls if not industry_data: industry_data = self.industry_data if (not industry_cls) and (not industry_data): warnings.warn("没有指定行业分类方式,也没有输入行业分类信息", UserWarning) return pd.Series( index=pd.MultiIndex.from_product([date_range, code_list]), name="group", data="NA", ) # 4. 具体行业 if not industry_cls: # 没有输入行业分类信息 df_local = pd.DataFrame() if isinstance(industry_data, dict): df_local = pd.DataFrame(index=date_range, data=[industry_data] * len(date_range)).stack(level=-1) if isinstance(industry_data, pd.DataFrame): df_local = industry_data.stack(level=-1) else: df_local = industry_data df_local.index.names = ["date", "code"] return df_local # 如果输入了行业分类信息,按行业分类信息进行处理 # FIXME: 暂时使用聚宽的行业数据 stock_list = utils.QA_fmt_code_list(code_list, style="jq") df_local = pd.DataFrame() industries = map(partial(jqdatasdk.get_industry, stock_list), date_range) industries = { d: { s: ind.get(s).get(industry_cls, dict()).get("industry_name", "NA") for s in stock_list } for d, ind in zip(date_range, industries) } df_local = pd.DataFrame(industries).T.sort_index() df_local.columns = df_local.columns.map(lambda x: x[0:6]) df_local = df_local.stack(level=-1) df_local.index.names = ["date", "code"] return df_local
def __gen_clean_factor_and_forward_returns(self): """ 格式化因子数据,附加因子远期收益,分组,权重信息 """ factor_data = self.factor # 股票代码: 默认转换为 QA 支持格式 code_list = utils.QA_fmt_code_list( list(factor_data.index.get_level_values("code").drop_duplicates()) ) # 因子日期 factor_time_range = list(factor_data.index.levels[0].drop_duplicates()) start_time = min(factor_time_range) end_time = max(factor_time_range) # 附加数据 if hasattr(self.prices, "__call__"): prices = self.prices( code_list=code_list, start_time=start_time, end_time=end_time, frequence=self.frequence, ) prices = prices.loc[~prices.index.duplicated()] else: prices = self.prices self.prices = prices if hasattr(self.groupby, "__call__"): groupby = self.groupby( code_list=code_list, factor_time_range=factor_time_range ) else: groupby = self.groupby self.groupby = groupby if hasattr(self.stock_start_date, "__call__"): stock_start_date = self.stock_start_date( code_list=code_list, factor_time_range=factor_time_range ) else: stock_start_date = self.stock_start_date self.stock_start_date = stock_start_date if hasattr(self.weights, "__call__"): weights = self.weights( code_list=code_list, factor_time_range=factor_time_range, frequence=self.frequence ) else: weights = self.weights self.weights = weights # 周期处理 # self.interval = utils.get_interval(self.frequence) # 4. 因子处理 self._clean_factor_data = get_clean_factor_and_forward_returns( factor=factor_data, prices=self.prices, groupby=self.groupby, stock_start_date=self.stock_start_date, weights=self.weights, binning_by_group=self.binning_by_group, quantiles=self.quantiles, bins=self.bins, periods=self.periods, max_loss=self.max_loss, zero_aware=self.zero_aware, frequence=self.frequence, )
def get_groupby( self, code_list: Union[str, Tuple[str], List[str]], factor_time_range: list = None, industry_cls: str = None, industry_data: Union[dict, pd.Series, pd.DataFrame] = None, frequence: str = 'DAY', detailed: bool = False, ) -> pd.Series: """ 获取行业信息 参数 --- :param code_list: 股票代码 :param factor_time_range: 因子时间范围 :param industry_cls: 行业分类 :param industry_data: 行业信息 :param detailed: 如果按日期获取行业数据,对速度影响很大, 设置该参数为 False, 则默认按照 end_time 的行业分类作为全部日期的行业信息 返回 --- :return: 以 ['日期' '股票'] 为索引的行业信息 """ # 1. 股票代码 if isinstance(code_list, tuple): code_list = list(code_list) # 2. 日期处理 if not detailed: date_range = [pd.Timestamp(max(factor_time_range)).date()] else: date_range = list(map(lambda x: x.date(), factor_time_range)) # 3. 行业处理 if not industry_cls: industry_cls = self.industry_cls if not industry_data: industry_data = self.industry_data if (industry_cls is None) and (industry_data is None): warnings.warn("没有指定行业分类方式,也没有输入行业分类信息", UserWarning) return pd.Series( index=pd.MultiIndex.from_product([date_range, code_list]), name="group", data="NA", ) # 4. 具体行业 if not industry_cls: # 没有输入行业分类信息 df_local = pd.DataFrame() if isinstance(industry_data, dict): df_local = pd.DataFrame( index=date_range, data=[industry_data] * len(date_range) ).stack(level=-1) if isinstance(industry_data, pd.DataFrame): df_local = industry_data.stack(level=-1) else: df_local = industry_data df_local.index.names = ["date", "code"] return df_local # 如果输入了行业分类信息,按行业分类信息进行处理 # FIXME: 暂时使用聚宽的行业数据 # stock_list = utils.QA_fmt_code_list(code_list, style="jq") stock_list = utils.QA_fmt_code_list(code_list) df_local = pd.DataFrame() for cursor_date in date_range: df_tmp = QA_fetch_industry_adv( code=code_list, cursor_date = cursor_date, levels=industry_cls.split('_')[1], src=industry_cls.split("_")[0])[["code", "industry_name"]] df_tmp["date"] = cursor_date df_local = df_local.append(df_tmp) # industries = map( # partial(jqdatasdk.get_industry, # stock_list), # date_range # ) # industries = { # d: { # s: ind.get(s).get(industry_cls, # dict()).get("industry_name", # "NA") # for s in stock_list # } # for d, # ind in zip(date_range, # industries) # } # df_local = pd.DataFrame(industries).T.sort_index() # df_local.columns = df_local.columns.map(lambda x: x[0:6]) # df_local = df_local.stack(level=-1) # df_local.index.names = ["date", "code"] df_local = df_local.set_index(["date", "code"]) return df_local["industry_name"]