class DiscreteDQN(DiscreteDQNBase): __hash__ = param_hash trainer_param: DQNTrainerParameters = field( default_factory=DQNTrainerParameters) net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `Dueling`. default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) cpe_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected= FullyConnected())) def __post_init_post_parse__(self): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl self.action_names = self.trainer_param.actions assert ( len(self.action_names) > 1 ), f"DiscreteDQNModel needs at least 2 actions. Got {self.action_names}." if self.trainer_param.minibatch_size % 8 != 0: logger.warn( f"minibatch size ({self.trainer_param.minibatch_size}) " "should be divisible by 8 for performance reasons!") # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager` # inconsistently. def build_trainer(self) -> DQNTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) q_network_target = q_network.get_target_network() reward_network, q_network_cpe, q_network_cpe_target = None, None, None if self.eval_parameters.calc_cpe_in_training: # Metrics + reward num_output_nodes = (len(self.metrics_to_score) + 1) * len( # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`. self.trainer_param.actions) cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) q_network_cpe_target = q_network_cpe.get_target_network() # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`. self._q_network = q_network trainer = DQNTrainer( q_network=q_network, q_network_target=q_network_target, reward_network=reward_network, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, metrics_to_score=self.metrics_to_score, evaluation=self.eval_parameters, # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def get_reporter(self): return DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( self._q_network, self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, )
class Reinforce(ModelManager): __hash__ = param_hash trainer_param: ReinforceTrainerParameters = field( default_factory=ReinforceTrainerParameters) # using DQN net here because it supports `possible_actions_mask` policy_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-ignore default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) value_net_builder: Optional[ValueNetBuilder__Union] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-ignore default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]))) sampler_temperature: float = 1.0 def __post_init_post_parse__(self): super().__post_init_post_parse__() self.action_names = self.trainer_param.actions self._policy: Optional[Policy] = None assert ( len(self.action_names) > 1), f"REINFORCE needs at least 2 actions. Got {self.action_names}." # pyre-ignore def build_trainer(self) -> ReinforceTrainer: policy_net_builder = self.policy_net_builder.value # pyre-ignore self._policy_network = policy_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) value_net = None if self.value_net_builder: value_net_builder = self.value_net_builder.value # pyre-ignore value_net = value_net_builder.build_value_network( self.state_normalization_data) trainer = ReinforceTrainer( policy=self.create_policy(), value_net=value_net, **self.trainer_param.asdict(), # pyre-ignore ) return trainer def create_policy(self, serving: bool = False): if serving: return create_predictor_policy_from_model( self.build_serving_module()) else: if self._policy is None: sampler = SoftmaxActionSampler( temperature=self.sampler_temperature) # pyre-ignore self._policy = Policy(scorer=self._policy_network, sampler=sampler) return self._policy def build_serving_module(self) -> torch.nn.Module: assert self._policy_network is not None policy_serving_module = self.policy_net_builder.value.build_serving_module( q_network=self._policy_network, state_normalization_data=self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, ) return policy_serving_module def run_feature_identification( self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]: raise NotImplementedError @property def required_normalization_keys(self) -> List[str]: return [NormalizationKey.STATE] @property def should_generate_eval_dataset(self) -> bool: raise NotImplementedError def query_data( self, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: raise NotImplementedError def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: Optional[ResourceOptions], ) -> RLTrainingOutput: raise NotImplementedError @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config( )
class DiscreteDQN(DiscreteDQNBase): __hash__ = param_hash trainer_param: DQNTrainerParameters = field( default_factory=DQNTrainerParameters) net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `Dueling`. # pyre-fixme[28]: Unexpected keyword argument `Dueling`. default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) cpe_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected= FullyConnected())) # TODO: move evaluation parameters to here from trainer_param.evaluation # note that only DiscreteDQN and QRDQN call RLTrainer._initialize_cpe, # so maybe can be removed from the RLTrainer class. def __post_init_post_parse__(self): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl self.eval_parameters = self.trainer_param.evaluation self.action_names = self.trainer_param.actions assert len( self.action_names) > 1, "DiscreteDQNModel needs at least 2 actions" assert ( self.trainer_param.minibatch_size % 8 == 0 ), "The minibatch size must be divisible by 8 for performance reasons." def build_trainer(self) -> DQNTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) if self.use_gpu: q_network = q_network.cuda() q_network_target = q_network.get_target_network() reward_network, q_network_cpe, q_network_cpe_target = None, None, None if self.trainer_param.evaluation.calc_cpe_in_training: # Metrics + reward num_output_nodes = (len(self.metrics_to_score) + 1) * len( self.trainer_param.actions) cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) if self.use_gpu: reward_network.cuda() q_network_cpe.cuda() q_network_cpe_target = q_network_cpe.get_target_network() # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`. # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`. self._q_network = q_network trainer = DQNTrainer( q_network, q_network_target, reward_network, self.trainer_param, self.use_gpu, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, metrics_to_score=self.metrics_to_score, loss_reporter=NoOpLossReporter(), ) return trainer def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( self._q_network, self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, )
class DiscreteCRR(DiscreteDQNBase): __hash__ = param_hash trainer_param: CRRTrainerParameters = field( default_factory=CRRTrainerParameters) actor_net_builder: DiscreteActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: DiscreteActorNetBuilder__Union( FullyConnected=DiscreteFullyConnected())) critic_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) cpe_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected= FullyConnected())) eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self._q1_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl self.action_names = self.trainer_param.actions assert ( len(self.action_names) > 1 ), f"DiscreteDQNModel needs at least 2 actions. Got {self.action_names}." # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager` # inconsistently. def build_trainer(self, use_gpu: bool) -> DiscreteCRRTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `DiscreteCRR` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, len(self.action_names)) # The arguments to q_network1 and q_network2 below are modeled after those in discrete_dqn.py # The target networks will be created in DiscreteCRRTrainer critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `DiscreteCRR` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) q2_network = ( critic_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, len(self.action_names), ) # pyre-fixme[16]: `CRRTrainerParameters` has no attribute # `double_q_learning`. if self.trainer_param.double_q_learning else None) reward_network, q_network_cpe, q_network_cpe_target = None, None, None if self.eval_parameters.calc_cpe_in_training: # Metrics + reward num_output_nodes = (len(self.metrics_to_score) + 1) * len( # pyre-fixme[16]: `CRRTrainerParameters` has no attribute `actions`. self.trainer_param.actions) cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_data, num_output_nodes, ) q_network_cpe_target = q_network_cpe.get_target_network() trainer = DiscreteCRRTrainer( actor_network=self._actor_network, q1_network=self._q1_network, reward_network=reward_network, q2_network=q2_network, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, metrics_to_score=self.metrics_to_score, evaluation=self.eval_parameters, # pyre-fixme[16]: `CRRTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def create_policy(self, serving: bool) -> Policy: """Create online actor critic policy.""" if serving: return create_predictor_policy_from_model( self.build_actor_module()) else: return ActorPolicyWrapper(self._actor_network) def get_reporter(self): return DiscreteCRRReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) # Note: when using test_gym.py as the entry point, the normalization data # is set when the line normalization = build_normalizer(env) is executed. # The code then calls build_state_normalizer() and build_action_normalizer() # in utils.py def serving_module_names(self): module_names = ["default_model", "dqn", "actor_dqn"] if len(self.action_names) == 2: module_names.append("binary_difference_scorer") return module_names def build_serving_modules(self): """ `actor_dqn` is the actor module wrapped in the DQN predictor wrapper. This helps putting the actor in places where DQN predictor wrapper is expected. If the policy is greedy, then this wrapper would work. """ serving_modules = { "default_model": self.build_actor_module(), "dqn": self._build_dqn_module(self._q1_network), "actor_dqn": self._build_dqn_module(ActorDQN(self._actor_network)), } if len(self.action_names) == 2: serving_modules.update({ "binary_difference_scorer": self._build_binary_difference_scorer( ActorDQN(self._actor_network)), }) return serving_modules def _build_dqn_module(self, network): critic_net_builder = self.critic_net_builder.value assert network is not None return critic_net_builder.build_serving_module( network, self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, ) def _build_binary_difference_scorer(self, network): critic_net_builder = self.critic_net_builder.value assert network is not None return critic_net_builder.build_binary_difference_scorer( network, self.state_normalization_data, action_names=self.action_names, state_feature_config=self.state_feature_config, ) # Also, even though the build_serving_module below is directed to # discrete_actor_net_builder.py, which returns ActorPredictorWrapper, # just like in the continuous_actor_net_builder.py, the outputs of the # discrete actor will still be computed differently from those of the # continuous actor because during serving, the act() function for the # Agent class in gym/agents/agents.py returns # self.action_extractor(actor_output), which is created in # create_for_env_with_serving_policy, when # env.get_serving_action_extractor() is called. During serving, # action_extractor calls serving_action_extractor() in env_wrapper.py, # which checks the type of action_space during serving time and treats # spaces.Discrete differently from spaces.Box (continuous). def build_actor_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value assert self._actor_network is not None return net_builder.build_serving_module( self._actor_network, self.state_normalization_data, action_feature_ids=list(range(len(self.action_names))), )
class PPO(ModelManager): __hash__ = param_hash trainer_param: PPOTrainerParameters = field(default_factory=PPOTrainerParameters) # using DQN net here because it supports `possible_actions_mask` policy_net_builder: DiscreteDQNNetBuilder__Union = field( # pyre-ignore default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()) ) value_net_builder: Optional[ValueNetBuilder__Union] = None state_feature_config_provider: ModelFeatureConfigProvider__Union = field( # pyre-ignore default_factory=lambda: ModelFeatureConfigProvider__Union( raw=RawModelFeatureConfigProvider(float_feature_infos=[]) ) ) sampler_temperature: float = 1.0 def __post_init_post_parse__(self): super().__post_init_post_parse__() self._policy: Optional[Policy] = None assert ( len(self.action_names) > 1 ), f"PPO needs at least 2 actions. Got {self.action_names}." @property def action_names(self): return self.trainer_param.actions def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> PPOTrainer: policy_net_builder = self.policy_net_builder.value policy_network = policy_net_builder.build_q_network( self.state_feature_config, normalization_data_map[NormalizationKey.STATE], len(self.action_names), ) value_net = None value_net_builder = self.value_net_builder if value_net_builder: value_net_builder = value_net_builder.value value_net = value_net_builder.build_value_network( normalization_data_map[NormalizationKey.STATE] ) trainer = PPOTrainer( policy=self._create_policy(policy_network), value_net=value_net, **self.trainer_param.asdict(), # pyre-ignore ) return trainer def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): assert isinstance(trainer_module, PPOTrainer) if serving: assert normalization_data_map is not None return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map) ) else: return self._create_policy(trainer_module.scorer) def _create_policy(self, policy_network): if self._policy is None: sampler = SoftmaxActionSampler(temperature=self.sampler_temperature) self._policy = Policy(scorer=policy_network, sampler=sampler) return self._policy def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, PPOTrainer) policy_serving_module = self.policy_net_builder.value.build_serving_module( q_network=trainer_module.scorer, state_normalization_data=normalization_data_map[NormalizationKey.STATE], action_names=self.action_names, state_feature_config=self.state_feature_config, ) return policy_serving_module @property def state_feature_config(self) -> rlt.ModelFeatureConfig: return self.state_feature_config_provider.value.get_model_feature_config()
class DiscreteDQN(DiscreteDQNBase): __hash__ = param_hash trainer_param: DQNTrainerParameters = field( default_factory=DQNTrainerParameters) net_builder: DiscreteDQNNetBuilder__Union = field( default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling() )) cpe_net_builder: DiscreteDQNNetBuilder__Union = field( default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected= FullyConnected())) def __post_init_post_parse__(self, ): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl self.eval_parameters = self.trainer_param.evaluation self.action_names = self.trainer_param.actions assert len( self.action_names) > 1, "DiscreteDQNModel needs at least 2 actions" assert ( self.trainer_param.minibatch_size % 8 == 0 ), "The minibatch size must be divisible by 8 for performance reasons." def build_trainer(self) -> DQNTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( self.state_feature_config, self.state_normalization_parameters, len(self.action_names), ) if self.use_gpu: q_network = q_network.cuda() q_network_target = q_network.get_target_network() reward_network, q_network_cpe, q_network_cpe_target = None, None, None if self.trainer_param.evaluation.calc_cpe_in_training: # Metrics + reward num_output_nodes = (len(self.metrics_to_score) + 1) * len( self.trainer_param.actions) cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_parameters, num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( self.state_feature_config, self.state_normalization_parameters, num_output_nodes, ) if self.use_gpu: reward_network.cuda() q_network_cpe.cuda() q_network_cpe_target = q_network_cpe.get_target_network() self._q_network = q_network trainer = DQNTrainer( q_network, q_network_target, reward_network, self.trainer_param, self.use_gpu, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, metrics_to_score=self.metrics_to_score, loss_reporter=NoOpLossReporter(), ) return trainer def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( self._q_network, self.state_normalization_parameters, action_names=self.action_names, state_feature_config=self.state_feature_config, )