def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy: """ serving_module is the result of ModelManager.build_serving_module(). This function creates a Policy for gym environments. """ module_name = serving_module.original_name if module_name.endswith("DiscreteDqnPredictorWrapper"): rl_parameters = kwargs.get("rl_parameters", None) return DiscreteDQNPredictorPolicy(serving_module, rl_parameters) elif module_name.endswith("ActorPredictorWrapper"): return ActorPredictorPolicy( predictor=ActorPredictorUnwrapper(serving_module)) elif module_name.endswith("ParametricDqnPredictorWrapper"): # TODO: remove this dependency max_num_actions = kwargs.get("max_num_actions", None) assert (max_num_actions is not None), f"max_num_actions not given for Parametric DQN." q_network = ParametricDqnPredictorUnwrapper(serving_module) # TODO: write SlateQ Wrapper slate_size = kwargs.get("slate_size", None) if slate_size is not None: scorer = slate_q_serving_scorer(num_candidates=max_num_actions, q_network=q_network) sampler = TopKSampler(k=slate_size) else: sampler = GreedyActionSampler() scorer = parametric_dqn_serving_scorer( max_num_actions=max_num_actions, q_network=q_network) return Policy(scorer=scorer, sampler=sampler) else: raise NotImplementedError( f"Predictor policy for serving module {serving_module} not available." )
def create_policy(self, serving: bool) -> Policy: if serving: return create_predictor_policy_from_model( self.build_serving_module(), max_num_actions=self.num_candidates, slate_size=self.slate_size, ) else: scorer = slate_q_scorer(num_candidates=self.num_candidates, q_network=self._q_network) sampler = TopKSampler(k=self.slate_size) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`. max_num_actions=self.num_candidates, # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`. slate_size=self.slate_size, ) else: scorer = slate_q_scorer( num_candidates=self.num_candidates, # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got # `Union[torch.Tensor, torch.nn.Module]`. q_network=trainer_module.q_network, ) sampler = TopKSampler(k=self.slate_size) return Policy(scorer=scorer, sampler=sampler)