def state_feature_config_provider( self) -> ModelFeatureConfigProvider__Union: """ For online gym """ raw = RawModelFeatureConfigProvider( float_feature_infos=[ rlt.FloatFeatureInfo(name="arm0_sample", feature_id=0), rlt.FloatFeatureInfo(name="arm1_sample", feature_id=1), rlt.FloatFeatureInfo(name="arm2_sample", feature_id=2), rlt.FloatFeatureInfo(name="arm3_sample", feature_id=3), rlt.FloatFeatureInfo(name="arm4_sample", feature_id=4), ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="legal", feature_id=100, id_mapping_name="legal_actions") ], id_score_list_feature_configs=[ rlt.IdScoreListFeatureConfig(name="mu_changes", feature_id=1000, id_mapping_name="arms_list") ], id_mapping_config={ "legal_actions": rlt.IdMapping(ids=[0, 1, 2, 3, 4, 5]), "arms_list": rlt.IdMapping(ids=[0, 1, 2, 3, 4]), }, ) # pyre-fixme[16]: `ModelFeatureConfigProvider__Union` has no attribute # `make_union_instance`. return ModelFeatureConfigProvider__Union.make_union_instance(raw)
class Reinforce(ModelManager): __hash__ = param_hash trainer_param: ReinforceTrainerParameters = field( default_factory=ReinforceTrainerParameters) # using DQN net here because it supports `possible_actions_mask` policy_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-ignore default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) value_net_builder: Optional[ValueNetBuilder__Union] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-ignore default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]))) sampler_temperature: float = 1.0 def __post_init_post_parse__(self): super().__post_init_post_parse__() self.action_names = self.trainer_param.actions self._policy: Optional[Policy] = None assert ( len(self.action_names) > 1), f"REINFORCE needs at least 2 actions. Got {self.action_names}." # pyre-ignore def build_trainer(self) -> ReinforceTrainer: policy_net_builder = self.policy_net_builder.value # pyre-ignore self._policy_network = policy_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) value_net = None if self.value_net_builder: value_net_builder = self.value_net_builder.value # pyre-ignore value_net = value_net_builder.build_value_network( self.state_normalization_data) trainer = ReinforceTrainer( policy=self.create_policy(), value_net=value_net, **self.trainer_param.asdict(), # pyre-ignore ) return trainer def create_policy(self, serving: bool = False): if serving: return create_predictor_policy_from_model( self.build_serving_module()) else: if self._policy is None: sampler = SoftmaxActionSampler( temperature=self.sampler_temperature) # pyre-ignore self._policy = Policy(scorer=self._policy_network, sampler=sampler) return self._policy def build_serving_module(self) -> torch.nn.Module: assert self._policy_network is not None policy_serving_module = self.policy_net_builder.value.build_serving_module( q_network=self._policy_network, state_normalization_data=self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, ) return policy_serving_module def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: raise NotImplementedError @property def required_normalization_keys(self) -> List[str]: return [NormalizationKey.STATE] @property def should_generate_eval_dataset(self) -> bool: raise NotImplementedError def query_data( self, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: raise NotImplementedError def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: Optional[ResourceOptions], ) -> RLTrainingOutput: raise NotImplementedError @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config( )
class ActorCriticBase(ModelManager): state_preprocessing_options: Optional[PreprocessingOptions] = None action_preprocessing_options: Optional[PreprocessingOptions] = None action_feature_override: Optional[str] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-fixme[28]: Unexpected keyword argument `raw`. default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]))) action_float_features: List[Tuple[int, str]] = field(default_factory=list) reader_options: Optional[ReaderOptions] = None eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) save_critic_bool: bool = True def __post_init_post_parse__(self): super().__post_init_post_parse__() assert ( self.state_preprocessing_options is None or self.state_preprocessing_options.allowedlist_features is None ), ("Please set state allowlist features in state_float_features field of " "config instead") assert ( self.action_preprocessing_options is None or self.action_preprocessing_options.allowedlist_features is None ), ("Please set action allowlist features in action_float_features field of " "config instead") def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create online actor critic policy.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map)) else: return ActorPolicyWrapper(trainer_module.actor_network) @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config( ) @property def action_feature_config(self) -> rlt.ModelFeatureConfig: assert len(self.action_float_features ) > 0, "You must set action_float_features" return get_feature_config(self.action_float_features) def get_state_preprocessing_options(self) -> PreprocessingOptions: state_preprocessing_options = (self.state_preprocessing_options or PreprocessingOptions()) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state allowedlist_features: {state_features}") state_preprocessing_options = replace( state_preprocessing_options, allowedlist_features=state_features) return state_preprocessing_options def get_action_preprocessing_options(self) -> PreprocessingOptions: action_preprocessing_options = (self.action_preprocessing_options or PreprocessingOptions()) action_features = [ ffi.feature_id for ffi in self.action_feature_config.float_feature_infos ] logger.info(f"action allowedlist_features: {action_features}") # pyre-fixme actor_net_builder = self.actor_net_builder.value action_feature_override = actor_net_builder.default_action_preprocessing logger.info( f"Default action_feature_override is {action_feature_override}") if self.action_feature_override is not None: action_feature_override = self.action_feature_override assert action_preprocessing_options.feature_overrides is None action_preprocessing_options = replace( action_preprocessing_options, allowedlist_features=action_features, feature_overrides={ fid: action_feature_override for fid in action_features }, ) return action_preprocessing_options def get_data_module( self, *, input_table_spec: Optional[TableSpec] = None, reward_options: Optional[RewardOptions] = None, reader_options: Optional[ReaderOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, resource_options: Optional[ResourceOptions] = None, ) -> Optional[ReAgentDataModule]: return ActorCriticDataModule( input_table_spec=input_table_spec, reward_options=reward_options, setup_data=setup_data, saved_setup_data=saved_setup_data, reader_options=reader_options, resource_options=resource_options, model_manager=self, ) def get_reporter(self): return ActorCriticReporter()
class DiscreteDQNBase(ModelManager): target_action_distribution: Optional[List[float]] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-fixme[28]: Unexpected keyword argument `raw`. default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]))) preprocessing_options: Optional[PreprocessingOptions] = None reader_options: Optional[ReaderOptions] = None eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__init__() self._metrics_to_score = None self._q_network: Optional[ModelBase] = None def create_policy(self, serving: bool) -> Policy: """ Create an online DiscreteDQN Policy from env. """ if serving: return create_predictor_policy_from_model( self.build_serving_module(), rl_parameters=self.rl_parameters) else: sampler = SoftmaxActionSampler( temperature=self.rl_parameters.temperature) # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler) @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config( ) @property def metrics_to_score(self) -> List[str]: assert self._reward_options is not None if self._metrics_to_score is None: # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`. self._metrics_to_score = get_metrics_to_score( # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. self._reward_options.metric_reward_values) return self._metrics_to_score @property def should_generate_eval_dataset(self) -> bool: return self.eval_parameters.calc_cpe_in_training @property def required_normalization_keys(self) -> List[str]: return [NormalizationKey.STATE] def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: preprocessing_options = self.preprocessing_options or PreprocessingOptions( ) logger.info("Overriding whitelist_features") state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] preprocessing_options = preprocessing_options._replace( whitelist_features=state_features) return { NormalizationKey.STATE: NormalizationData(dense_normalization_parameters= identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options)) } def query_data( self, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: return query_data( input_table_spec=input_table_spec, discrete_action=True, actions=self.action_names, include_possible_actions=True, sample_range=sample_range, custom_reward_expression=reward_options.custom_reward_expression, multi_steps=self.multi_steps, gamma=self.rl_parameters.gamma, ) @property def multi_steps(self) -> Optional[int]: return self.rl_parameters.multi_steps def build_batch_preprocessor(self) -> BatchPreprocessor: state_preprocessor = Preprocessor( self.state_normalization_data.dense_normalization_parameters, use_gpu=self.use_gpu, ) return DiscreteDqnBatchPreprocessor( num_actions=len(self.action_names), state_preprocessor=state_preprocessor, use_gpu=self.use_gpu, ) def get_reporter(self): return DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int, reader_options: ReaderOptions, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ batch_preprocessor = self.build_batch_preprocessor() reporter = self.get_reporter() # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. self.trainer.add_observer(reporter) train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, trainer_module=self.trainer, num_epochs=num_epochs, use_gpu=self.use_gpu, batch_preprocessor=batch_preprocessor, reader_options=self.reader_options, checkpoint_path=self._lightning_checkpoint_path, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) return RLTrainingOutput(training_report=training_report)
class DiscreteDQNBase(ModelManager): target_action_distribution: Optional[List[float]] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-fixme[28]: Unexpected keyword argument `raw`. default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]))) preprocessing_options: Optional[PreprocessingOptions] = None reader_options: Optional[ReaderOptions] = None eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self._metrics_to_score = None self._q_network: Optional[ModelBase] = None def create_policy(self, serving: bool) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: return create_predictor_policy_from_model( self.build_serving_module(), rl_parameters=self.rl_parameters) else: sampler = GreedyActionSampler() # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler) @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config( ) @property def metrics_to_score(self) -> List[str]: assert self._reward_options is not None if self._metrics_to_score is None: # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`. self._metrics_to_score = get_metrics_to_score( self._reward_options.metric_reward_values) return self._metrics_to_score @property def should_generate_eval_dataset(self) -> bool: raise RuntimeError @property def required_normalization_keys(self) -> List[str]: return [NormalizationKey.STATE] def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: raise RuntimeError def query_data( self, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, data_fetcher: DataFetcher, ) -> Dataset: raise RuntimeError @property def multi_steps(self) -> Optional[int]: return self.rl_parameters.multi_steps def build_batch_preprocessor(self, use_gpu: bool) -> BatchPreprocessor: raise RuntimeError def get_data_module( self, *, input_table_spec: Optional[TableSpec] = None, reward_options: Optional[RewardOptions] = None, reader_options: Optional[ReaderOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, resource_options: Optional[ResourceOptions] = None, ) -> Optional[ReAgentDataModule]: return DiscreteDqnDataModule( input_table_spec=input_table_spec, reward_options=reward_options, setup_data=setup_data, saved_setup_data=saved_setup_data, reader_options=reader_options, resource_options=resource_options, model_manager=self, ) def get_reporter(self): return DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: Optional[ResourceOptions] = None, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ reporter = self.get_reporter() # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`. self.trainer.set_reporter(reporter) assert data_module # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_lightning_trainer`. self._lightning_trainer = train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, test_dataset=test_dataset, trainer_module=self.trainer, data_module=data_module, num_epochs=num_epochs, logger_name="DiscreteDqn", reader_options=reader_options, checkpoint_path=self._lightning_checkpoint_path, resource_options=resource_options, ) rank = get_rank() if rank == 0: # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) logger_data = self._lightning_trainer.logger.line_plot_aggregated self._lightning_trainer.logger.clear_local_data() return RLTrainingOutput(training_report=training_report, logger_data=logger_data) # Output from processes with non-0 rank is not used return RLTrainingOutput()
class PPO(ModelManager): __hash__ = param_hash trainer_param: PPOTrainerParameters = field(default_factory=PPOTrainerParameters) # using DQN net here because it supports `possible_actions_mask` policy_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-ignore default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) value_net_builder: Optional[ValueNetBuilder__Union] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-ignore default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]) ) ) sampler_temperature: float = 1.0 def __post_init_post_parse__(self): super().__post_init_post_parse__() self._policy: Optional[Policy] = None assert ( len(self.action_names) > 1 ), f"PPO needs at least 2 actions. Got {self.action_names}." @property def action_names(self): return self.trainer_param.actions def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> PPOTrainer: policy_net_builder = self.policy_net_builder.value policy_network = policy_net_builder.build_q_network( self.state_feature_config, normalization_data_map[NormalizationKey.STATE], len(self.action_names), ) value_net = None value_net_builder = self.value_net_builder if value_net_builder: value_net_builder = value_net_builder.value value_net = value_net_builder.build_value_network( normalization_data_map[NormalizationKey.STATE] ) trainer = PPOTrainer( policy=self._create_policy(policy_network), value_net=value_net, **self.trainer_param.asdict(), # pyre-ignore ) return trainer def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): assert isinstance(trainer_module, PPOTrainer) if serving: assert normalization_data_map is not None return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map) ) else: return self._create_policy(trainer_module.scorer) def _create_policy(self, policy_network): if self._policy is None: sampler = SoftmaxActionSampler(temperature=self.sampler_temperature) self._policy = Policy(scorer=policy_network, sampler=sampler) return self._policy def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, PPOTrainer) policy_serving_module = self.policy_net_builder.value.build_serving_module( q_network=trainer_module.scorer, state_normalization_data=normalization_data_map[NormalizationKey.STATE], action_names=self.action_names, state_feature_config=self.state_feature_config, ) return policy_serving_module @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config()
class DiscreteDQNBase(ModelManager): target_action_distribution: Optional[List[float]] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-fixme[28]: Unexpected keyword argument `raw`. default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]) ) ) preprocessing_options: Optional[PreprocessingOptions] = None reader_options: Optional[ReaderOptions] = None eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() @property @abc.abstractmethod def rl_parameters(self) -> RLParameters: pass @property @abc.abstractmethod def action_names(self) -> List[str]: # Returns the list of possible actions for this instance of problem pass def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), rl_parameters=self.rl_parameters, ) else: sampler = GreedyActionSampler() # pyre-fixme[6]: Expected `ModelBase` for 1st param but got # `Union[torch.Tensor, torch.nn.Module]`. scorer = discrete_dqn_scorer(trainer_module.q_network) return Policy(scorer=scorer, sampler=sampler) @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config() def get_state_preprocessing_options(self) -> PreprocessingOptions: state_preprocessing_options = ( self.preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos ] logger.info(f"state allowedlist_features: {state_features}") state_preprocessing_options = replace( state_preprocessing_options, allowedlist_features=state_features ) return state_preprocessing_options @property def multi_steps(self) -> Optional[int]: return self.rl_parameters.multi_steps def get_data_module( self, *, input_table_spec: Optional[TableSpec] = None, reward_options: Optional[RewardOptions] = None, reader_options: Optional[ReaderOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, resource_options: Optional[ResourceOptions] = None, ) -> Optional[ReAgentDataModule]: return DiscreteDqnDataModule( input_table_spec=input_table_spec, reward_options=reward_options, setup_data=setup_data, saved_setup_data=saved_setup_data, reader_options=reader_options, resource_options=resource_options, model_manager=self, ) def get_reporter(self): return DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, )