Exemplo n.º 1
0
def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy:
    """
    serving_module is the result of ModelManager.build_serving_module().
    This function creates a Policy for gym environments.
    """
    module_name = serving_module.original_name
    if module_name.endswith("DiscreteDqnPredictorWrapper"):
        sampler = GreedyActionSampler()
        scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(serving_module))
        return Policy(scorer=scorer, sampler=sampler)
    elif module_name.endswith("ActorPredictorWrapper"):
        return ActorPredictorPolicy(
            predictor=ActorPredictorUnwrapper(serving_module))
    elif module_name.endswith("ParametricDqnPredictorWrapper"):
        # TODO: remove this dependency
        max_num_actions = kwargs.get("max_num_actions", None)
        assert (max_num_actions
                is not None), f"max_num_actions not given for Parametric DQN."
        sampler = GreedyActionSampler()
        scorer = parametric_dqn_serving_scorer(
            max_num_actions=max_num_actions,
            q_network=ParametricDqnPredictorUnwrapper(serving_module),
        )
        return Policy(scorer=scorer, sampler=sampler)
    else:
        raise NotImplementedError(
            f"Predictor policy for serving module {serving_module} not available."
        )
Exemplo n.º 2
0
 def __init__(self, wrapped_dqn_predictor, rl_parameters: Optional[RLParameters]):
     if rl_parameters and rl_parameters.softmax_policy:
         self.sampler = SoftmaxActionSampler(temperature=rl_parameters.temperature)
     else:
         self.sampler = GreedyActionSampler()
     self.scorer = discrete_dqn_serving_scorer(
         q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
     )
Exemplo n.º 3
0
class DiscreteDQNPredictorPolicy(Policy):
    def __init__(self, wrapped_dqn_predictor,
                 rl_parameters: Optional[RLParameters]):
        if rl_parameters and rl_parameters.softmax_policy:
            self.sampler = SoftmaxActionSampler(
                temperature=rl_parameters.temperature)
        else:
            self.sampler = GreedyActionSampler()
        self.scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor))

    # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
    #  its type `no_grad` is not callable.
    @torch.no_grad()
    def act(
        self,
        obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]],
        possible_actions_mask: Optional[np.ndarray],
    ) -> rlt.ActorOutput:
        """Input is either state_with_presence, or
        ServingFeatureData (in the case of sparse features)"""
        assert isinstance(obs, tuple)
        if isinstance(obs, rlt.ServingFeatureData):
            state: rlt.ServingFeatureData = obs
        else:
            state = rlt.ServingFeatureData(
                float_features_with_presence=obs,
                id_list_features={},
                id_score_list_features={},
            )
        scores = self.scorer(state, possible_actions_mask)
        return self.sampler.sample_action(scores).cpu().detach()
Exemplo n.º 4
0
 def create_policy(self, serving: bool) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(), rl_parameters=self.rl_parameters)
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_dqn_scorer(self.trainer.q_network)
         return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 5
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         sampler = GreedyActionSampler()
         scorer = discrete_dqn_serving_scorer(
             DiscreteDqnPredictorUnwrapper(self.build_serving_module()))
     else:
         sampler = SoftmaxActionSampler(
             temperature=self.rl_parameters.temperature)
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_qrdqn_scorer(self.trainer.q_network)
     return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 6
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             rl_parameters=self.rl_parameters,
         )
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[6]: Expected `ModelBase` for 1st param but got
         #  `Union[torch.Tensor, torch.nn.Module]`.
         scorer = discrete_dqn_scorer(trainer_module.q_network)
         return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 7
0
class DiscreteDQNPredictorPolicy(Policy):
    def __init__(self, wrapped_dqn_predictor):
        self.sampler = GreedyActionSampler()
        self.scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
        )

    @torch.no_grad()
    def act(
        self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]]
    ) -> rlt.ActorOutput:
        """ Input is either state_with_presence, or
        ServingFeatureData (in the case of sparse features) """
        assert isinstance(obs, tuple)
        if isinstance(obs, rlt.ServingFeatureData):
            state: rlt.ServingFeatureData = obs
        else:
            state = rlt.ServingFeatureData(
                float_features_with_presence=obs,
                id_list_features={},
                id_score_list_features={},
            )
        scores = self.scorer(state)
        return self.sampler.sample_action(scores).cpu().detach()
Exemplo n.º 8
0
 def __init__(self, wrapped_dqn_predictor):
     self.sampler = GreedyActionSampler()
     self.scorer = discrete_dqn_serving_scorer(
         q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
     )