def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() idxs = torch.tensor(idxs) possible_actions_mask = torch.tensor(possible_actions_mask).float() log_prob = torch.tensor(log_prob) return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), reward=reward, not_terminal=not_terinal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def as_policy_network_training_batch(self): return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=rlt.PreprocessedFeatureVector(float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions ), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )