示例#1
0
 def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
     obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
         train_batch)
     obs = torch.tensor(obs).squeeze(2)
     action = torch.tensor(action).float()
     reward = torch.tensor(reward).unsqueeze(1)
     next_obs = torch.tensor(next_obs).squeeze(2)
     next_action = torch.tensor(next_action)
     not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float()
     idxs = torch.tensor(idxs)
     possible_actions_mask = torch.tensor(possible_actions_mask).float()
     log_prob = torch.tensor(log_prob)
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedPolicyNetworkInput(
             state=rlt.PreprocessedFeatureVector(float_features=obs),
             action=rlt.PreprocessedFeatureVector(float_features=action),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=next_obs),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=next_action),
             reward=reward,
             not_terminal=not_terinal,
             step=None,
             time_diff=None,
         ),
         extras=rlt.ExtraData(),
     )
示例#2
0
 def as_policy_network_training_batch(self):
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedPolicyNetworkInput(
             state=rlt.PreprocessedFeatureVector(float_features=self.states),
             action=rlt.PreprocessedFeatureVector(float_features=self.actions),
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states
             ),
             next_action=rlt.PreprocessedFeatureVector(
                 float_features=self.next_actions
             ),
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )