def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ):
        """Create an online DiscreteDQN Policy from env."""

        # FIXME: this only works for one-hot encoded actions
        # pyre-fixme[16]: `Tensor` has no attribute `input_prototype`.
        action_dim = trainer_module.q_network.input_prototype()[1].float_features.shape[
            1
        ]
        if serving:
            assert normalization_data_map
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module, normalization_data_map),
                max_num_actions=action_dim,
            )
        else:
            # pyre-fixme[16]: `ParametricDQNBase` has no attribute `rl_parameters`.
            sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
            scorer = parametric_dqn_scorer(
                max_num_actions=action_dim,
                # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
                #  `Union[torch.Tensor, torch.nn.Module]`.
                q_network=trainer_module.q_network,
            )
            return Policy(scorer=scorer, sampler=sampler)
Beispiel #2
0
def evaluate_gym(
    env_name: str,
    model: ModelManager__Union,
    publisher: ModelPublisher__Union,
    num_eval_episodes: int,
    passing_score_bar: float,
    max_steps: Optional[int] = None,
):
    publisher_manager = publisher.value
    assert isinstance(
        publisher_manager, FileSystemPublisher
    ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher"
    env = Gym(env_name=env_name)
    torchscript_path = publisher_manager.get_latest_published_model(
        model.value)
    jit_model = torch.jit.load(torchscript_path)
    policy = create_predictor_policy_from_model(jit_model)
    agent = Agent.create_for_env_with_serving_policy(env, policy)
    rewards = evaluate_for_n_episodes(n=num_eval_episodes,
                                      env=env,
                                      agent=agent,
                                      max_steps=max_steps)
    avg_reward = np.mean(rewards)
    logger.info(f"Average reward over {num_eval_episodes} is {avg_reward}.\n"
                f"List of rewards: {rewards}")
    assert (avg_reward >= passing_score_bar
            ), f"{avg_reward} fails to pass the bar of {passing_score_bar}!"
    return
Beispiel #3
0
 def create_policy(self, serving: bool) -> Policy:
     """Create online actor critic policy."""
     if serving:
         return create_predictor_policy_from_model(
             self.build_actor_module())
     else:
         return ActorPolicyWrapper(self._actor_network)
Beispiel #4
0
 def create_policy(self, serving: bool = False):
     if serving:
         return create_predictor_policy_from_model(self.build_serving_module())
     else:
         if self._policy is None:
             sampler = SoftmaxActionSampler(temperature=self.sampler_temperature)
             # pyre-ignore
             self._policy = Policy(scorer=self._policy_network, sampler=sampler)
         return self._policy
Beispiel #5
0
 def create_policy(self, serving: bool) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(), rl_parameters=self.rl_parameters)
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_dqn_scorer(self.trainer.q_network)
         return Policy(scorer=scorer, sampler=sampler)
Beispiel #6
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(),
             max_num_actions=self.num_candidates,
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(num_candidates=self.num_candidates,
                                 q_network=self._q_network)
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)
Beispiel #7
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ):
     assert isinstance(trainer_module, PPOTrainer)
     if serving:
         assert normalization_data_map is not None
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map)
         )
     else:
         return self._create_policy(trainer_module.scorer)
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        # FIXME: this only works for one-hot encoded actions
        action_dim = get_num_output_features(
            self.action_normalization_data.dense_normalization_parameters)
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module(), max_num_actions=action_dim)
        else:
            sampler = SoftmaxActionSampler(
                temperature=self.rl_parameters.temperature)
            scorer = parametric_dqn_scorer(max_num_actions=action_dim,
                                           q_network=self._q_network)
            return Policy(scorer=scorer, sampler=sampler)
    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ) -> Policy:
        """Create online actor critic policy."""

        if serving:
            assert normalization_data_map
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module,
                                          normalization_data_map))
        else:
            return ActorPolicyWrapper(trainer_module.actor_network)
Beispiel #10
0
def make_agent_from_model(
    env: Gym,
    model: ModelManager__Union,
    publisher: ModelPublisher__Union,
    module_name: str,
):
    publisher_manager = publisher.value
    assert isinstance(
        publisher_manager, FileSystemPublisher
    ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher"
    module_names = model.value.serving_module_names()
    assert module_name in module_names, f"{module_name} not in {module_names}"
    torchscript_path = publisher_manager.get_latest_published_model(
        model.value, module_name)
    jit_model = torch.jit.load(torchscript_path)
    policy = create_predictor_policy_from_model(jit_model)
    agent = Agent.create_for_env_with_serving_policy(env, policy)
    return agent
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             rl_parameters=self.rl_parameters,
         )
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[6]: Expected `ModelBase` for 1st param but got
         #  `Union[torch.Tensor, torch.nn.Module]`.
         scorer = discrete_dqn_scorer(trainer_module.q_network)
         return Policy(scorer=scorer, sampler=sampler)
Beispiel #12
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ):
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`.
             max_num_actions=self.num_candidates,
             # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`.
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(
             num_candidates=self.num_candidates,
             # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
             #  `Union[torch.Tensor, torch.nn.Module]`.
             q_network=trainer_module.q_network,
         )
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)