Ejemplo n.º 1
0
def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy:
    """
    serving_module is the result of ModelManager.build_serving_module().
    This function creates a Policy for gym environments.
    """
    module_name = serving_module.original_name
    if module_name.endswith("DiscreteDqnPredictorWrapper"):
        sampler = GreedyActionSampler()
        scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(serving_module))
        return Policy(scorer=scorer, sampler=sampler)
    elif module_name.endswith("ActorPredictorWrapper"):
        return ActorPredictorPolicy(
            predictor=ActorPredictorUnwrapper(serving_module))
    elif module_name.endswith("ParametricDqnPredictorWrapper"):
        # TODO: remove this dependency
        max_num_actions = kwargs.get("max_num_actions", None)
        assert (max_num_actions
                is not None), f"max_num_actions not given for Parametric DQN."
        sampler = GreedyActionSampler()
        scorer = parametric_dqn_serving_scorer(
            max_num_actions=max_num_actions,
            q_network=ParametricDqnPredictorUnwrapper(serving_module),
        )
        return Policy(scorer=scorer, sampler=sampler)
    else:
        raise NotImplementedError(
            f"Predictor policy for serving module {serving_module} not available."
        )
Ejemplo n.º 2
0
 def __init__(self, wrapped_dqn_predictor, rl_parameters: Optional[RLParameters]):
     if rl_parameters and rl_parameters.softmax_policy:
         self.sampler = SoftmaxActionSampler(temperature=rl_parameters.temperature)
     else:
         self.sampler = GreedyActionSampler()
     self.scorer = discrete_dqn_serving_scorer(
         q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
     )
Ejemplo n.º 3
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         sampler = GreedyActionSampler()
         scorer = discrete_dqn_serving_scorer(
             DiscreteDqnPredictorUnwrapper(self.build_serving_module()))
     else:
         sampler = SoftmaxActionSampler(
             temperature=self.rl_parameters.temperature)
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_qrdqn_scorer(self.trainer.q_network)
     return Policy(scorer=scorer, sampler=sampler)
Ejemplo n.º 4
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
        from reagent.gym.policies.scorers.discrete_scorer import (
            discrete_dqn_scorer,
            discrete_dqn_serving_scorer,
        )

        sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
        if serving:
            scorer = discrete_dqn_serving_scorer(
                DiscreteDqnPredictorUnwrapper(self.build_serving_module())
            )
        else:
            scorer = discrete_dqn_scorer(self.trainer.q_network)
        return Policy(scorer=scorer, sampler=sampler)
Ejemplo n.º 5
0
 def __init__(self, wrapped_dqn_predictor):
     self.sampler = GreedyActionSampler()
     self.scorer = discrete_dqn_serving_scorer(
         q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
     )