def reg_fields_from_local_step(cls): """将 cls 的所有 field 以 field_name 注册能够从 GetLocalStep 的 Prop 进行赋值""" assert issubclass(cls, GSStep) assert is_dataclass(cls) for curr_field in fields(cls): if not curr_field.init: continue GlobalGSStepMapping.register( GetContextStep, cls, rule_name=curr_field.name, diff_name={GetContextStep.val: curr_field})
def _input_steps_to_kwargs(_input_steps: Union['GSStep', Tuple['GSStep', str], Iterable[Union['GSStep', Tuple['GSStep', str]]]], to_step_type: Type['GSStep']) -> Mapping[str, Any]: """ 将 pip_op 的内容,根据映射关系,生成一个 dict 用于 kwargs Parameters ---------- _input_steps to_step_type """ from gs_research_workflow.core.gs_step_mapping import GlobalGSStepMapping # step 1 : 将 _input_steps 统一格式,转成可以迭代的对象 ls_input_steps = _input_steps_to_list(_input_steps) # step 2 : 遍历 _input_steps 的内容,将 field 的结果进行输出 dict_rlt: Mapping[str, Any] = dict() for input_step in ls_input_steps: step, rule_name = input_step step_mapping = GlobalGSStepMapping.get_registered( step.__class__, to_step_type, rule_name) curr_kwargs = step.map_to_kwargs(step_mapping) if curr_kwargs: dict_rlt.update(curr_kwargs) return dict_rlt
def get_init_value_dict(self, out_self_cls: bool = False) -> Mapping[str, Any]: """获取 init 的 dictionay 对象 Notes : 这里不作为 property , 避免产生一个与 dataclass 定义业务意义无关的 property 内容 Notes : 递归嵌套的 dict 关系是 lazy 产生的,init 过程中仅保留相关的数据链路关系 Parameters ---------- out_self_cls:bool 是否多输出一层当前 class 的内容 """ import copy from gs_research_workflow.core.gs_step_mapping import GlobalGSStepMapping # TODO: 这里需要 composition 的情况 init_dict_rlt = dict() if self._direct_init_field_value: for k, v in self._direct_init_field_value.items(): if isinstance(v, GSStep): init_dict_rlt[k] = v.get_init_value_dict(out_self_cls) else: init_dict_rlt[k] = copy.deepcopy(v) # init_dict_rlt = copy.deepcopy(self._direct_init_field_value) if self._ls_input_steps: for curr_step in self._ls_input_steps: field_mapping = GlobalGSStepMapping.get_registered( curr_step[0].__class__, self.__class__, curr_step[1]) # pip 的选项,使用的格式为 "#field1,field2# rule_name key = "#" + ",".join(field_mapping.field_names) + "#" init_dict_rlt[key] = curr_step[0].get_init_value_dict(True) init_dict_rlt[key][_KEY_PROPERTIES] = ",".join( field_mapping.property_names) if curr_step[1] is not None: init_dict_rlt[key][_KEY_RULE_NAME] = curr_step[1] if out_self_cls: return {cls_to_str(self.__class__): init_dict_rlt} else: return init_dict_rlt
train_pos = int(len(random_df) * self.split_ratio) self._train_set, self._val_set = random_df.iloc[:train_pos], random_df.iloc[train_pos:] if self.f_after_split is not None: self._train_set = self.f_after_split(self._train_set) self._val_set = self.f_after_split(self._train_set) @property def train_set(self) -> pd.DataFrame: return self._train_set @property def val_set(self) -> pd.DataFrame: return self._val_set GlobalGSStepMapping.register(SymbolTSStep, TrainValSpiltStep, diff_name={SymbolTSStep.ts_data: TrainValSpiltStep.train_val_orig_data}) GlobalGSStepMapping.register(FuncStrStep, TrainValSpiltStep, diff_name={FuncStrStep.func: TrainValSpiltStep.f_after_split}) GlobalGSStepMapping.register(FuncStrStep, TrainValSpiltStep, rule_name="train_val_orig_data", diff_name={FuncStrStep.func_result: TrainValSpiltStep.train_val_orig_data}) if __name__ == "__main__": import pprint pp = pprint.PrettyPrinter(indent=1, compact=True) symbol_ts_step = SymbolTSStep(api="index_weight", symbols={"BigCap": "000043.SH", "MidCap": "000044.SH", "SmlCap": "000045.SH"}, cols=["con_code"]) f_after_split_step = FuncStrStep(func_body="lambda df : {t: set(df[df.index == t][df.columns[0]].tolist()) for t in df.index.unique()}") train_val_set = TrainValSpiltStep(_input_steps=[symbol_ts_step, f_after_split_step], split_ratio=0.85)
break @property def tf_ds(self) -> tf.data.Dataset: if self.ds_pip is None: return self._tf_ds else: return self._tf_ds_with_pip reg_fields_from_local_step(TSPortfolioWeightTFDSStep) GlobalGSStepMapping.register( TSPortfolioWeightInputStep, TSPortfolioWeightTFDSStep, diff_name={ TSPortfolioWeightInputStep.all_portfolio_callable: TSPortfolioWeightTFDSStep.all_portfolio_callable }) GlobalGSStepMapping.register(TSPortfolioWeightInputStep, TSPortfolioWeightTFDSStep, diff_name={ TSPortfolioWeightInputStep.market_indicator: TSPortfolioWeightTFDSStep.market_indicator }, rule_name="market_indicator") GlobalGSStepMapping.register(SymbolMultipleTSStep, TSPortfolioWeightTFDSStep, diff_name={
df = df_align.join(df_stk) if self.f_fill_na: df = self.f_fill_na(df) df_rlt = df[df.columns[1:]] return df_rlt @property def ts_callable_by_lookback(self) -> SymbolPeriodTSByLookbackCallable: return self._get_ts_by_lookback GlobalGSStepMapping.register( SymbolTSStep, TSPeriodTSByLookbackStep, rule_name="period_ts_callable", diff_name={ SymbolTSStep.symbol_period_ts_callable: TSPeriodTSByLookbackStep.symbol_period_ts_callable }) GlobalGSStepMapping.register( SymbolMultipleTSStep, TSPeriodTSByLookbackStep, rule_name="period_ts_callable", diff_name={ SymbolMultipleTSStep.symbol_period_ts_callable: TSPeriodTSByLookbackStep.symbol_period_ts_callable }) GlobalGSStepMapping.register( SymbolTSStep,
b_append_pred=False) # print(f"\t {start_t} - {end_t} - {idx_to_mask} - {start_idx}") yield (input_ids, position_id, token_id, attention_mask_id), y_true # return for debug # return (input_ids, position_id, token_id, attention_mask_id), y_true @property def tf_ds(self) -> tf.data.Dataset: if self.ds_pip is None: return self._tf_ds else: return self._tf_ds_with_pip GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={ ChnEquityInputStep.train_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="train") GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={ ChnEquityInputStep.val_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="validation") GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={ ChnEquityInputStep.eval_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="evaluate") if __name__ == "__main__": def cs_financial_statement_model_evaluate(): from gs_research_workflow.time_series.models.ts_bert import TSBertForMaskedCS from gs_research_workflow.time_series.gs_steps.model_steps import TFModelStep # 显示所有列 pd.set_option('display.max_columns', None) # 显示所有行
@property def func_result(self) -> Any: """执行函数,并将结果返回""" args = self.args if self.single_input is not None: args = [self.single_input] args = [] if args is None else args kwargs = {} if self.kwargs is None else self.kwargs or {} return self.func(*args, **kwargs) # 适用于 function 搭建 pip line GlobalGSStepMapping.register( FuncStrStep, FuncStrStep, rule_name="single_ret_pip", diff_name={FuncStrStep.func_result: FuncStrStep.single_input}) GlobalGSStepMapping.register( FuncStrStep, FuncStrStep, rule_name="args_ret_pip", diff_name={FuncStrStep.func_result: FuncStrStep.args}) GlobalGSStepMapping.register( FuncStrStep, FuncStrStep, rule_name="kwargs_ret_pip", diff_name={FuncStrStep.func_result: FuncStrStep.kwargs})
checkpoint_path = self.model_cls.model_checkpoint_path( self.model_init_hp) eval_model = self.model_cls.from_pre_saved(checkpoint_path) eval_result = eval_model.evaluate(self.test_ds) logger.info(f"eval result : {eval_result}") loss, default_metrics, *_ = eval_result if getattr(self, "metrics_reporter", None) is not None: logger.info(f"report model final result {default_metrics}") await self.report_final_result(default_metrics) await asyncio.sleep(60.) GlobalGSStepMapping.register( FuncStrStep, TFTrainStep, rule_name="train_ds", diff_name={FuncStrStep.func_result: TFTrainStep.train_ds}) GlobalGSStepMapping.register( FuncStrStep, TFTrainStep, rule_name="val_ds", diff_name={FuncStrStep.func_result: TFTrainStep.val_ds}) GlobalGSStepMapping.register( FuncStrStep, TFTrainStep, rule_name="test_ds", diff_name={FuncStrStep.func_result: TFTrainStep.test_ds})
assert self.symbols is None return lambda symbol: self._func_after_query( getattr(self._sdk, self.api)(symbol, self.start_t, self.end_t, cols=self.cols)) @property def symbol_period_ts_callable(self) -> SymbolPeriodTSCallable: """按照 symbol 获取一段时间的 callable 对象 """ assert self.symbols is None and self.start_t is None and self.end_t is None return lambda symbol, start_t, end_t: self._func_after_query( getattr(self._sdk, self.api)(symbol, start_t, end_t, cols=self.cols)) reg_fields_from_local_step(SymbolTSStep) # ts 取值的结果能够套入一个函数的输出内容 GlobalGSStepMapping.register(SymbolTSStep, FuncStrStep, rule_name="ts_process", diff_name={SymbolTSStep.ts_data: FuncStrStep.single_input}) GlobalGSStepMapping.register(KeyValueListToMappingStep, SymbolTSStep, rule_name="symbols", diff_name={KeyValueListToMappingStep.mapping_data: SymbolTSStep.symbols}) @dataclass class SymbolMultipleTSStep(GSStep): """适用于将多个 Symbol TS 的数据按照 t Join 之后的 callable 或者 ts_data 先简化一些,只能来自于一个 sdk wrapper 的多个接口 """ data_query_class: str """query的类定义,如 tushare""" apis_and_columns: Dict[str, Tuple[str, List[str]]]