def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action) reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terminal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() possible_actions_mask = torch.tensor(possible_actions_mask) next_possible_actions_mask = not_terminal.repeat(1, num_actions) log_prob = torch.tensor(log_prob) assert ( action.size(1) == num_actions ), f"action size(1) is {action.size(1)} while num_actions is {num_actions}" return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedDiscreteDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=action, next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=next_action, possible_actions_mask=possible_actions_mask, possible_next_actions_mask=next_possible_actions_mask, reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=log_prob.exp(), max_num_actions=None, metrics=None, ), )
def as_discrete_maxq_training_batch(self): return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedDiscreteDqnInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=self.actions, next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=self.next_actions, possible_actions_mask=self.possible_actions_mask, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )