def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy: """ serving_module is the result of ModelManager.build_serving_module(). This function creates a Policy for gym environments. """ module_name = serving_module.original_name if module_name.endswith("DiscreteDqnPredictorWrapper"): sampler = GreedyActionSampler() scorer = discrete_dqn_serving_scorer( q_network=DiscreteDqnPredictorUnwrapper(serving_module)) return Policy(scorer=scorer, sampler=sampler) elif module_name.endswith("ActorPredictorWrapper"): return ActorPredictorPolicy( predictor=ActorPredictorUnwrapper(serving_module)) elif module_name.endswith("ParametricDqnPredictorWrapper"): # TODO: remove this dependency max_num_actions = kwargs.get("max_num_actions", None) assert (max_num_actions is not None), f"max_num_actions not given for Parametric DQN." sampler = GreedyActionSampler() scorer = parametric_dqn_serving_scorer( max_num_actions=max_num_actions, q_network=ParametricDqnPredictorUnwrapper(serving_module), ) return Policy(scorer=scorer, sampler=sampler) else: raise NotImplementedError( f"Predictor policy for serving module {serving_module} not available." )
def __init__(self, wrapped_dqn_predictor, rl_parameters: Optional[RLParameters]): if rl_parameters and rl_parameters.softmax_policy: self.sampler = SoftmaxActionSampler(temperature=rl_parameters.temperature) else: self.sampler = GreedyActionSampler() self.scorer = discrete_dqn_serving_scorer( q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor) )
class DiscreteDQNPredictorPolicy(Policy): def __init__(self, wrapped_dqn_predictor, rl_parameters: Optional[RLParameters]): if rl_parameters and rl_parameters.softmax_policy: self.sampler = SoftmaxActionSampler( temperature=rl_parameters.temperature) else: self.sampler = GreedyActionSampler() self.scorer = discrete_dqn_serving_scorer( q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)) # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because # its type `no_grad` is not callable. @torch.no_grad() def act( self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]], possible_actions_mask: Optional[np.ndarray], ) -> rlt.ActorOutput: """Input is either state_with_presence, or ServingFeatureData (in the case of sparse features)""" assert isinstance(obs, tuple) if isinstance(obs, rlt.ServingFeatureData): state: rlt.ServingFeatureData = obs else: state = rlt.ServingFeatureData( float_features_with_presence=obs, id_list_features={}, id_score_list_features={}, ) scores = self.scorer(state, possible_actions_mask) return self.sampler.sample_action(scores).cpu().detach()
def create_policy(self, serving: bool) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: return create_predictor_policy_from_model( self.build_serving_module(), rl_parameters=self.rl_parameters) else: sampler = GreedyActionSampler() # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy(self, serving: bool) -> Policy: if serving: sampler = GreedyActionSampler() scorer = discrete_dqn_serving_scorer( DiscreteDqnPredictorUnwrapper(self.build_serving_module())) else: sampler = SoftmaxActionSampler( temperature=self.rl_parameters.temperature) # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_qrdqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), rl_parameters=self.rl_parameters, ) else: sampler = GreedyActionSampler() # pyre-fixme[6]: Expected `ModelBase` for 1st param but got # `Union[torch.Tensor, torch.nn.Module]`. scorer = discrete_dqn_scorer(trainer_module.q_network) return Policy(scorer=scorer, sampler=sampler)
class DiscreteDQNPredictorPolicy(Policy): def __init__(self, wrapped_dqn_predictor): self.sampler = GreedyActionSampler() self.scorer = discrete_dqn_serving_scorer( q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor) ) @torch.no_grad() def act( self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]] ) -> rlt.ActorOutput: """ Input is either state_with_presence, or ServingFeatureData (in the case of sparse features) """ assert isinstance(obs, tuple) if isinstance(obs, rlt.ServingFeatureData): state: rlt.ServingFeatureData = obs else: state = rlt.ServingFeatureData( float_features_with_presence=obs, id_list_features={}, id_score_list_features={}, ) scores = self.scorer(state) return self.sampler.sample_action(scores).cpu().detach()
def __init__(self, wrapped_dqn_predictor): self.sampler = GreedyActionSampler() self.scorer = discrete_dqn_serving_scorer( q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor) )