Exemplo n.º 1
0
 def create_policy(self, serving: bool) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(), rl_parameters=self.rl_parameters)
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_dqn_scorer(self.trainer.q_network)
         return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 2
0
    def create_policy(self) -> Policy:
        """ Create an online DiscreteDQN Policy from env.
        """
        # Avoiding potentially importing gym when it's not installed

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
        from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer

        sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
        scorer = discrete_dqn_scorer(self.trainer.q_network)
        return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 3
0
 def create_policy(self, serving: bool) -> Policy:
     """ Create an online DiscreteDQN Policy from env. """
     if serving:
         sampler = GreedyActionSampler()
         scorer = discrete_dqn_serving_scorer(
             DiscreteDqnPredictorUnwrapper(self.build_serving_module()))
     else:
         # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`.
         # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`.
         # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `rl_parameters`.
         sampler = SoftmaxActionSampler(
             temperature=self.rl_parameters.temperature)
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_dqn_scorer(self.trainer.q_network)
     return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 4
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
        from reagent.gym.policies.scorers.discrete_scorer import (
            discrete_dqn_scorer,
            discrete_dqn_serving_scorer,
        )

        sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
        if serving:
            scorer = discrete_dqn_serving_scorer(
                DiscreteDqnPredictorUnwrapper(self.build_serving_module())
            )
        else:
            scorer = discrete_dqn_scorer(self.trainer.q_network)
        return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 5
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             rl_parameters=self.rl_parameters,
         )
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[6]: Expected `ModelBase` for 1st param but got
         #  `Union[torch.Tensor, torch.nn.Module]`.
         scorer = discrete_dqn_scorer(trainer_module.q_network)
         return Policy(scorer=scorer, sampler=sampler)