def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action, next_action = one_hot_actions(self.num_actions, batch.action, batch.next_action, batch.terminal) if self.trainer_preprocessor is not None: state = self.trainer_preprocessor(batch.state) next_state = self.trainer_preprocessor(batch.next_state) else: state = rlt.FeatureData(float_features=batch.state) next_state = rlt.FeatureData(float_features=batch.next_state) return rlt.DiscreteDqnInput( state=state, action=action, next_state=next_state, next_action=next_action, possible_actions_mask=torch.ones_like(action).float(), possible_next_actions_mask=torch.ones_like(next_action).float(), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def __call__( # type: ignore self, batch: Dict[str, torch.Tensor] ) -> rlt.DiscreteDqnInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"] ) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"] ) # not terminal iff at least one possible for next action not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float() action = F.one_hot(batch["action"].to(torch.int64), self.num_actions) # next action can potentially have value self.num_action if not available next_action = F.one_hot( batch["next_action"].to(torch.int64), self.num_actions + 1 )[:, : self.num_actions] return rlt.DiscreteDqnInput( state=rlt.PreprocessedFeatureVector(preprocessed_state), next_state=rlt.PreprocessedFeatureVector(preprocessed_next_state), action=action, next_action=next_action, reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=not_terminal.unsqueeze(1), possible_actions_mask=batch["possible_actions_mask"], possible_next_actions_mask=batch["possible_next_actions_mask"], extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1).cpu().numpy(), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() action = F.one_hot(batch.action, self.num_actions).squeeze(1).float() # next action is garbage for terminal transitions (so just zero them) next_action = torch.zeros_like(action) non_terminal_indices = (batch.terminal == 0).squeeze(1) next_action[non_terminal_indices] = (F.one_hot( batch.next_action[non_terminal_indices], self.num_actions).squeeze(1).float()) return rlt.DiscreteDqnInput( state=rlt.FeatureData(float_features=batch.state), action=action, next_state=rlt.FeatureData(float_features=batch.next_state), next_action=next_action, possible_actions_mask=torch.ones_like(action).float(), possible_next_actions_mask=torch.ones_like(next_action).float(), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def as_discrete_maxq_training_batch(self): return rlt.DiscreteDqnInput( state=rlt.FeatureData(float_features=self.states), action=self.actions, next_state=rlt.FeatureData(float_features=self.next_states), next_action=self.next_actions, possible_actions_mask=self.possible_actions_mask, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )