Exemplo n.º 1
0
def create_predictor_policy_from_model(serving_module, **kwargs) -> Policy:
    """
    serving_module is the result of ModelManager.build_serving_module().
    This function creates a Policy for gym environments.
    """
    module_name = serving_module.original_name
    if module_name.endswith("DiscreteDqnPredictorWrapper"):
        return DiscreteDQNPredictorPolicy(serving_module)
    elif module_name.endswith("ActorPredictorWrapper"):
        return ActorPredictorPolicy(predictor=ActorPredictorUnwrapper(serving_module))
    elif module_name.endswith("ParametricDqnPredictorWrapper"):
        # TODO: remove this dependency
        max_num_actions = kwargs.get("max_num_actions", None)
        assert (
            max_num_actions is not None
        ), f"max_num_actions not given for Parametric DQN."
        sampler = GreedyActionSampler()
        scorer = parametric_dqn_serving_scorer(
            max_num_actions=max_num_actions,
            q_network=ParametricDqnPredictorUnwrapper(serving_module),
        )
        return Policy(scorer=scorer, sampler=sampler)
    else:
        raise NotImplementedError(
            f"Predictor policy for serving module {serving_module} not available."
        )
Exemplo n.º 2
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create online actor critic policy. """

        from reagent.gym.policies import ActorPredictorPolicy

        if serving:
            return ActorPredictorPolicy(
                ActorPredictorUnwrapper(self.build_serving_module()))
        else:
            # pyre-fixme[16]: `ActorCriticBase` has no attribute `_actor_network`.
            return ActorPolicyWrapper(self._actor_network)