コード例 #1
0
def reg_fields_from_local_step(cls):
    """将 cls 的所有 field 以 field_name 注册能够从 GetLocalStep 的 Prop 进行赋值"""
    assert issubclass(cls, GSStep)
    assert is_dataclass(cls)
    for curr_field in fields(cls):
        if not curr_field.init:
            continue
        GlobalGSStepMapping.register(
            GetContextStep,
            cls,
            rule_name=curr_field.name,
            diff_name={GetContextStep.val: curr_field})
コード例 #2
0
        train_pos = int(len(random_df) * self.split_ratio)
        self._train_set, self._val_set = random_df.iloc[:train_pos], random_df.iloc[train_pos:]
        if self.f_after_split is not None:
            self._train_set = self.f_after_split(self._train_set)
            self._val_set = self.f_after_split(self._train_set)

    @property
    def train_set(self) -> pd.DataFrame:
        return self._train_set

    @property
    def val_set(self) -> pd.DataFrame:
        return self._val_set


GlobalGSStepMapping.register(SymbolTSStep, TrainValSpiltStep, diff_name={SymbolTSStep.ts_data: TrainValSpiltStep.train_val_orig_data})
GlobalGSStepMapping.register(FuncStrStep, TrainValSpiltStep,
                             diff_name={FuncStrStep.func: TrainValSpiltStep.f_after_split})
GlobalGSStepMapping.register(FuncStrStep, TrainValSpiltStep, rule_name="train_val_orig_data",
                             diff_name={FuncStrStep.func_result: TrainValSpiltStep.train_val_orig_data})

if __name__ == "__main__":
    import pprint

    pp = pprint.PrettyPrinter(indent=1, compact=True)

    symbol_ts_step = SymbolTSStep(api="index_weight",
                                  symbols={"BigCap": "000043.SH", "MidCap": "000044.SH", "SmlCap": "000045.SH"},
                                  cols=["con_code"])
    f_after_split_step = FuncStrStep(func_body="lambda df : {t: set(df[df.index == t][df.columns[0]].tolist()) for t in df.index.unique()}")
    train_val_set = TrainValSpiltStep(_input_steps=[symbol_ts_step, f_after_split_step], split_ratio=0.85)
コード例 #3
0
                break

    @property
    def tf_ds(self) -> tf.data.Dataset:
        if self.ds_pip is None:
            return self._tf_ds
        else:
            return self._tf_ds_with_pip


reg_fields_from_local_step(TSPortfolioWeightTFDSStep)

GlobalGSStepMapping.register(
    TSPortfolioWeightInputStep,
    TSPortfolioWeightTFDSStep,
    diff_name={
        TSPortfolioWeightInputStep.all_portfolio_callable:
        TSPortfolioWeightTFDSStep.all_portfolio_callable
    })

GlobalGSStepMapping.register(TSPortfolioWeightInputStep,
                             TSPortfolioWeightTFDSStep,
                             diff_name={
                                 TSPortfolioWeightInputStep.market_indicator:
                                 TSPortfolioWeightTFDSStep.market_indicator
                             },
                             rule_name="market_indicator")

GlobalGSStepMapping.register(SymbolMultipleTSStep,
                             TSPortfolioWeightTFDSStep,
                             diff_name={
コード例 #4
0
        df = df_align.join(df_stk)
        if self.f_fill_na:
            df = self.f_fill_na(df)
        df_rlt = df[df.columns[1:]]

        return df_rlt

    @property
    def ts_callable_by_lookback(self) -> SymbolPeriodTSByLookbackCallable:
        return self._get_ts_by_lookback


GlobalGSStepMapping.register(
    SymbolTSStep,
    TSPeriodTSByLookbackStep,
    rule_name="period_ts_callable",
    diff_name={
        SymbolTSStep.symbol_period_ts_callable:
        TSPeriodTSByLookbackStep.symbol_period_ts_callable
    })

GlobalGSStepMapping.register(
    SymbolMultipleTSStep,
    TSPeriodTSByLookbackStep,
    rule_name="period_ts_callable",
    diff_name={
        SymbolMultipleTSStep.symbol_period_ts_callable:
        TSPeriodTSByLookbackStep.symbol_period_ts_callable
    })

GlobalGSStepMapping.register(
    SymbolTSStep,
コード例 #5
0
                    b_append_pred=False)

                # print(f"\t {start_t} - {end_t} - {idx_to_mask} - {start_idx}")
                yield (input_ids, position_id, token_id, attention_mask_id), y_true
                # return for debug
                # return (input_ids, position_id, token_id, attention_mask_id), y_true

    @property
    def tf_ds(self) -> tf.data.Dataset:
        if self.ds_pip is None:
            return self._tf_ds
        else:
            return self._tf_ds_with_pip


GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={
    ChnEquityInputStep.train_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="train")

GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={
    ChnEquityInputStep.val_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="validation")

GlobalGSStepMapping.register(ChnEquityInputStep, FinancialStatementCSMaskedTFDatasetStep, diff_name={
    ChnEquityInputStep.eval_items: FinancialStatementCSMaskedTFDatasetStep.df_equities}, rule_name="evaluate")

if __name__ == "__main__":
    def cs_financial_statement_model_evaluate():
        from gs_research_workflow.time_series.models.ts_bert import TSBertForMaskedCS
        from gs_research_workflow.time_series.gs_steps.model_steps import TFModelStep

        # 显示所有列
        pd.set_option('display.max_columns', None)
        # 显示所有行
コード例 #6
0
    @property
    def func_result(self) -> Any:
        """执行函数,并将结果返回"""
        args = self.args
        if self.single_input is not None:
            args = [self.single_input]
        args = [] if args is None else args
        kwargs = {} if self.kwargs is None else self.kwargs or {}
        return self.func(*args, **kwargs)


# 适用于 function 搭建 pip line
GlobalGSStepMapping.register(
    FuncStrStep,
    FuncStrStep,
    rule_name="single_ret_pip",
    diff_name={FuncStrStep.func_result: FuncStrStep.single_input})

GlobalGSStepMapping.register(
    FuncStrStep,
    FuncStrStep,
    rule_name="args_ret_pip",
    diff_name={FuncStrStep.func_result: FuncStrStep.args})

GlobalGSStepMapping.register(
    FuncStrStep,
    FuncStrStep,
    rule_name="kwargs_ret_pip",
    diff_name={FuncStrStep.func_result: FuncStrStep.kwargs})
コード例 #7
0
        checkpoint_path = self.model_cls.model_checkpoint_path(
            self.model_init_hp)
        eval_model = self.model_cls.from_pre_saved(checkpoint_path)
        eval_result = eval_model.evaluate(self.test_ds)
        logger.info(f"eval result : {eval_result}")
        loss, default_metrics, *_ = eval_result

        if getattr(self, "metrics_reporter", None) is not None:
            logger.info(f"report model final result {default_metrics}")
            await self.report_final_result(default_metrics)
            await asyncio.sleep(60.)


GlobalGSStepMapping.register(
    FuncStrStep,
    TFTrainStep,
    rule_name="train_ds",
    diff_name={FuncStrStep.func_result: TFTrainStep.train_ds})

GlobalGSStepMapping.register(
    FuncStrStep,
    TFTrainStep,
    rule_name="val_ds",
    diff_name={FuncStrStep.func_result: TFTrainStep.val_ds})

GlobalGSStepMapping.register(
    FuncStrStep,
    TFTrainStep,
    rule_name="test_ds",
    diff_name={FuncStrStep.func_result: TFTrainStep.test_ds})
コード例 #8
0
        assert self.symbols is None
        return lambda symbol: self._func_after_query(
            getattr(self._sdk, self.api)(symbol, self.start_t, self.end_t, cols=self.cols))

    @property
    def symbol_period_ts_callable(self) -> SymbolPeriodTSCallable:
        """按照 symbol 获取一段时间的 callable 对象 """
        assert self.symbols is None and self.start_t is None and self.end_t is None
        return lambda symbol, start_t, end_t: self._func_after_query(
            getattr(self._sdk, self.api)(symbol, start_t, end_t, cols=self.cols))


reg_fields_from_local_step(SymbolTSStep)

# ts 取值的结果能够套入一个函数的输出内容
GlobalGSStepMapping.register(SymbolTSStep, FuncStrStep, rule_name="ts_process",
                             diff_name={SymbolTSStep.ts_data: FuncStrStep.single_input})

GlobalGSStepMapping.register(KeyValueListToMappingStep, SymbolTSStep, rule_name="symbols",
                             diff_name={KeyValueListToMappingStep.mapping_data: SymbolTSStep.symbols})


@dataclass
class SymbolMultipleTSStep(GSStep):
    """适用于将多个 Symbol TS 的数据按照 t Join 之后的 callable 或者 ts_data
    先简化一些,只能来自于一个 sdk wrapper 的多个接口
    """

    data_query_class: str
    """query的类定义,如 tushare"""

    apis_and_columns: Dict[str, Tuple[str, List[str]]]