Exemplo n.º 1
0
def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy:
    """
    serving_module is the result of ModelManager.build_serving_module().
    This function creates a Policy for gym environments.
    """
    module_name = serving_module.original_name
    if module_name.endswith("DiscreteDqnPredictorWrapper"):
        rl_parameters = kwargs.get("rl_parameters", None)
        return DiscreteDQNPredictorPolicy(serving_module, rl_parameters)
    elif module_name.endswith("ActorPredictorWrapper"):
        return ActorPredictorPolicy(
            predictor=ActorPredictorUnwrapper(serving_module))
    elif module_name.endswith("ParametricDqnPredictorWrapper"):
        # TODO: remove this dependency
        max_num_actions = kwargs.get("max_num_actions", None)
        assert (max_num_actions
                is not None), f"max_num_actions not given for Parametric DQN."
        q_network = ParametricDqnPredictorUnwrapper(serving_module)

        # TODO: write SlateQ Wrapper
        slate_size = kwargs.get("slate_size", None)
        if slate_size is not None:
            scorer = slate_q_serving_scorer(num_candidates=max_num_actions,
                                            q_network=q_network)
            sampler = TopKSampler(k=slate_size)
        else:
            sampler = GreedyActionSampler()
            scorer = parametric_dqn_serving_scorer(
                max_num_actions=max_num_actions, q_network=q_network)
        return Policy(scorer=scorer, sampler=sampler)
    else:
        raise NotImplementedError(
            f"Predictor policy for serving module {serving_module} not available."
        )
Exemplo n.º 2
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(),
             max_num_actions=self.num_candidates,
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(num_candidates=self.num_candidates,
                                 q_network=self._q_network)
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 3
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ):
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`.
             max_num_actions=self.num_candidates,
             # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`.
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(
             num_candidates=self.num_candidates,
             # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
             #  `Union[torch.Tensor, torch.nn.Module]`.
             q_network=trainer_module.q_network,
         )
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)