Exemplo n.º 1
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(),
             max_num_actions=self.num_candidates,
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(num_candidates=self.num_candidates,
                                 q_network=self._q_network)
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)
Exemplo n.º 2
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ):
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`.
             max_num_actions=self.num_candidates,
             # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`.
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(
             num_candidates=self.num_candidates,
             # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
             #  `Union[torch.Tensor, torch.nn.Module]`.
             q_network=trainer_module.q_network,
         )
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)