コード例 #1
0
class DiscreteDQNPredictorPolicy(Policy):
    def __init__(self, wrapped_dqn_predictor,
                 rl_parameters: Optional[RLParameters]):
        if rl_parameters and rl_parameters.softmax_policy:
            self.sampler = SoftmaxActionSampler(
                temperature=rl_parameters.temperature)
        else:
            self.sampler = GreedyActionSampler()
        self.scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor))

    # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
    #  its type `no_grad` is not callable.
    @torch.no_grad()
    def act(
        self,
        obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]],
        possible_actions_mask: Optional[np.ndarray],
    ) -> rlt.ActorOutput:
        """Input is either state_with_presence, or
        ServingFeatureData (in the case of sparse features)"""
        assert isinstance(obs, tuple)
        if isinstance(obs, rlt.ServingFeatureData):
            state: rlt.ServingFeatureData = obs
        else:
            state = rlt.ServingFeatureData(
                float_features_with_presence=obs,
                id_list_features={},
                id_score_list_features={},
            )
        scores = self.scorer(state, possible_actions_mask)
        return self.sampler.sample_action(scores).cpu().detach()
コード例 #2
0
class DiscreteDQNPredictorPolicy(Policy):
    def __init__(self, wrapped_dqn_predictor):
        self.sampler = GreedyActionSampler()
        self.scorer = discrete_dqn_serving_scorer(
            q_network=DiscreteDqnPredictorUnwrapper(wrapped_dqn_predictor)
        )

    @torch.no_grad()
    def act(
        self, obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]]
    ) -> rlt.ActorOutput:
        """ Input is either state_with_presence, or
        ServingFeatureData (in the case of sparse features) """
        assert isinstance(obs, tuple)
        if isinstance(obs, rlt.ServingFeatureData):
            state: rlt.ServingFeatureData = obs
        else:
            state = rlt.ServingFeatureData(
                float_features_with_presence=obs,
                id_list_features={},
                id_score_list_features={},
            )
        scores = self.scorer(state)
        return self.sampler.sample_action(scores).cpu().detach()