def sample(self, n_tasks, with_replacement=False): """Sample a list of environment updates. Args: n_tasks (int): Number of updates to sample. with_replacement (bool): Whether tasks can repeat when sampled. Note that if more tasks are sampled than exist, then tasks may repeat, but only after every environment has been included at least once in this batch. Ignored for continuous task spaces. Returns: list[EnvUpdate]: Batch of sampled environment updates, which, when invoked on environments, will configure them with new tasks. See :py:class:`~EnvUpdate` for more information. """ return [ SetTaskUpdate(self._env_constructor, task) for task in self._env.sample_tasks(n_tasks) ]
def _obtain_samples(self, runner): """Obtain samples for each task before and after the fast-adaptation. Args: runner (LocalRunner): A local runner instance to obtain samples. Returns: tuple: Tuple of (all_samples, all_params). all_samples (list[MAMLEpisodeBatch]): A list of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. """ tasks = self._env.sample_tasks(self._meta_batch_size) all_samples = [[] for _ in range(len(tasks))] all_params = [] theta = dict(self._policy.named_parameters()) for i, task in enumerate(tasks): for j in range(self._num_grad_updates + 1): env_up = SetTaskUpdate(None, task=task) episodes = runner.obtain_samples(runner.step_itr, env_update=env_up) batch_samples = self._process_samples(episodes) all_samples[i].append(batch_samples) # The last iteration does only sampling but no adapting if j < self._num_grad_updates: # A grad need to be kept for the next grad update # Except for the last grad update require_grad = j < self._num_grad_updates - 1 self._adapt(batch_samples, set_grad=require_grad) all_params.append(dict(self._policy.named_parameters())) # Restore to pre-updated policy update_module_params(self._policy, theta) return all_samples, all_params
def sample(self, n_tasks, with_replacement=False): """Sample a list of environment updates. Note that this will always return environments in the same order, to make parallel sampling across workers efficient. If randomizing the environment order is required, shuffle the result of this method. Args: n_tasks (int): Number of updates to sample. Must be a multiple of the number of env classes in the benchmark (e.g. 1 for MT/ML1, 10 for MT10, 50 for MT50). Tasks for each environment will be grouped to be adjacent to each other. with_replacement (bool): Whether tasks can repeat when sampled. Since this cannot be easily implemented for an object pool, setting this to True results in ValueError. Raises: ValueError: If the number of requested tasks is not equal to the number of classes or the number of total tasks. Returns: list[EnvUpdate]: Batch of sampled environment updates, which, when invoked on environments, will configure them with new tasks. See :py:class:`~EnvUpdate` for more information. """ if n_tasks % len(self._classes) != 0: raise ValueError('For this benchmark, n_tasks must be a multiple ' f'of {len(self._classes)}') tasks_per_class = n_tasks // len(self._classes) updates = [] # Avoid pickling the entire task sampler into every EnvUpdate inner_wrapper = self._inner_wrapper add_env_onehot = self._add_env_onehot task_indices = self._task_indices def wrap(env, task): """Wrap an environment in a metaworld benchmark. Args: env (gym.Env): A metaworld / gym environment. task (metaworld.Task): A metaworld task. Returns: garage.Env: The wrapped environment. """ env = GymEnv(env, max_episode_length=env.max_path_length) env = TaskNameWrapper(env, task_name=task.env_name) if add_env_onehot: env = TaskOnehotWrapper(env, task_index=task_indices[task.env_name], n_total_tasks=len(task_indices)) if inner_wrapper is not None: env = inner_wrapper(env, task) return env for env_name, env in self._classes.items(): order_index = self._next_order_index for _ in range(tasks_per_class): task_index = self._task_orders[env_name][order_index] task = self._task_map[env_name][task_index] updates.append(SetTaskUpdate(env, task, wrap)) if with_replacement: order_index = np.random.randint(0, MW_TASKS_PER_ENV) else: order_index += 1 order_index %= MW_TASKS_PER_ENV self._next_order_index += tasks_per_class if self._next_order_index >= MW_TASKS_PER_ENV: self._next_order_index %= MW_TASKS_PER_ENV self._shuffle_tasks() return updates