示例#1
0
 def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch:
     obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = (
         train_batch)
     obs = torch.tensor(obs).squeeze(2)
     action = torch.tensor(action)
     reward = torch.tensor(reward).unsqueeze(1)
     next_obs = torch.tensor(next_obs).squeeze(2)
     next_action = torch.tensor(next_action)
     not_terminal = 1.0 - torch.tensor(terminal).unsqueeze(1).float()
     possible_actions_mask = torch.tensor(possible_actions_mask)
     next_possible_actions_mask = not_terminal.repeat(1, num_actions)
     log_prob = torch.tensor(log_prob)
     assert (
         action.size(1) == num_actions
     ), f"action size(1) is {action.size(1)} while num_actions is {num_actions}"
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedDiscreteDqnInput(
             state=rlt.PreprocessedFeatureVector(float_features=obs),
             action=action,
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=next_obs),
             next_action=next_action,
             possible_actions_mask=possible_actions_mask,
             possible_next_actions_mask=next_possible_actions_mask,
             reward=reward,
             not_terminal=not_terminal,
             step=None,
             time_diff=None,
         ),
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
示例#2
0
 def as_discrete_maxq_training_batch(self):
     return rlt.PreprocessedTrainingBatch(
         training_input=rlt.PreprocessedDiscreteDqnInput(
             state=rlt.PreprocessedFeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.PreprocessedFeatureVector(
                 float_features=self.next_states
             ),
             next_action=self.next_actions,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )