def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=rlt.PreprocessedFeatureVector(float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions ), tiled_next_state=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, :state_dim ] ), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, state_dim: ] ), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) batch_size = obs.shape[0] obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action).to(torch.float32) reward = torch.tensor(reward).unsqueeze(1) not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8) possible_actions_mask = torch.ones_like(action).to(torch.bool) tiled_next_state = torch.repeat_interleave(next_obs, repeats=num_actions, axis=0) possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1) possible_next_actions_mask = not_terminal.repeat(1, num_actions).to( torch.bool) return rlt.PreprocessedTrainingBatch( rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), possible_actions=None, possible_actions_mask=possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=possible_next_actions), possible_next_actions_mask=possible_next_actions_mask, tiled_next_state=rlt.PreprocessedFeatureVector( float_features=tiled_next_state), reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )