def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): """Create an online DiscreteDQN Policy from env.""" # FIXME: this only works for one-hot encoded actions # pyre-fixme[16]: `Tensor` has no attribute `input_prototype`. action_dim = trainer_module.q_network.input_prototype()[1].float_features.shape[ 1 ] if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), max_num_actions=action_dim, ) else: # pyre-fixme[16]: `ParametricDQNBase` has no attribute `rl_parameters`. sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature) scorer = parametric_dqn_scorer( max_num_actions=action_dim, # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got # `Union[torch.Tensor, torch.nn.Module]`. q_network=trainer_module.q_network, ) return Policy(scorer=scorer, sampler=sampler)
def evaluate_gym( env_name: str, model: ModelManager__Union, publisher: ModelPublisher__Union, num_eval_episodes: int, passing_score_bar: float, max_steps: Optional[int] = None, ): publisher_manager = publisher.value assert isinstance( publisher_manager, FileSystemPublisher ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher" env = Gym(env_name=env_name) torchscript_path = publisher_manager.get_latest_published_model( model.value) jit_model = torch.jit.load(torchscript_path) policy = create_predictor_policy_from_model(jit_model) agent = Agent.create_for_env_with_serving_policy(env, policy) rewards = evaluate_for_n_episodes(n=num_eval_episodes, env=env, agent=agent, max_steps=max_steps) avg_reward = np.mean(rewards) logger.info(f"Average reward over {num_eval_episodes} is {avg_reward}.\n" f"List of rewards: {rewards}") assert (avg_reward >= passing_score_bar ), f"{avg_reward} fails to pass the bar of {passing_score_bar}!" return
def create_policy(self, serving: bool) -> Policy: """Create online actor critic policy.""" if serving: return create_predictor_policy_from_model( self.build_actor_module()) else: return ActorPolicyWrapper(self._actor_network)
def create_policy(self, serving: bool = False): if serving: return create_predictor_policy_from_model(self.build_serving_module()) else: if self._policy is None: sampler = SoftmaxActionSampler(temperature=self.sampler_temperature) # pyre-ignore self._policy = Policy(scorer=self._policy_network, sampler=sampler) return self._policy
def create_policy(self, serving: bool) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: return create_predictor_policy_from_model( self.build_serving_module(), rl_parameters=self.rl_parameters) else: sampler = GreedyActionSampler() # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. scorer = discrete_dqn_scorer(self.trainer.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy(self, serving: bool) -> Policy: if serving: return create_predictor_policy_from_model( self.build_serving_module(), max_num_actions=self.num_candidates, slate_size=self.slate_size, ) else: scorer = slate_q_scorer(num_candidates=self.num_candidates, q_network=self._q_network) sampler = TopKSampler(k=self.slate_size) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): assert isinstance(trainer_module, PPOTrainer) if serving: assert normalization_data_map is not None return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map) ) else: return self._create_policy(trainer_module.scorer)
def create_policy(self, serving: bool) -> Policy: """ Create an online DiscreteDQN Policy from env. """ # FIXME: this only works for one-hot encoded actions action_dim = get_num_output_features( self.action_normalization_data.dense_normalization_parameters) if serving: return create_predictor_policy_from_model( self.build_serving_module(), max_num_actions=action_dim) else: sampler = SoftmaxActionSampler( temperature=self.rl_parameters.temperature) scorer = parametric_dqn_scorer(max_num_actions=action_dim, q_network=self._q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create online actor critic policy.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map)) else: return ActorPolicyWrapper(trainer_module.actor_network)
def make_agent_from_model( env: Gym, model: ModelManager__Union, publisher: ModelPublisher__Union, module_name: str, ): publisher_manager = publisher.value assert isinstance( publisher_manager, FileSystemPublisher ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher" module_names = model.value.serving_module_names() assert module_name in module_names, f"{module_name} not in {module_names}" torchscript_path = publisher_manager.get_latest_published_model( model.value, module_name) jit_model = torch.jit.load(torchscript_path) policy = create_predictor_policy_from_model(jit_model) agent = Agent.create_for_env_with_serving_policy(env, policy) return agent
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ) -> Policy: """Create an online DiscreteDQN Policy from env.""" if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), rl_parameters=self.rl_parameters, ) else: sampler = GreedyActionSampler() # pyre-fixme[6]: Expected `ModelBase` for 1st param but got # `Union[torch.Tensor, torch.nn.Module]`. scorer = discrete_dqn_scorer(trainer_module.q_network) return Policy(scorer=scorer, sampler=sampler)
def create_policy( self, trainer_module: ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, ): if serving: assert normalization_data_map return create_predictor_policy_from_model( self.build_serving_module(trainer_module, normalization_data_map), # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`. max_num_actions=self.num_candidates, # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`. slate_size=self.slate_size, ) else: scorer = slate_q_scorer( num_candidates=self.num_candidates, # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got # `Union[torch.Tensor, torch.nn.Module]`. q_network=trainer_module.q_network, ) sampler = TopKSampler(k=self.slate_size) return Policy(scorer=scorer, sampler=sampler)