Пример #1
0
 def obtain_exact_trajectories(self,
                               n_traj_per_worker,
                               agent_update,
                               env_update=None):
     self._update_workers(agent_update, env_update)
     batches = []
     for worker in self._workers:
         for _ in range(n_traj_per_worker):
             batch = worker.rollout()
             batches.append(batch)
     return SkillTrajectoryBatch.concatenate(*batches)
Пример #2
0
 def obtain_samples(self, itr, num_samples, agent_update, env_update=None,
                     skill=None):
     self._update_workers(agent_update, env_update)
     batches = []
     completed_samples = 0
     while True:
         for worker in self._workers:
             if not isinstance(worker, SkillWorker) and not isinstance(worker, KantWorker):
                 raise ValueError('Worker used by Local Skill Sampler class'
                                  ' must be a Skill/Kant Worker object, but got '
                                  '{}'.format(type(worker)))
             batch = worker.rollout(skill)
             completed_samples += len(batch.actions)
             batches.append(batch)
             if completed_samples >= num_samples:
                 return SkillTrajectoryBatch.concatenate(*batches)
Пример #3
0
    def _obtain_evaluation_samples(self, env, num_trajs=100):
        paths = []
        max_path_length = self.max_eval_path_length
        if max_path_length is None:
            max_path_length = self.max_path_length
        # Use a finite length rollout for evaluation.
        if max_path_length is None or np.isinf(max_path_length):
            max_path_length = 1000

        for _ in range(num_trajs):
            path = self._rollout(env,
                                 self.policy,
                                 max_path_length=max_path_length,
                                 deterministic=True)
            paths.append(path)

        return SkillTrajectoryBatch.from_trajectory_list(
            self.env_spec, self.skills_num, paths)
Пример #4
0
    def collect_rollout(self):
        states = self._states
        self._states = []
        last_states = self._last_states
        self._last_states = []
        skills = self._skills
        self._skills = []
        actions = self._actions
        self._actions = []
        rewards = self._rewards
        self._rewards = []
        terminals = self._terminals
        self._terminals = []
        env_infos = self._env_infos
        self._env_infos = defaultdict(list)
        agent_infos = self._agent_infos
        self._agent_infos = defaultdict(list)
        for k, v in agent_infos.items():
            if k == "dist":
                continue
            agent_infos[k] = np.asarray(v)
        for k, v in env_infos.items():
            env_infos[k] = np.asarray(v)
        lengths = self._lengths
        self._lengths = []

        # print(np.asarray(skills))
        return SkillTrajectoryBatch(
            env_spec=self.env.spec,
            num_skills=self._num_skills,
            skills=np.asarray(skills).reshape((np.asarray(skills).shape[0], )),
            states=np.asarray(states),
            last_states=np.asarray(last_states),
            actions=np.asarray(actions),
            env_rewards=np.asarray(rewards),  # env_rewards
            terminals=np.asarray(terminals),
            env_infos=dict(env_infos),
            agent_infos=dict(agent_infos),
            lengths=np.asarray(lengths, dtype='i'))