def evaluate(self, algo, test_rollouts_per_task=None): """Evaluate the Meta-RL algorithm on the test tasks. Args: algo (garage.np.algos.MetaRLAlgorithm): The algorithm to evaluate. test_rollouts_per_task (int or None): Number of rollouts per task. """ if test_rollouts_per_task is None: test_rollouts_per_task = self._n_test_rollouts adapted_trajectories = [] logger.log('Sampling for adapation and meta-testing...') if self._test_sampler is None: self._test_sampler = LocalSampler.from_worker_factory( WorkerFactory(seed=get_seed(), max_episode_length=self._max_episode_length, n_workers=1, worker_class=self._worker_class, worker_args=self._worker_args), agents=algo.get_exploration_policy(), envs=self._test_task_sampler.sample(1)) for env_up in self._test_task_sampler.sample(self._n_test_tasks): policy = algo.get_exploration_policy() traj = TrajectoryBatch.concatenate(*[ self._test_sampler.obtain_samples(self._eval_itr, 1, policy, env_up) for _ in range(self._n_exploration_traj) ]) adapted_policy = algo.adapt_policy(policy, traj) adapted_traj = self._test_sampler.obtain_samples( self._eval_itr, test_rollouts_per_task * self._max_episode_length, adapted_policy) adapted_trajectories.append(adapted_traj) logger.log('Finished meta-testing...') if self._test_task_names is not None: name_map = dict(enumerate(self._test_task_names)) else: name_map = None with tabular.prefix(self._prefix + '/' if self._prefix else ''): log_multitask_performance( self._eval_itr, TrajectoryBatch.concatenate(*adapted_trajectories), getattr(algo, 'discount', 1.0), name_map=name_map) self._eval_itr += 1
def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Collect at least a given number transitions (timesteps). Args: itr(int): The current iteration number. Using this argument is deprecated. num_samples(int): Minimum number of transitions / timesteps to sample. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: garage.TrajectoryBatch: The batch of collected trajectories. """ self._update_workers(agent_update, env_update) batches = [] completed_samples = 0 while True: for worker in self._workers: batch = worker.rollout() completed_samples += len(batch.actions) batches.append(batch) if completed_samples > num_samples: return TrajectoryBatch.concatenate(*batches)
def _evaluate_policy(self, epoch): """Evaluate the performance of the policy via deterministic rollouts. Statistics such as (average) discounted return and success rate are recorded. Args: epoch (int): The current training epoch. Returns: float: The average return across self._num_evaluation_trajectories trajectories """ eval_trajs = [] for _ in range(self._num_tasks): eval_trajs.append( obtain_evaluation_samples( self.policy, self._eval_env, num_trajs=self._num_evaluation_trajectories)) eval_trajs = TrajectoryBatch.concatenate(*eval_trajs) last_return = log_multitask_performance(epoch, eval_trajs, self._discount) return last_return
def obtain_exact_trajectories(self, n_traj_per_worker, agent_update, env_update=None): """Sample an exact number of trajectories per worker. Args: n_traj_per_worker (int): Exact number of trajectories to gather for each worker. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. Always in worker order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc. """ self._update_workers(agent_update, env_update) batches = [] for worker in self._workers: for _ in range(n_traj_per_worker): batch = worker.rollout() batches.append(batch) return TrajectoryBatch.concatenate(*batches)
def collect_rollout(self): """Collect all completed rollouts. Returns: garage.TrajectoryBatch: A batch of the trajectories completed since the last call to collect_rollout(). """ if len(self._completed_rollouts) == 1: result = self._completed_rollouts[0] else: result = TrajectoryBatch.concatenate(*self._completed_rollouts) self._completed_rollouts = [] return result
def evaluate(self, algo): """Evaluate the Meta-RL algorithm on the test tasks. Args: algo (garage.np.algos.MetaRLAlgorithm): The algorithm to evaluate. """ adapted_trajectories = [] for env_up in self._test_task_sampler.sample(self._n_test_tasks): policy = algo.get_exploration_policy() traj = TrajectoryBatch.concatenate(*[ self._test_sampler.obtain_samples(self._eval_itr, 1, policy, env_up) for _ in range(self._n_exploration_traj) ]) adapted_policy = algo.adapt_policy(policy, traj) adapted_traj = self._test_sampler.obtain_samples( self._eval_itr, 1, adapted_policy) adapted_trajectories.append(adapted_traj) log_performance(self._eval_itr, TrajectoryBatch.concatenate(*adapted_trajectories), getattr(algo, 'discount', 1.0), prefix=self._prefix) self._eval_itr += 1
def collect_rollout(self): """Gather fragments from all in-progress rollouts. Returns: garage.TrajectoryBatch: A batch of the trajectory fragments. """ for i, frag in enumerate(self._fragments): assert frag.env is self._envs[i] if len(frag.rewards) > 0: complete_frag = frag.to_batch() self._complete_fragments.append(complete_frag) self._fragments[i] = InProgressTrajectory( frag.env, frag.last_obs) assert len(self._complete_fragments) > 0 result = TrajectoryBatch.concatenate(*self._complete_fragments) self._complete_fragments = [] return result
def obtain_exact_trajectories(self, n_traj_per_worker, agent_update, env_update=None): """Sample an exact number of trajectories per worker. Args: n_traj_per_worker (int): Exact number of trajectories to gather for each worker. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. Always in worker order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc. """ active_workers = [] pbar = ProgBarCounter(self._worker_factory.n_workers) trajectories = defaultdict(list) # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) while any( len(trajectories[i]) < n_traj_per_worker for i in range(self._worker_factory.n_workers)): # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting trajectories on # those workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect trajectories # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected trajectory # if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, trajectory_batch = ray.get(result) pbar.inc(1) trajectories[ready_worker_id].append(trajectory_batch) if len(trajectories[ready_worker_id]) < n_traj_per_worker: idle_worker_ids.append(ready_worker_id) pbar.stop() ordered_trajectories = list( itertools.chain(*[ trajectories[i] for i in range(self._worker_factory.n_workers) ])) return TrajectoryBatch.concatenate(*ordered_trajectories)
def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Sample the policy for new trajectories. Args: itr(int): Iteration number. num_samples(int): Number of steps the the sampler should collect. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. """ active_workers = [] pbar = ProgBarCounter(num_samples) completed_samples = 0 batches = [] # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) while completed_samples < num_samples: # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting trajectories on # those workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect trajectories # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected trajectory # if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, trajectory_batch = ray.get(result) idle_worker_ids.append(ready_worker_id) num_returned_samples = trajectory_batch.lengths.sum() completed_samples += num_returned_samples pbar.inc(num_returned_samples) batches.append(trajectory_batch) pbar.stop() return TrajectoryBatch.concatenate(*batches)
def obtain_exact_trajectories(self, n_traj_per_worker, agent_update, env_update=None): """Sample an exact number of trajectories per worker. Args: n_traj_per_worker (int): Exact number of trajectories to gather for each worker. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. Always in worker order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc. Raises: AssertionError: On internal errors. """ pbar = ProgBarCounter(self._factory.n_workers) self._agent_version += 1 updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) env_ups = self._factory.prepare_worker_messages(env_update) trajectories = defaultdict(list) while any( len(trajectories[i]) < n_traj_per_worker for i in range(self._factory.n_workers)): self._push_updates(updated_workers, agent_ups, env_ups) tag, contents = self._to_sampler.get() if tag == 'trajectory': batch, version, worker_n = contents if version == self._agent_version: if len(trajectories[worker_n]) < n_traj_per_worker: trajectories[worker_n].append(batch) if len(trajectories[worker_n]) == n_traj_per_worker: pbar.inc(1) try: self._to_worker[worker_n].put_nowait(('stop', ())) except queue.Full: pass else: raise AssertionError('Unknown tag {} with contents {}'.format( tag, contents)) for q in self._to_worker: try: q.put_nowait(('stop', ())) except queue.Full: pass pbar.stop() ordered_trajectories = list( itertools.chain( *[trajectories[i] for i in range(self._factory.n_workers)])) return TrajectoryBatch.concatenate(*ordered_trajectories)
def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Collect at least a given number transitions (timesteps). Args: itr(int): The current iteration number. Using this argument is deprecated. num_samples(int): Minimum number of transitions / timesteps to sample. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: garage.TrajectoryBatch: The batch of collected trajectories. Raises: AssertionError: On internal errors. """ del itr pbar = ProgBarCounter(num_samples) batches = [] completed_samples = 0 self._agent_version += 1 updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) env_ups = self._factory.prepare_worker_messages(env_update) while completed_samples < num_samples: self._push_updates(updated_workers, agent_ups, env_ups) for _ in range(self._factory.n_workers): try: tag, contents = self._to_sampler.get_nowait() if tag == 'trajectory': batch, version, worker_n = contents del worker_n if version == self._agent_version: batches.append(batch) num_returned_samples = batch.lengths.sum() completed_samples += num_returned_samples pbar.inc(num_returned_samples) else: # Receiving paths from previous iterations is # normal. Potentially, we could gather them here, # if an off-policy method wants them. pass else: raise AssertionError( 'Unknown tag {} with contents {}'.format( tag, contents)) except queue.Empty: pass for q in self._to_worker: try: q.put_nowait(('stop', ())) except queue.Full: pass pbar.stop() return TrajectoryBatch.concatenate(*batches)