def __call__(self, batch):
        not_terminal = 1.0 - batch.terminal.float()
        action, next_action = one_hot_actions(self.num_actions, batch.action,
                                              batch.next_action,
                                              batch.terminal)
        if self.trainer_preprocessor is not None:
            state = self.trainer_preprocessor(batch.state)
            next_state = self.trainer_preprocessor(batch.next_state)
        else:
            state = rlt.FeatureData(float_features=batch.state)
            next_state = rlt.FeatureData(float_features=batch.next_state)

        return rlt.DiscreteDqnInput(
            state=state,
            action=action,
            next_state=next_state,
            next_action=next_action,
            possible_actions_mask=torch.ones_like(action).float(),
            possible_next_actions_mask=torch.ones_like(next_action).float(),
            reward=batch.reward,
            not_terminal=not_terminal,
            step=None,
            time_diff=None,
            extras=rlt.ExtraData(
                mdp_id=None,
                sequence_number=None,
                action_probability=batch.log_prob.exp(),
                max_num_actions=None,
                metrics=None,
            ),
        )
Example #2
0
 def __call__(  # type: ignore
     self, batch: Dict[str, torch.Tensor]
 ) -> rlt.DiscreteDqnInput:
     batch = batch_to_device(batch, self.device)
     preprocessed_state = self.state_preprocessor(
         batch["state_features"], batch["state_features_presence"]
     )
     preprocessed_next_state = self.state_preprocessor(
         batch["next_state_features"], batch["next_state_features_presence"]
     )
     # not terminal iff at least one possible for next action
     not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float()
     action = F.one_hot(batch["action"].to(torch.int64), self.num_actions)
     # next action can potentially have value self.num_action if not available
     next_action = F.one_hot(
         batch["next_action"].to(torch.int64), self.num_actions + 1
     )[:, : self.num_actions]
     return rlt.DiscreteDqnInput(
         state=rlt.PreprocessedFeatureVector(preprocessed_state),
         next_state=rlt.PreprocessedFeatureVector(preprocessed_next_state),
         action=action,
         next_action=next_action,
         reward=batch["reward"].unsqueeze(1),
         time_diff=batch["time_diff"].unsqueeze(1),
         step=batch["step"].unsqueeze(1),
         not_terminal=not_terminal.unsqueeze(1),
         possible_actions_mask=batch["possible_actions_mask"],
         possible_next_actions_mask=batch["possible_next_actions_mask"],
         extras=rlt.ExtraData(
             mdp_id=batch["mdp_id"].unsqueeze(1).cpu().numpy(),
             sequence_number=batch["sequence_number"].unsqueeze(1),
             action_probability=batch["action_probability"].unsqueeze(1),
         ),
     )
Example #3
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     action = F.one_hot(batch.action, self.num_actions).squeeze(1).float()
     # next action is garbage for terminal transitions (so just zero them)
     next_action = torch.zeros_like(action)
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action[non_terminal_indices] = (F.one_hot(
         batch.next_action[non_terminal_indices],
         self.num_actions).squeeze(1).float())
     return rlt.DiscreteDqnInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=action,
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=next_action,
         possible_actions_mask=torch.ones_like(action).float(),
         possible_next_actions_mask=torch.ones_like(next_action).float(),
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
Example #4
0
 def as_discrete_maxq_training_batch(self):
     return rlt.DiscreteDqnInput(
         state=rlt.FeatureData(float_features=self.states),
         action=self.actions,
         next_state=rlt.FeatureData(float_features=self.next_states),
         next_action=self.next_actions,
         possible_actions_mask=self.possible_actions_mask,
         possible_next_actions_mask=self.possible_next_actions_mask,
         reward=self.rewards,
         not_terminal=self.not_terminal,
         step=self.step,
         time_diff=self.time_diffs,
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )