コード例 #1
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     # normalize actions
     action = rescale_actions(
         batch.action,
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     # only normalize non-terminal
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action = torch.zeros_like(action)
     next_action[non_terminal_indices] = rescale_actions(
         batch.next_action[non_terminal_indices],
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     dict_batch = {
         InputColumn.STATE_FEATURES: batch.state,
         InputColumn.NEXT_STATE_FEATURES: batch.next_state,
         InputColumn.ACTION: action,
         InputColumn.NEXT_ACTION: next_action,
         InputColumn.REWARD: batch.reward,
         InputColumn.NOT_TERMINAL: not_terminal,
         InputColumn.STEP: None,
         InputColumn.TIME_DIFF: None,
         InputColumn.EXTRAS: rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     }
     has_candidate_features = False
     try:
         dict_batch.update(
             {
                 InputColumn.CANDIDATE_FEATURES: batch.doc,
                 InputColumn.NEXT_CANDIDATE_FEATURES: batch.next_doc,
             }
         )
         has_candidate_features = True
     except AttributeError:
         pass
     output = rlt.PolicyNetworkInput.from_dict(dict_batch)
     if has_candidate_features:
         output.state = rlt._embed_states(output.state)
         output.next_state = rlt._embed_states(output.next_state)
     return output
コード例 #2
0
 def forward(
     self,
     state_with_presence: Tuple[torch.Tensor, torch.Tensor],
     candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]],
 ):
     assert (
         len(candidate_with_presence_list) == self.num_candidates
     ), f"{len(candidate_with_presence_list)} != {self.num_candidates}"
     preprocessed_state = self.state_preprocessor(*state_with_presence)
     # each is batch_size x candidate_dim, result is batch_size x num_candidates x candidate_dim
     preprocessed_candidates = torch.stack(
         [
             self.candidate_preprocessor(*x)
             for x in candidate_with_presence_list
         ],
         dim=1,
     )
     input = rlt.FeatureData(
         float_features=preprocessed_state,
         candidate_docs=rlt.DocList(
             float_features=preprocessed_candidates,
             mask=torch.tensor(-1),
             value=torch.tensor(-1),
         ),
     )
     input = rlt._embed_states(input)
     action = self.model(input).action
     if self.action_postprocessor is not None:
         # pyre-fixme[29]: `Optional[Postprocessor]` is not a function.
         action = self.action_postprocessor(action)
     return action
コード例 #3
0
 def obs_preprocessor(self, obs: np.ndarray) -> rlt.FeatureData:
     preprocessor = RecsimObsPreprocessor.create_from_env(self)
     preprocessed_obs = preprocessor(obs)
     return rlt._embed_states(preprocessed_obs)