def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     assert (len(batch.state.shape) == 2
             ), f"{batch.state.shape} is not (batch_size, state_dim)."
     batch_size, _ = batch.state.shape
     action, next_action = one_hot_actions(self.num_actions, batch.action,
                                           batch.next_action,
                                           batch.terminal)
     possible_actions = get_possible_actions_for_gym(
         batch_size, self.num_actions)
     possible_next_actions = possible_actions.clone()
     possible_actions_mask = torch.ones((batch_size, self.num_actions))
     possible_next_actions_mask = possible_actions_mask.clone()
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=rlt.FeatureData(float_features=action),
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=rlt.FeatureData(float_features=next_action),
         possible_actions=possible_actions,
         possible_actions_mask=possible_actions_mask,
         possible_next_actions=possible_next_actions,
         possible_next_actions_mask=possible_next_actions_mask,
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
Example #2
0
 def __call__(self, batch: Dict[str,
                                torch.Tensor]) -> rlt.ParametricDqnInput:
     batch = batch_to_device(batch, self.device)
     # first preprocess state and action
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"])
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"],
         batch["next_state_features_presence"])
     preprocessed_action = self.action_preprocessor(
         batch["action"], batch["action_presence"])
     preprocessed_next_action = self.action_preprocessor(
         batch["next_action"], batch["next_action_presence"])
     return rlt.ParametricDqnInput(
         state=rlt.FeatureData(preprocessed_state),
         next_state=rlt.FeatureData(preprocessed_next_state),
         action=rlt.FeatureData(preprocessed_action),
         next_action=rlt.FeatureData(preprocessed_next_action),
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=batch["not_terminal"].unsqueeze(1),
         possible_actions=batch["possible_actions"],
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions=batch["possible_next_actions"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )