Beispiel #1
0
 def state_feature_config_provider(
         self) -> ModelFeatureConfigProvider__Union:
     """ For online gym """
     raw = RawModelFeatureConfigProvider(
         float_feature_infos=[
             rlt.FloatFeatureInfo(name="arm0_sample", feature_id=0),
             rlt.FloatFeatureInfo(name="arm1_sample", feature_id=1),
             rlt.FloatFeatureInfo(name="arm2_sample", feature_id=2),
             rlt.FloatFeatureInfo(name="arm3_sample", feature_id=3),
             rlt.FloatFeatureInfo(name="arm4_sample", feature_id=4),
         ],
         id_list_feature_configs=[
             rlt.IdListFeatureConfig(name="legal",
                                     feature_id=100,
                                     id_mapping_name="legal_actions")
         ],
         id_score_list_feature_configs=[
             rlt.IdScoreListFeatureConfig(name="mu_changes",
                                          feature_id=1000,
                                          id_mapping_name="arms_list")
         ],
         id_mapping_config={
             "legal_actions": rlt.IdMapping(ids=[0, 1, 2, 3, 4, 5]),
             "arms_list": rlt.IdMapping(ids=[0, 1, 2, 3, 4]),
         },
     )
     # pyre-fixme[16]: `ModelFeatureConfigProvider__Union` has no attribute
     #  `make_union_instance`.
     return ModelFeatureConfigProvider__Union.make_union_instance(raw)
Beispiel #2
0
class Reinforce(ModelManager):
    __hash__ = param_hash

    trainer_param: ReinforceTrainerParameters = field(
        default_factory=ReinforceTrainerParameters)
    # using DQN net here because it supports `possible_actions_mask`
    policy_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-ignore
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-ignore
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])))
    sampler_temperature: float = 1.0

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.action_names = self.trainer_param.actions
        self._policy: Optional[Policy] = None
        assert (
            len(self.action_names) >
            1), f"REINFORCE needs at least 2 actions. Got {self.action_names}."

    # pyre-ignore
    def build_trainer(self) -> ReinforceTrainer:
        policy_net_builder = self.policy_net_builder.value
        # pyre-ignore
        self._policy_network = policy_net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_data,
            len(self.action_names),
        )
        value_net = None
        if self.value_net_builder:
            value_net_builder = self.value_net_builder.value  # pyre-ignore
            value_net = value_net_builder.build_value_network(
                self.state_normalization_data)
        trainer = ReinforceTrainer(
            policy=self.create_policy(),
            value_net=value_net,
            **self.trainer_param.asdict(),  # pyre-ignore
        )
        return trainer

    def create_policy(self, serving: bool = False):
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module())
        else:
            if self._policy is None:
                sampler = SoftmaxActionSampler(
                    temperature=self.sampler_temperature)
                # pyre-ignore
                self._policy = Policy(scorer=self._policy_network,
                                      sampler=sampler)
            return self._policy

    def build_serving_module(self) -> torch.nn.Module:
        assert self._policy_network is not None
        policy_serving_module = self.policy_net_builder.value.build_serving_module(
            q_network=self._policy_network,
            state_normalization_data=self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
        return policy_serving_module

    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        raise NotImplementedError

    @property
    def required_normalization_keys(self) -> List[str]:
        return [NormalizationKey.STATE]

    @property
    def should_generate_eval_dataset(self) -> bool:
        raise NotImplementedError

    def query_data(
        self,
        input_table_spec: TableSpec,
        sample_range: Optional[Tuple[float, float]],
        reward_options: RewardOptions,
    ) -> Dataset:
        raise NotImplementedError

    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: Optional[ResourceOptions],
    ) -> RLTrainingOutput:
        raise NotImplementedError

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config(
        )
class ActorCriticBase(ModelManager):
    state_preprocessing_options: Optional[PreprocessingOptions] = None
    action_preprocessing_options: Optional[PreprocessingOptions] = None
    action_feature_override: Optional[str] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `raw`.
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])))
    action_float_features: List[Tuple[int, str]] = field(default_factory=list)
    reader_options: Optional[ReaderOptions] = None
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)
    save_critic_bool: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        assert (
            self.state_preprocessing_options is None
            or self.state_preprocessing_options.allowedlist_features is None
        ), ("Please set state allowlist features in state_float_features field of "
            "config instead")
        assert (
            self.action_preprocessing_options is None
            or self.action_preprocessing_options.allowedlist_features is None
        ), ("Please set action allowlist features in action_float_features field of "
            "config instead")

    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ) -> Policy:
        """Create online actor critic policy."""

        if serving:
            assert normalization_data_map
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module,
                                          normalization_data_map))
        else:
            return ActorPolicyWrapper(trainer_module.actor_network)

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config(
        )

    @property
    def action_feature_config(self) -> rlt.ModelFeatureConfig:
        assert len(self.action_float_features
                   ) > 0, "You must set action_float_features"
        return get_feature_config(self.action_float_features)

    def get_state_preprocessing_options(self) -> PreprocessingOptions:
        state_preprocessing_options = (self.state_preprocessing_options
                                       or PreprocessingOptions())
        state_features = [
            ffi.feature_id
            for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state allowedlist_features: {state_features}")
        state_preprocessing_options = replace(
            state_preprocessing_options, allowedlist_features=state_features)
        return state_preprocessing_options

    def get_action_preprocessing_options(self) -> PreprocessingOptions:
        action_preprocessing_options = (self.action_preprocessing_options
                                        or PreprocessingOptions())
        action_features = [
            ffi.feature_id
            for ffi in self.action_feature_config.float_feature_infos
        ]
        logger.info(f"action allowedlist_features: {action_features}")

        # pyre-fixme
        actor_net_builder = self.actor_net_builder.value
        action_feature_override = actor_net_builder.default_action_preprocessing
        logger.info(
            f"Default action_feature_override is {action_feature_override}")
        if self.action_feature_override is not None:
            action_feature_override = self.action_feature_override

        assert action_preprocessing_options.feature_overrides is None
        action_preprocessing_options = replace(
            action_preprocessing_options,
            allowedlist_features=action_features,
            feature_overrides={
                fid: action_feature_override
                for fid in action_features
            },
        )
        return action_preprocessing_options

    def get_data_module(
        self,
        *,
        input_table_spec: Optional[TableSpec] = None,
        reward_options: Optional[RewardOptions] = None,
        reader_options: Optional[ReaderOptions] = None,
        setup_data: Optional[Dict[str, bytes]] = None,
        saved_setup_data: Optional[Dict[str, bytes]] = None,
        resource_options: Optional[ResourceOptions] = None,
    ) -> Optional[ReAgentDataModule]:
        return ActorCriticDataModule(
            input_table_spec=input_table_spec,
            reward_options=reward_options,
            setup_data=setup_data,
            saved_setup_data=saved_setup_data,
            reader_options=reader_options,
            resource_options=resource_options,
            model_manager=self,
        )

    def get_reporter(self):
        return ActorCriticReporter()
Beispiel #4
0
class DiscreteDQNBase(ModelManager):
    target_action_distribution: Optional[List[float]] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `raw`.
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])))
    preprocessing_options: Optional[PreprocessingOptions] = None
    reader_options: Optional[ReaderOptions] = None
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__init__()
        self._metrics_to_score = None
        self._q_network: Optional[ModelBase] = None

    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module(), rl_parameters=self.rl_parameters)
        else:
            sampler = SoftmaxActionSampler(
                temperature=self.rl_parameters.temperature)
            # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
            scorer = discrete_dqn_scorer(self.trainer.q_network)
            return Policy(scorer=scorer, sampler=sampler)

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config(
        )

    @property
    def metrics_to_score(self) -> List[str]:
        assert self._reward_options is not None
        if self._metrics_to_score is None:
            # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`.
            self._metrics_to_score = get_metrics_to_score(
                # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`.
                self._reward_options.metric_reward_values)
        return self._metrics_to_score

    @property
    def should_generate_eval_dataset(self) -> bool:
        return self.eval_parameters.calc_cpe_in_training

    @property
    def required_normalization_keys(self) -> List[str]:
        return [NormalizationKey.STATE]

    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        preprocessing_options = self.preprocessing_options or PreprocessingOptions(
        )
        logger.info("Overriding whitelist_features")
        state_features = [
            ffi.feature_id
            for ffi in self.state_feature_config.float_feature_infos
        ]
        preprocessing_options = preprocessing_options._replace(
            whitelist_features=state_features)
        return {
            NormalizationKey.STATE:
            NormalizationData(dense_normalization_parameters=
                              identify_normalization_parameters(
                                  input_table_spec, InputColumn.STATE_FEATURES,
                                  preprocessing_options))
        }

    def query_data(
        self,
        input_table_spec: TableSpec,
        sample_range: Optional[Tuple[float, float]],
        reward_options: RewardOptions,
    ) -> Dataset:
        return query_data(
            input_table_spec=input_table_spec,
            discrete_action=True,
            actions=self.action_names,
            include_possible_actions=True,
            sample_range=sample_range,
            custom_reward_expression=reward_options.custom_reward_expression,
            multi_steps=self.multi_steps,
            gamma=self.rl_parameters.gamma,
        )

    @property
    def multi_steps(self) -> Optional[int]:
        return self.rl_parameters.multi_steps

    def build_batch_preprocessor(self) -> BatchPreprocessor:
        state_preprocessor = Preprocessor(
            self.state_normalization_data.dense_normalization_parameters,
            use_gpu=self.use_gpu,
        )
        return DiscreteDqnBatchPreprocessor(
            num_actions=len(self.action_names),
            state_preprocessor=state_preprocessor,
            use_gpu=self.use_gpu,
        )

    def get_reporter(self):
        return DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )

    def train(
        self,
        train_dataset: Dataset,
        eval_dataset: Optional[Dataset],
        num_epochs: int,
        reader_options: ReaderOptions,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        batch_preprocessor = self.build_batch_preprocessor()
        reporter = self.get_reporter()
        # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`.
        self.trainer.add_observer(reporter)

        train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            trainer_module=self.trainer,
            num_epochs=num_epochs,
            use_gpu=self.use_gpu,
            batch_preprocessor=batch_preprocessor,
            reader_options=self.reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
        )
        # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
        training_report = RLTrainingReport.make_union_instance(
            reporter.generate_training_report())
        return RLTrainingOutput(training_report=training_report)
Beispiel #5
0
class DiscreteDQNBase(ModelManager):
    target_action_distribution: Optional[List[float]] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `raw`.
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])))
    preprocessing_options: Optional[PreprocessingOptions] = None
    reader_options: Optional[ReaderOptions] = None
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._metrics_to_score = None
        self._q_network: Optional[ModelBase] = None

    def create_policy(self, serving: bool) -> Policy:
        """Create an online DiscreteDQN Policy from env."""
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module(), rl_parameters=self.rl_parameters)
        else:
            sampler = GreedyActionSampler()
            # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
            scorer = discrete_dqn_scorer(self.trainer.q_network)
            return Policy(scorer=scorer, sampler=sampler)

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config(
        )

    @property
    def metrics_to_score(self) -> List[str]:
        assert self._reward_options is not None
        if self._metrics_to_score is None:
            # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`.
            self._metrics_to_score = get_metrics_to_score(
                self._reward_options.metric_reward_values)
        return self._metrics_to_score

    @property
    def should_generate_eval_dataset(self) -> bool:
        raise RuntimeError

    @property
    def required_normalization_keys(self) -> List[str]:
        return [NormalizationKey.STATE]

    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        raise RuntimeError

    def query_data(
        self,
        input_table_spec: TableSpec,
        sample_range: Optional[Tuple[float, float]],
        reward_options: RewardOptions,
        data_fetcher: DataFetcher,
    ) -> Dataset:
        raise RuntimeError

    @property
    def multi_steps(self) -> Optional[int]:
        return self.rl_parameters.multi_steps

    def build_batch_preprocessor(self, use_gpu: bool) -> BatchPreprocessor:
        raise RuntimeError

    def get_data_module(
        self,
        *,
        input_table_spec: Optional[TableSpec] = None,
        reward_options: Optional[RewardOptions] = None,
        reader_options: Optional[ReaderOptions] = None,
        setup_data: Optional[Dict[str, bytes]] = None,
        saved_setup_data: Optional[Dict[str, bytes]] = None,
        resource_options: Optional[ResourceOptions] = None,
    ) -> Optional[ReAgentDataModule]:
        return DiscreteDqnDataModule(
            input_table_spec=input_table_spec,
            reward_options=reward_options,
            setup_data=setup_data,
            saved_setup_data=saved_setup_data,
            reader_options=reader_options,
            resource_options=resource_options,
            model_manager=self,
        )

    def get_reporter(self):
        return DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )

    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: Optional[ResourceOptions] = None,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        reporter = self.get_reporter()
        # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`.
        self.trainer.set_reporter(reporter)
        assert data_module

        # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_lightning_trainer`.
        self._lightning_trainer = train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            test_dataset=test_dataset,
            trainer_module=self.trainer,
            data_module=data_module,
            num_epochs=num_epochs,
            logger_name="DiscreteDqn",
            reader_options=reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
            resource_options=resource_options,
        )
        rank = get_rank()
        if rank == 0:
            # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
            training_report = RLTrainingReport.make_union_instance(
                reporter.generate_training_report())
            logger_data = self._lightning_trainer.logger.line_plot_aggregated
            self._lightning_trainer.logger.clear_local_data()
            return RLTrainingOutput(training_report=training_report,
                                    logger_data=logger_data)
        # Output from processes with non-0 rank is not used
        return RLTrainingOutput()
Beispiel #6
0
class PPO(ModelManager):
    __hash__ = param_hash

    trainer_param: PPOTrainerParameters = field(default_factory=PPOTrainerParameters)
    # using DQN net here because it supports `possible_actions_mask`
    policy_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-ignore
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-ignore
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])
        )
    )
    sampler_temperature: float = 1.0

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._policy: Optional[Policy] = None
        assert (
            len(self.action_names) > 1
        ), f"PPO needs at least 2 actions. Got {self.action_names}."

    @property
    def action_names(self):
        return self.trainer_param.actions

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> PPOTrainer:
        policy_net_builder = self.policy_net_builder.value
        policy_network = policy_net_builder.build_q_network(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            len(self.action_names),
        )
        value_net = None
        value_net_builder = self.value_net_builder
        if value_net_builder:
            value_net_builder = value_net_builder.value
            value_net = value_net_builder.build_value_network(
                normalization_data_map[NormalizationKey.STATE]
            )
        trainer = PPOTrainer(
            policy=self._create_policy(policy_network),
            value_net=value_net,
            **self.trainer_param.asdict(),  # pyre-ignore
        )
        return trainer

    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ):
        assert isinstance(trainer_module, PPOTrainer)
        if serving:
            assert normalization_data_map is not None
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module, normalization_data_map)
            )
        else:
            return self._create_policy(trainer_module.scorer)

    def _create_policy(self, policy_network):
        if self._policy is None:
            sampler = SoftmaxActionSampler(temperature=self.sampler_temperature)
            self._policy = Policy(scorer=policy_network, sampler=sampler)
        return self._policy

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, PPOTrainer)
        policy_serving_module = self.policy_net_builder.value.build_serving_module(
            q_network=trainer_module.scorer,
            state_normalization_data=normalization_data_map[NormalizationKey.STATE],
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
        return policy_serving_module

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config()
class DiscreteDQNBase(ModelManager):
    target_action_distribution: Optional[List[float]] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `raw`.
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])
        )
    )
    preprocessing_options: Optional[PreprocessingOptions] = None
    reader_options: Optional[ReaderOptions] = None
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()

    @property
    @abc.abstractmethod
    def rl_parameters(self) -> RLParameters:
        pass

    @property
    @abc.abstractmethod
    def action_names(self) -> List[str]:
        # Returns the list of possible actions for this instance of problem
        pass

    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ) -> Policy:
        """Create an online DiscreteDQN Policy from env."""
        if serving:
            assert normalization_data_map
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module, normalization_data_map),
                rl_parameters=self.rl_parameters,
            )
        else:
            sampler = GreedyActionSampler()
            # pyre-fixme[6]: Expected `ModelBase` for 1st param but got
            #  `Union[torch.Tensor, torch.nn.Module]`.
            scorer = discrete_dqn_scorer(trainer_module.q_network)
            return Policy(scorer=scorer, sampler=sampler)

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config()

    def get_state_preprocessing_options(self) -> PreprocessingOptions:
        state_preprocessing_options = (
            self.preprocessing_options or PreprocessingOptions()
        )
        state_features = [
            ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
        ]
        logger.info(f"state allowedlist_features: {state_features}")
        state_preprocessing_options = replace(
            state_preprocessing_options, allowedlist_features=state_features
        )
        return state_preprocessing_options

    @property
    def multi_steps(self) -> Optional[int]:
        return self.rl_parameters.multi_steps

    def get_data_module(
        self,
        *,
        input_table_spec: Optional[TableSpec] = None,
        reward_options: Optional[RewardOptions] = None,
        reader_options: Optional[ReaderOptions] = None,
        setup_data: Optional[Dict[str, bytes]] = None,
        saved_setup_data: Optional[Dict[str, bytes]] = None,
        resource_options: Optional[ResourceOptions] = None,
    ) -> Optional[ReAgentDataModule]:
        return DiscreteDqnDataModule(
            input_table_spec=input_table_spec,
            reward_options=reward_options,
            setup_data=setup_data,
            saved_setup_data=saved_setup_data,
            reader_options=reader_options,
            resource_options=resource_options,
            model_manager=self,
        )

    def get_reporter(self):
        return DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )