def random_batch(self, batch_size): indices = self._sample_indices(batch_size) next_obs_idxs = [] for i in indices: possible_next_obs_idxs = self._idx_to_future_obs_idx[i] # This is generally faster than random.choice. Makes you wonder what # random.choice is doing num_options = len(possible_next_obs_idxs) if num_options == 1: next_obs_i = 0 else: if self.resampling_strategy == 'uniform': next_obs_i = int(np.random.randint(0, num_options)) elif self.resampling_strategy == 'truncated_geometric': next_obs_i = int(truncated_geometric( p=self.truncated_geom_factor/num_options, truncate_threshold=num_options-1, size=1, new_value=0 )) else: raise ValueError("Invalid resampling strategy: {}".format( self.resampling_strategy )) next_obs_idxs.append(possible_next_obs_idxs[next_obs_i]) next_obs_idxs = np.array(next_obs_idxs) resampled_goals = self.env.convert_obs_to_goals( self._next_obs[next_obs_idxs] ) num_goals_are_from_rollout = int( batch_size * self.fraction_goals_are_rollout_goals ) if num_goals_are_from_rollout > 0: resampled_goals[:num_goals_are_from_rollout] = self._goals[ indices[:num_goals_are_from_rollout] ] return dict( observations=self._observations[indices], actions=self._actions[indices], rewards=self._rewards[indices], terminals=self._terminals[indices], next_observations=self._next_obs[indices], goals_used_for_rollout=self._goals[indices], resampled_goals=resampled_goals, num_steps_left=self._num_steps_left[indices], indices=np.array(indices).reshape(-1, 1), )
def _sample_taus_for_training(self, batch): if self.finite_horizon: if self.tau_sample_strategy == 'uniform': num_steps_left = np.random.randint(0, self.max_tau + 1, (self.batch_size, 1)) elif self.tau_sample_strategy == 'truncated_geometric': num_steps_left = truncated_geometric( p=self.truncated_geom_factor / self.max_tau, truncate_threshold=self.max_tau, size=(self.batch_size, 1), new_value=0) elif self.tau_sample_strategy == 'no_resampling': num_steps_left = batch['num_steps_left'] elif self.tau_sample_strategy == 'all_valid': num_steps_left = np.tile(np.arange(0, self.max_tau + 1), self.batch_size) num_steps_left = np.expand_dims(num_steps_left, 1) else: raise TypeError("Invalid tau_sample_strategy: {}".format( self.tau_sample_strategy)) else: num_steps_left = np.zeros((self.batch_size, 1)) return num_steps_left
def random_batch_random_tau(self, batch_size, max_tau): indices = np.random.randint(0, self._size, batch_size) next_obs_idxs = [] for i in indices: possible_next_obs_idxs = self._idx_to_future_obs_idx[i] # This is generally faster than random.choice. Makes you wonder what # random.choice is doing num_options = len(possible_next_obs_idxs) tau = np.random.randint(0, min(max_tau + 1, num_options)) if num_options == 1: next_obs_i = 0 else: if self.resampling_strategy == 'uniform': next_obs_i = int(np.random.randint(0, tau + 1)) elif self.resampling_strategy == 'truncated_geometric': next_obs_i = int( truncated_geometric(p=self.truncated_geom_factor / tau, truncate_threshold=num_options - 1, size=1, new_value=0)) else: raise ValueError("Invalid resampling strategy: {}".format( self.resampling_strategy)) next_obs_idxs.append(possible_next_obs_idxs[next_obs_i]) next_obs_idxs = np.array(next_obs_idxs) training_goals = self.env.convert_obs_to_goals( self._next_obs[next_obs_idxs]) return dict( observations=self._observations[indices], actions=self._actions[indices], rewards=self._rewards[indices], terminals=self._terminals[indices], next_observations=self._next_obs[indices], training_goals=training_goals, num_steps_left=self._num_steps_left[indices], )
def random_batch(self, batch_size): indices = self._sample_indices(batch_size) next_obs_idxs = [] for i in indices: possible_next_obs_idxs = self._idx_to_future_obs_idx[i] # This is generally faster than random.choice. Makes you wonder what # random.choice is doing num_options = len(possible_next_obs_idxs) if num_options == 1: next_obs_i = 0 else: if self.resampling_strategy == 'uniform': next_obs_i = int(np.random.randint(0, num_options)) elif self.resampling_strategy == 'truncated_geometric': next_obs_i = int( truncated_geometric(p=self.truncated_geom_factor / num_options, truncate_threshold=num_options - 1, size=1, new_value=0)) else: raise ValueError("Invalid resampling strategy: {}".format( self.resampling_strategy)) next_obs_idxs.append(possible_next_obs_idxs[next_obs_i]) next_obs_idxs = np.array(next_obs_idxs) resampled_goals = self.env.convert_obs_to_goals( self._next_obs[next_obs_idxs]) num_goals_are_from_rollout = int(batch_size * self.fraction_goals_are_rollout_goals) if num_goals_are_from_rollout > 0: resampled_goals[:num_goals_are_from_rollout] = self._goals[ indices[:num_goals_are_from_rollout]] # recompute rewards new_obs = self._observations[indices] new_next_obs = self._next_obs[indices] new_actions = self._actions[indices] new_rewards = self._rewards[indices].copy() # needs to be recomputed env_info_dicts = [self.rebuild_env_info_dict(idx) for idx in indices] random_numbers = np.random.rand(batch_size) for i in range(batch_size): if random_numbers[i] < self.fraction_resampled_goals_are_env_goals: resampled_goals[i, :] = self.env.sample_goal_for_rollout( ) # env_goals[i, :] new_reward = self.env.compute_her_reward_np( new_obs[i, :], new_actions[i, :], new_next_obs[i, :], resampled_goals[i, :], env_info_dicts[i], ) new_rewards[i] = new_reward # new_rewards = self.env.computer_her_reward_np_batch( # new_obs, # new_actions, # new_next_obs, # resampled_goals, # env_infos, # ) batch = dict( observations=new_obs, actions=new_actions, rewards=new_rewards, terminals=self._terminals[indices], next_observations=new_next_obs, goals_used_for_rollout=self._goals[indices], resampled_goals=resampled_goals, num_steps_left=self._num_steps_left[indices], indices=np.array(indices).reshape(-1, 1), goals=resampled_goals, ) for key in self._env_info_keys: assert key not in batch.keys() batch[key] = self._env_infos[key][indices] return batch
import matplotlib.pyplot as plt from railrl.misc.np_util import truncated_geometric truncated_geom_factor = 1 max_tau = 5 batch_size = 10000 num_steps_left = truncated_geometric( p=truncated_geom_factor / max_tau, truncate_threshold=max_tau, size=(batch_size, 1), new_value=0, ) print(num_steps_left.max()) a = plt.hist(num_steps_left, bins=max_tau) print(a) plt.show()