コード例 #1
0
 def random_batch(self, batch_size):
     indices = self._sample_indices(batch_size)
     next_obs_idxs = []
     for i in indices:
         possible_next_obs_idxs = self._idx_to_future_obs_idx[i]
         # This is generally faster than random.choice. Makes you wonder what
         # random.choice is doing
         num_options = len(possible_next_obs_idxs)
         if num_options == 1:
             next_obs_i = 0
         else:
             if self.resampling_strategy == 'uniform':
                 next_obs_i = int(np.random.randint(0, num_options))
             elif self.resampling_strategy == 'truncated_geometric':
                 next_obs_i = int(truncated_geometric(
                     p=self.truncated_geom_factor/num_options,
                     truncate_threshold=num_options-1,
                     size=1,
                     new_value=0
                 ))
             else:
                 raise ValueError("Invalid resampling strategy: {}".format(
                     self.resampling_strategy
                 ))
         next_obs_idxs.append(possible_next_obs_idxs[next_obs_i])
     next_obs_idxs = np.array(next_obs_idxs)
     resampled_goals = self.env.convert_obs_to_goals(
         self._next_obs[next_obs_idxs]
     )
     num_goals_are_from_rollout = int(
         batch_size * self.fraction_goals_are_rollout_goals
     )
     if num_goals_are_from_rollout > 0:
         resampled_goals[:num_goals_are_from_rollout] = self._goals[
             indices[:num_goals_are_from_rollout]
         ]
     return dict(
         observations=self._observations[indices],
         actions=self._actions[indices],
         rewards=self._rewards[indices],
         terminals=self._terminals[indices],
         next_observations=self._next_obs[indices],
         goals_used_for_rollout=self._goals[indices],
         resampled_goals=resampled_goals,
         num_steps_left=self._num_steps_left[indices],
         indices=np.array(indices).reshape(-1, 1),
     )
コード例 #2
0
ファイル: tdm.py プロジェクト: Asap7772/rail-rl-franka-eval
 def _sample_taus_for_training(self, batch):
     if self.finite_horizon:
         if self.tau_sample_strategy == 'uniform':
             num_steps_left = np.random.randint(0, self.max_tau + 1,
                                                (self.batch_size, 1))
         elif self.tau_sample_strategy == 'truncated_geometric':
             num_steps_left = truncated_geometric(
                 p=self.truncated_geom_factor / self.max_tau,
                 truncate_threshold=self.max_tau,
                 size=(self.batch_size, 1),
                 new_value=0)
         elif self.tau_sample_strategy == 'no_resampling':
             num_steps_left = batch['num_steps_left']
         elif self.tau_sample_strategy == 'all_valid':
             num_steps_left = np.tile(np.arange(0, self.max_tau + 1),
                                      self.batch_size)
             num_steps_left = np.expand_dims(num_steps_left, 1)
         else:
             raise TypeError("Invalid tau_sample_strategy: {}".format(
                 self.tau_sample_strategy))
     else:
         num_steps_left = np.zeros((self.batch_size, 1))
     return num_steps_left
コード例 #3
0
 def random_batch_random_tau(self, batch_size, max_tau):
     indices = np.random.randint(0, self._size, batch_size)
     next_obs_idxs = []
     for i in indices:
         possible_next_obs_idxs = self._idx_to_future_obs_idx[i]
         # This is generally faster than random.choice. Makes you wonder what
         # random.choice is doing
         num_options = len(possible_next_obs_idxs)
         tau = np.random.randint(0, min(max_tau + 1, num_options))
         if num_options == 1:
             next_obs_i = 0
         else:
             if self.resampling_strategy == 'uniform':
                 next_obs_i = int(np.random.randint(0, tau + 1))
             elif self.resampling_strategy == 'truncated_geometric':
                 next_obs_i = int(
                     truncated_geometric(p=self.truncated_geom_factor / tau,
                                         truncate_threshold=num_options - 1,
                                         size=1,
                                         new_value=0))
             else:
                 raise ValueError("Invalid resampling strategy: {}".format(
                     self.resampling_strategy))
         next_obs_idxs.append(possible_next_obs_idxs[next_obs_i])
     next_obs_idxs = np.array(next_obs_idxs)
     training_goals = self.env.convert_obs_to_goals(
         self._next_obs[next_obs_idxs])
     return dict(
         observations=self._observations[indices],
         actions=self._actions[indices],
         rewards=self._rewards[indices],
         terminals=self._terminals[indices],
         next_observations=self._next_obs[indices],
         training_goals=training_goals,
         num_steps_left=self._num_steps_left[indices],
     )
コード例 #4
0
    def random_batch(self, batch_size):
        indices = self._sample_indices(batch_size)
        next_obs_idxs = []
        for i in indices:
            possible_next_obs_idxs = self._idx_to_future_obs_idx[i]
            # This is generally faster than random.choice. Makes you wonder what
            # random.choice is doing
            num_options = len(possible_next_obs_idxs)
            if num_options == 1:
                next_obs_i = 0
            else:
                if self.resampling_strategy == 'uniform':
                    next_obs_i = int(np.random.randint(0, num_options))
                elif self.resampling_strategy == 'truncated_geometric':
                    next_obs_i = int(
                        truncated_geometric(p=self.truncated_geom_factor /
                                            num_options,
                                            truncate_threshold=num_options - 1,
                                            size=1,
                                            new_value=0))
                else:
                    raise ValueError("Invalid resampling strategy: {}".format(
                        self.resampling_strategy))
            next_obs_idxs.append(possible_next_obs_idxs[next_obs_i])
        next_obs_idxs = np.array(next_obs_idxs)
        resampled_goals = self.env.convert_obs_to_goals(
            self._next_obs[next_obs_idxs])
        num_goals_are_from_rollout = int(batch_size *
                                         self.fraction_goals_are_rollout_goals)
        if num_goals_are_from_rollout > 0:
            resampled_goals[:num_goals_are_from_rollout] = self._goals[
                indices[:num_goals_are_from_rollout]]
        # recompute rewards
        new_obs = self._observations[indices]
        new_next_obs = self._next_obs[indices]
        new_actions = self._actions[indices]
        new_rewards = self._rewards[indices].copy()  # needs to be recomputed
        env_info_dicts = [self.rebuild_env_info_dict(idx) for idx in indices]
        random_numbers = np.random.rand(batch_size)
        for i in range(batch_size):
            if random_numbers[i] < self.fraction_resampled_goals_are_env_goals:
                resampled_goals[i, :] = self.env.sample_goal_for_rollout(
                )  # env_goals[i, :]

            new_reward = self.env.compute_her_reward_np(
                new_obs[i, :],
                new_actions[i, :],
                new_next_obs[i, :],
                resampled_goals[i, :],
                env_info_dicts[i],
            )
            new_rewards[i] = new_reward
        # new_rewards = self.env.computer_her_reward_np_batch(
        #     new_obs,
        #     new_actions,
        #     new_next_obs,
        #     resampled_goals,
        #     env_infos,
        # )

        batch = dict(
            observations=new_obs,
            actions=new_actions,
            rewards=new_rewards,
            terminals=self._terminals[indices],
            next_observations=new_next_obs,
            goals_used_for_rollout=self._goals[indices],
            resampled_goals=resampled_goals,
            num_steps_left=self._num_steps_left[indices],
            indices=np.array(indices).reshape(-1, 1),
            goals=resampled_goals,
        )
        for key in self._env_info_keys:
            assert key not in batch.keys()
            batch[key] = self._env_infos[key][indices]
        return batch
コード例 #5
0
import matplotlib.pyplot as plt
from railrl.misc.np_util import truncated_geometric

truncated_geom_factor = 1
max_tau = 5
batch_size = 10000

num_steps_left = truncated_geometric(
    p=truncated_geom_factor / max_tau,
    truncate_threshold=max_tau,
    size=(batch_size, 1),
    new_value=0,
)
print(num_steps_left.max())

a = plt.hist(num_steps_left, bins=max_tau)
print(a)
plt.show()