def __init__(self,
                 exchange: Union[Exchange, str],
                 action_scheme: Union[ActionScheme, str],
                 reward_scheme: Union[RewardScheme, str],
                 feature_pipeline: Union[FeaturePipeline, str] = None,
                 **kwargs):
        """
        Arguments:
            exchange: The `Exchange` that will be used to feed data from and execute trades within.
            action_scheme:  The component for transforming an action into a `Trade` at each timestep.
            reward_scheme: The component for determining the reward at each timestep.
            feature_pipeline (optional): The pipeline of features to pass the observations through.
            kwargs (optional): Additional arguments for tuning the environment, logging, etc.
        """
        super().__init__()
        # str => exchange , 当前交易账号
        self._exchange = exchanges.get(exchange) if isinstance(
            exchange, str) else exchange

        # str => 交易动作方案
        self._action_scheme = actions.get(action_scheme) if isinstance(
            action_scheme, str) else action_scheme

        # str => 奖赏方案
        self._reward_scheme = rewards.get(reward_scheme) if isinstance(
            reward_scheme, str) else reward_scheme
        # # str =》 特征转换管道
        self._feature_pipeline = features.get(feature_pipeline) if isinstance(
            feature_pipeline, str) else feature_pipeline

        if feature_pipeline is not None:
            self._exchange.feature_pipeline = feature_pipeline

        # 链接交易动作方案与交易账号
        self._action_scheme.exchange = self._exchange
        # 链接奖赏方案与交易账号
        self._reward_scheme.exchange = self._exchange

        # 观测维度空间
        self.observation_space = self._exchange.observation_space

        # 动作维度空间
        self.action_space = self._action_scheme.action_space

        self.render_benchmarks: List[Dict] = kwargs.get(
            'render_benchmarks', [])
        self.viewer = None

        self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
        self.logger.setLevel(kwargs.get('log_level', logging.DEBUG))
        # 取消tensorflow的日志记录
        logging.getLogger('tensorflow').disabled = kwargs.get(
            'disable_tensorflow_logger', True)

        self.reset()
Exemplo n.º 2
0
    def __init__(self,
                 exchange: Union[InstrumentExchange, str],
                 action_strategy: Union[ActionStrategy, str],
                 reward_strategy: Union[RewardStrategy, str],
                 feature_pipeline: Union[FeaturePipeline, str] = None,
                 **kwargs):
        """
        Arguments:
            exchange: The `InstrumentExchange` that will be used to feed data from and execute trades within.
            action_strategy:  The strategy for transforming an action into a `Trade` at each timestep.
            reward_strategy: The strategy for determining the reward at each timestep.
            feature_pipeline (optional): The pipeline of features to pass the observations through.
            kwargs (optional): Additional arguments for tuning the environment, logging, etc.
        """
        super().__init__()

        self._exchange = exchanges.get(exchange) if isinstance(
            exchange, str) else exchange
        self._action_strategy = actions.get(action_strategy) if isinstance(
            action_strategy, str) else action_strategy
        self._reward_strategy = rewards.get(reward_strategy) if isinstance(
            reward_strategy, str) else reward_strategy
        self._feature_pipeline = features.get(feature_pipeline) if isinstance(
            feature_pipeline, str) else feature_pipeline

        if feature_pipeline is not None:
            self._exchange.feature_pipeline = feature_pipeline

        self._exchange.reset()

        self._action_strategy.exchange = self._exchange
        self._reward_strategy.exchange = self._exchange

        self.observation_space = self._exchange.observation_space
        self.action_space = self._action_strategy.action_space

        self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
        self.logger.setLevel(kwargs.get('log_level', logging.DEBUG))

        logging.getLogger('tensorflow').disabled = kwargs.get(
            'disable_tensorflow_logger', True)
Exemplo n.º 3
0
 def feature_pipeline(self,
                      feature_pipeline: Union[FeaturePipeline, str] = None):
     self._feature_pipeline = features.get(feature_pipeline) if isinstance(
         feature_pipeline, str) else feature_pipeline