def create_policy(self, serving: bool) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: return create_predictor_policy_from_model( self.build_serving_module(), rl_parameters=self.rl_parameters) else: sampler = GreedyActionSampler() # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy(self) -> Policy: """ Create an online DiscreteDQN Policy from env. """ # Avoiding potentially importing gym when it's not installed from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature) scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy(self, serving: bool) -> Policy: """ Create an online DiscreteDQN Policy from env. """ if serving: sampler = GreedyActionSampler() scorer = discrete_dqn_serving_scorer( DiscreteDqnPredictorUnwrapper(self.build_serving_module())) else: # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`. # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`. # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`. sampler = SoftmaxActionSampler( temperature=self.rl_parameters.temperature) # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy(self, serving: bool) -> Policy: """ Create an online DiscreteDQN Policy from env. """ from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler from reagent.gym.policies.scorers.discrete_scorer import ( discrete_dqn_scorer, discrete_dqn_serving_scorer, ) sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature) if serving: scorer = discrete_dqn_serving_scorer( DiscreteDqnPredictorUnwrapper(self.build_serving_module()) ) else: scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), rl_parameters=self.rl_parameters, ) else: sampler = GreedyActionSampler() # pyre-fixme[6]: Expected `ModelBase` for 1st param but got # `Union[torch.Tensor, torch.nn.Module]`. scorer = discrete_dqn_scorer(trainer_module.q_network) return Policy(scorer=scorer, sampler=sampler)