コード例 #1
0
ファイル: meta_evaluator.py プロジェクト: songanz/garage
    def evaluate(self, algo, test_rollouts_per_task=None):
        """Evaluate the Meta-RL algorithm on the test tasks.

        Args:
            algo (garage.np.algos.MetaRLAlgorithm): The algorithm to evaluate.
            test_rollouts_per_task (int or None): Number of rollouts per task.

        """
        if test_rollouts_per_task is None:
            test_rollouts_per_task = self._n_test_rollouts
        adapted_trajectories = []
        logger.log('Sampling for adapation and meta-testing...')
        if self._test_sampler is None:
            self._test_sampler = LocalSampler.from_worker_factory(
                WorkerFactory(seed=get_seed(),
                              max_episode_length=self._max_episode_length,
                              n_workers=1,
                              worker_class=self._worker_class,
                              worker_args=self._worker_args),
                agents=algo.get_exploration_policy(),
                envs=self._test_task_sampler.sample(1))
        for env_up in self._test_task_sampler.sample(self._n_test_tasks):
            policy = algo.get_exploration_policy()
            traj = TrajectoryBatch.concatenate(*[
                self._test_sampler.obtain_samples(self._eval_itr, 1, policy,
                                                  env_up)
                for _ in range(self._n_exploration_traj)
            ])
            adapted_policy = algo.adapt_policy(policy, traj)
            adapted_traj = self._test_sampler.obtain_samples(
                self._eval_itr,
                test_rollouts_per_task * self._max_episode_length,
                adapted_policy)
            adapted_trajectories.append(adapted_traj)
        logger.log('Finished meta-testing...')

        if self._test_task_names is not None:
            name_map = dict(enumerate(self._test_task_names))
        else:
            name_map = None

        with tabular.prefix(self._prefix + '/' if self._prefix else ''):
            log_multitask_performance(
                self._eval_itr,
                TrajectoryBatch.concatenate(*adapted_trajectories),
                getattr(algo, 'discount', 1.0),
                name_map=name_map)
        self._eval_itr += 1
コード例 #2
0
    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Collect at least a given number transitions (timesteps).

        Args:
            itr(int): The current iteration number. Using this argument is
                deprecated.
            num_samples(int): Minimum number of transitions / timesteps to
                sample.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            garage.TrajectoryBatch: The batch of collected trajectories.

        """
        self._update_workers(agent_update, env_update)
        batches = []
        completed_samples = 0
        while True:
            for worker in self._workers:
                batch = worker.rollout()
                completed_samples += len(batch.actions)
                batches.append(batch)
                if completed_samples > num_samples:
                    return TrajectoryBatch.concatenate(*batches)
コード例 #3
0
    def _evaluate_policy(self, epoch):
        """Evaluate the performance of the policy via deterministic rollouts.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_trajectories
                trajectories

        """
        eval_trajs = []
        for _ in range(self._num_tasks):
            eval_trajs.append(
                obtain_evaluation_samples(
                    self.policy,
                    self._eval_env,
                    num_trajs=self._num_evaluation_trajectories))
        eval_trajs = TrajectoryBatch.concatenate(*eval_trajs)
        last_return = log_multitask_performance(epoch, eval_trajs,
                                                self._discount)
        return last_return
コード例 #4
0
    def obtain_exact_trajectories(self,
                                  n_traj_per_worker,
                                  agent_update,
                                  env_update=None):
        """Sample an exact number of trajectories per worker.

        Args:
            n_traj_per_worker (int): Exact number of trajectories to gather for
                each worker.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            TrajectoryBatch: Batch of gathered trajectories. Always in worker
                order. In other words, first all trajectories from worker 0,
                then all trajectories from worker 1, etc.

        """
        self._update_workers(agent_update, env_update)
        batches = []
        for worker in self._workers:
            for _ in range(n_traj_per_worker):
                batch = worker.rollout()
                batches.append(batch)
        return TrajectoryBatch.concatenate(*batches)
コード例 #5
0
ファイル: vec_worker.py プロジェクト: AzOuss96/garage
    def collect_rollout(self):
        """Collect all completed rollouts.

        Returns:
            garage.TrajectoryBatch: A batch of the trajectories completed since
                the last call to collect_rollout().

        """
        if len(self._completed_rollouts) == 1:
            result = self._completed_rollouts[0]
        else:
            result = TrajectoryBatch.concatenate(*self._completed_rollouts)
        self._completed_rollouts = []
        return result
コード例 #6
0
    def evaluate(self, algo):
        """Evaluate the Meta-RL algorithm on the test tasks.

        Args:
            algo (garage.np.algos.MetaRLAlgorithm): The algorithm to evaluate.

        """
        adapted_trajectories = []
        for env_up in self._test_task_sampler.sample(self._n_test_tasks):
            policy = algo.get_exploration_policy()
            traj = TrajectoryBatch.concatenate(*[
                self._test_sampler.obtain_samples(self._eval_itr, 1, policy,
                                                  env_up)
                for _ in range(self._n_exploration_traj)
            ])
            adapted_policy = algo.adapt_policy(policy, traj)
            adapted_traj = self._test_sampler.obtain_samples(
                self._eval_itr, 1, adapted_policy)
            adapted_trajectories.append(adapted_traj)
        log_performance(self._eval_itr,
                        TrajectoryBatch.concatenate(*adapted_trajectories),
                        getattr(algo, 'discount', 1.0),
                        prefix=self._prefix)
        self._eval_itr += 1
コード例 #7
0
    def collect_rollout(self):
        """Gather fragments from all in-progress rollouts.

        Returns:
            garage.TrajectoryBatch: A batch of the trajectory fragments.

        """
        for i, frag in enumerate(self._fragments):
            assert frag.env is self._envs[i]
            if len(frag.rewards) > 0:
                complete_frag = frag.to_batch()
                self._complete_fragments.append(complete_frag)
                self._fragments[i] = InProgressTrajectory(
                    frag.env, frag.last_obs)
        assert len(self._complete_fragments) > 0
        result = TrajectoryBatch.concatenate(*self._complete_fragments)
        self._complete_fragments = []
        return result
コード例 #8
0
ファイル: ray_sampler.py プロジェクト: yus-nas/garage
    def obtain_exact_trajectories(self,
                                  n_traj_per_worker,
                                  agent_update,
                                  env_update=None):
        """Sample an exact number of trajectories per worker.

        Args:
            n_traj_per_worker (int): Exact number of trajectories to gather for
                each worker.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            TrajectoryBatch: Batch of gathered trajectories. Always in worker
                order. In other words, first all trajectories from worker 0,
                then all trajectories from worker 1, etc.

        """
        active_workers = []
        pbar = ProgBarCounter(self._worker_factory.n_workers)
        trajectories = defaultdict(list)

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        while any(
                len(trajectories[i]) < n_traj_per_worker
                for i in range(self._worker_factory.n_workers)):
            # if there are workers still being updated, check
            # which ones are still updating and take the workers that
            # are done updating, and start collecting trajectories on
            # those workers.
            if updating_workers:
                updated, updating_workers = ray.wait(updating_workers,
                                                     num_returns=1,
                                                     timeout=0.1)
                upd = [ray.get(up) for up in updated]
                idle_worker_ids.extend(upd)

            # if there are idle workers, use them to collect trajectories
            # mark the newly busy workers as active
            while idle_worker_ids:
                idle_worker_id = idle_worker_ids.pop()
                worker = self._all_workers[idle_worker_id]
                active_workers.append(worker.rollout.remote())

            # check which workers are done/not done collecting a sample
            # if any are done, send them to process the collected trajectory
            # if they are not, keep checking if they are done
            ready, not_ready = ray.wait(active_workers,
                                        num_returns=1,
                                        timeout=0.001)
            active_workers = not_ready
            for result in ready:
                ready_worker_id, trajectory_batch = ray.get(result)
                pbar.inc(1)
                trajectories[ready_worker_id].append(trajectory_batch)
                if len(trajectories[ready_worker_id]) < n_traj_per_worker:
                    idle_worker_ids.append(ready_worker_id)
        pbar.stop()
        ordered_trajectories = list(
            itertools.chain(*[
                trajectories[i] for i in range(self._worker_factory.n_workers)
            ]))
        return TrajectoryBatch.concatenate(*ordered_trajectories)
コード例 #9
0
ファイル: ray_sampler.py プロジェクト: yus-nas/garage
    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Sample the policy for new trajectories.

        Args:
            itr(int): Iteration number.
            num_samples(int): Number of steps the the sampler should collect.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            TrajectoryBatch: Batch of gathered trajectories.

        """
        active_workers = []
        pbar = ProgBarCounter(num_samples)
        completed_samples = 0
        batches = []

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        while completed_samples < num_samples:
            # if there are workers still being updated, check
            # which ones are still updating and take the workers that
            # are done updating, and start collecting trajectories on
            # those workers.
            if updating_workers:
                updated, updating_workers = ray.wait(updating_workers,
                                                     num_returns=1,
                                                     timeout=0.1)
                upd = [ray.get(up) for up in updated]
                idle_worker_ids.extend(upd)

            # if there are idle workers, use them to collect trajectories
            # mark the newly busy workers as active
            while idle_worker_ids:
                idle_worker_id = idle_worker_ids.pop()
                worker = self._all_workers[idle_worker_id]
                active_workers.append(worker.rollout.remote())

            # check which workers are done/not done collecting a sample
            # if any are done, send them to process the collected trajectory
            # if they are not, keep checking if they are done
            ready, not_ready = ray.wait(active_workers,
                                        num_returns=1,
                                        timeout=0.001)
            active_workers = not_ready
            for result in ready:
                ready_worker_id, trajectory_batch = ray.get(result)
                idle_worker_ids.append(ready_worker_id)
                num_returned_samples = trajectory_batch.lengths.sum()
                completed_samples += num_returned_samples
                pbar.inc(num_returned_samples)
                batches.append(trajectory_batch)
        pbar.stop()
        return TrajectoryBatch.concatenate(*batches)
コード例 #10
0
    def obtain_exact_trajectories(self,
                                  n_traj_per_worker,
                                  agent_update,
                                  env_update=None):
        """Sample an exact number of trajectories per worker.

        Args:
            n_traj_per_worker (int): Exact number of trajectories to gather for
                each worker.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            TrajectoryBatch: Batch of gathered trajectories. Always in worker
                order. In other words, first all trajectories from worker 0,
                then all trajectories from worker 1, etc.

        Raises:
            AssertionError: On internal errors.

        """
        pbar = ProgBarCounter(self._factory.n_workers)
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)
        trajectories = defaultdict(list)

        while any(
                len(trajectories[i]) < n_traj_per_worker
                for i in range(self._factory.n_workers)):
            self._push_updates(updated_workers, agent_ups, env_ups)
            tag, contents = self._to_sampler.get()
            if tag == 'trajectory':
                batch, version, worker_n = contents
                if version == self._agent_version:
                    if len(trajectories[worker_n]) < n_traj_per_worker:
                        trajectories[worker_n].append(batch)
                    if len(trajectories[worker_n]) == n_traj_per_worker:
                        pbar.inc(1)
                        try:
                            self._to_worker[worker_n].put_nowait(('stop', ()))
                        except queue.Full:
                            pass
            else:
                raise AssertionError('Unknown tag {} with contents {}'.format(
                    tag, contents))

        for q in self._to_worker:
            try:
                q.put_nowait(('stop', ()))
            except queue.Full:
                pass
        pbar.stop()
        ordered_trajectories = list(
            itertools.chain(
                *[trajectories[i] for i in range(self._factory.n_workers)]))
        return TrajectoryBatch.concatenate(*ordered_trajectories)
コード例 #11
0
    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Collect at least a given number transitions (timesteps).

        Args:
            itr(int): The current iteration number. Using this argument is
                deprecated.
            num_samples(int): Minimum number of transitions / timesteps to
                sample.
            agent_update(object): Value which will be passed into the
                `agent_update_fn` before doing rollouts. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update(object): Value which will be passed into the
                `env_update_fn` before doing rollouts. If a list is passed in,
                it must have length exactly `factory.n_workers`, and will be
                spread across the workers.

        Returns:
            garage.TrajectoryBatch: The batch of collected trajectories.

        Raises:
            AssertionError: On internal errors.

        """
        del itr
        pbar = ProgBarCounter(num_samples)
        batches = []
        completed_samples = 0
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)

        while completed_samples < num_samples:
            self._push_updates(updated_workers, agent_ups, env_ups)
            for _ in range(self._factory.n_workers):
                try:
                    tag, contents = self._to_sampler.get_nowait()
                    if tag == 'trajectory':
                        batch, version, worker_n = contents
                        del worker_n
                        if version == self._agent_version:
                            batches.append(batch)
                            num_returned_samples = batch.lengths.sum()
                            completed_samples += num_returned_samples
                            pbar.inc(num_returned_samples)
                        else:
                            # Receiving paths from previous iterations is
                            # normal.  Potentially, we could gather them here,
                            # if an off-policy method wants them.
                            pass
                    else:
                        raise AssertionError(
                            'Unknown tag {} with contents {}'.format(
                                tag, contents))
                except queue.Empty:
                    pass
        for q in self._to_worker:
            try:
                q.put_nowait(('stop', ()))
            except queue.Full:
                pass
        pbar.stop()
        return TrajectoryBatch.concatenate(*batches)