def reset(self, report_queue): """ Do the very first reset for all environments in a vector. Populate shared memory with initial obs. Note that this is called only once, at the very beginning of training. After this the envs should auto-reset. :param report_queue: we use report queue to monitor reset progress (see appo.py). This can be a lengthy process. :return: first requests for policy workers (to generate actions for the very first env step) """ for env_i, e in enumerate(self.envs): observations = e.reset() if self.cfg.decorrelate_envs_on_one_worker: env_i_split = self.num_envs * self.split_idx + env_i decorrelate_steps = self.cfg.rollout * env_i_split + self.cfg.rollout * random.randint(0, 4) log.info('Decorrelating experience for %d frames...', decorrelate_steps) for decorrelate_step in range(decorrelate_steps): actions = [e.action_space.sample() for _ in range(self.num_agents)] observations, rew, dones, info = e.step(actions) for agent_i, obs in enumerate(observations): actor_state = self.actor_states[env_i][agent_i] actor_state.set_trajectory_data(dict(obs=obs), self.traj_buffer_idx, self.rollout_step) # rnn state is already initialized at zero # log.debug( # 'Reset progress w:%d-%d finished %d/%d, still initializing envs...', # self.worker_idx, self.split_idx, env_i + 1, len(self.envs), # ) safe_put(report_queue, dict(initialized_env=(self.worker_idx, self.split_idx, env_i)), queue_name='report') policy_request = self._format_policy_request() return policy_request
def _handle_reset(self): """ Reset all envs, one split at a time (double-buffering), and send requests to policy workers to get actions for the very first env step. """ for split_idx, env_runner in enumerate(self.env_runners): policy_inputs = env_runner.reset(self.report_queue) self._enqueue_policy_request(split_idx, policy_inputs) log.info('Finished reset for worker %d', self.worker_idx) safe_put(self.report_queue, dict(finished_reset=self.worker_idx), queue_name='report')
def _run(self): """ Main loop of the actor worker (rollout worker). Process tasks (mainly ROLLOUT_STEP) until we get the termination signal, which usually means end of training. Currently there is no mechanism to restart dead workers if something bad happens during training. We can only retry on the initial reset(). This is definitely something to work on. """ log.info('Initializing vector env runner %d...', self.worker_idx) # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg signal.signal(signal.SIGINT, signal.SIG_IGN) if self.cfg.actor_worker_gpus: set_gpus_for_process( self.worker_idx, num_gpus_per_process=1, process_type='actor', gpu_mask=self.cfg.actor_worker_gpus, ) torch.multiprocessing.set_sharing_strategy('file_system') timing = Timing() last_report = time.time() with torch.no_grad(): while not self.terminate: try: try: with timing.add_time('waiting'), timing.timeit('wait_actor'): tasks = self.task_queue.get_many(timeout=0.1) except Empty: tasks = [] for task in tasks: task_type, data = task if task_type == TaskType.INIT: self._init() continue if task_type == TaskType.TERMINATE: self._terminate() break # handling actual workload if task_type == TaskType.ROLLOUT_STEP: if 'work' not in timing: timing.waiting = 0 # measure waiting only after real work has started with timing.add_time('work'), timing.timeit('one_step'): self._advance_rollouts(data, timing) elif task_type == TaskType.RESET: with timing.add_time('reset'): self._handle_reset() elif task_type == TaskType.PBT: self._process_pbt_task(data) elif task_type == TaskType.UPDATE_ENV_STEPS: for env in self.env_runners: env.update_env_steps(data) if time.time() - last_report > 5.0 and 'one_step' in timing: timing_stats = dict(wait_actor=timing.wait_actor, step_actor=timing.one_step) memory_mb = memory_consumption_mb() stats = dict(memory_actor=memory_mb) safe_put(self.report_queue, dict(timing=timing_stats, stats=stats), queue_name='report') last_report = time.time() except RuntimeError as exc: log.warning('Error while processing data w: %d, exception: %s', self.worker_idx, exc) log.warning('Terminate process...') self.terminate = True safe_put(self.report_queue, dict(critical_error=self.worker_idx), queue_name='report') except KeyboardInterrupt: self.terminate = True except: log.exception('Unknown exception in rollout worker') self.terminate = True if self.worker_idx <= 1: time.sleep(0.1) log.info( 'Env runner %d, CPU aff. %r, rollouts %d: timing %s', self.worker_idx, psutil.Process().cpu_affinity(), self.num_complete_rollouts, timing, )
def _report_stats(self, stats): for report in stats: safe_put(self.report_queue, report, queue_name='report')