def obtain_exact_trajectories(self, n_traj_per_worker, agent_update, env_update=None): self._update_workers(agent_update, env_update) batches = [] for worker in self._workers: for _ in range(n_traj_per_worker): batch = worker.rollout() batches.append(batch) return SkillTrajectoryBatch.concatenate(*batches)
def obtain_samples(self, itr, num_samples, agent_update, env_update=None, skill=None): self._update_workers(agent_update, env_update) batches = [] completed_samples = 0 while True: for worker in self._workers: if not isinstance(worker, SkillWorker) and not isinstance(worker, KantWorker): raise ValueError('Worker used by Local Skill Sampler class' ' must be a Skill/Kant Worker object, but got ' '{}'.format(type(worker))) batch = worker.rollout(skill) completed_samples += len(batch.actions) batches.append(batch) if completed_samples >= num_samples: return SkillTrajectoryBatch.concatenate(*batches)
def _obtain_evaluation_samples(self, env, num_trajs=100): paths = [] max_path_length = self.max_eval_path_length if max_path_length is None: max_path_length = self.max_path_length # Use a finite length rollout for evaluation. if max_path_length is None or np.isinf(max_path_length): max_path_length = 1000 for _ in range(num_trajs): path = self._rollout(env, self.policy, max_path_length=max_path_length, deterministic=True) paths.append(path) return SkillTrajectoryBatch.from_trajectory_list( self.env_spec, self.skills_num, paths)
def collect_rollout(self): states = self._states self._states = [] last_states = self._last_states self._last_states = [] skills = self._skills self._skills = [] actions = self._actions self._actions = [] rewards = self._rewards self._rewards = [] terminals = self._terminals self._terminals = [] env_infos = self._env_infos self._env_infos = defaultdict(list) agent_infos = self._agent_infos self._agent_infos = defaultdict(list) for k, v in agent_infos.items(): if k == "dist": continue agent_infos[k] = np.asarray(v) for k, v in env_infos.items(): env_infos[k] = np.asarray(v) lengths = self._lengths self._lengths = [] # print(np.asarray(skills)) return SkillTrajectoryBatch( env_spec=self.env.spec, num_skills=self._num_skills, skills=np.asarray(skills).reshape((np.asarray(skills).shape[0], )), states=np.asarray(states), last_states=np.asarray(last_states), actions=np.asarray(actions), env_rewards=np.asarray(rewards), # env_rewards terminals=np.asarray(terminals), env_infos=dict(env_infos), agent_infos=dict(agent_infos), lengths=np.asarray(lengths, dtype='i'))