def _process_training_data(self, data, timing, wait_stats=None): self.is_training = True buffer, batch_size, samples, env_steps = data assert samples == batch_size * self.cfg.num_batches_per_iteration self.env_steps += env_steps experience_size = buffer.rewards.shape[0] stats = dict(learner_env_steps=self.env_steps, policy_id=self.policy_id) with timing.add_time('train'): discarding_rate = self._discarding_rate() self._update_pbt() train_stats = self._train(buffer, batch_size, experience_size, timing) if train_stats is not None: stats['train'] = train_stats if wait_stats is not None: wait_avg, wait_min, wait_max = wait_stats stats['train']['wait_avg'] = wait_avg stats['train']['wait_min'] = wait_min stats['train']['wait_max'] = wait_max stats['train'][ 'discarded_rollouts'] = self.num_discarded_rollouts stats['train']['discarding_rate'] = discarding_rate stats['stats'] = memory_stats('learner', self.device) self.is_training = False try: self.report_queue.put(stats) except Full: log.warning( 'Could not report training stats, the report queue is full!')
def _run(self): # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg signal.signal(signal.SIGINT, signal.SIG_IGN) psutil.Process().nice(min(self.cfg.default_niceness + 2, 20)) cuda_envvars(self.policy_id) torch.multiprocessing.set_sharing_strategy('file_system') timing = Timing() with timing.timeit('init'): # initialize the Torch modules log.info('Initializing model on the policy worker %d-%d...', self.policy_id, self.worker_idx) torch.set_num_threads(1) if self.cfg.device == 'gpu': # we should already see only one CUDA device, because of env vars assert torch.cuda.device_count() == 1 self.device = torch.device('cuda', index=0) else: self.device = torch.device('cpu') self.actor_critic = create_actor_critic(self.cfg, self.obs_space, self.action_space, timing) self.actor_critic.model_to_device(self.device) for p in self.actor_critic.parameters(): p.requires_grad = False # we don't train anything here log.info('Initialized model on the policy worker %d-%d!', self.policy_id, self.worker_idx) last_report = last_cache_cleanup = time.time() last_report_samples = 0 request_count = deque(maxlen=50) # very conservative limit on the minimum number of requests to wait for # this will almost guarantee that the system will continue collecting experience # at max rate even when 2/3 of workers are stuck for some reason (e.g. doing a long env reset) # Although if your workflow involves very lengthy operations that often freeze workers, it can be beneficial # to set min_num_requests to 1 (at a cost of potential inefficiency, i.e. policy worker will use very small # batches) min_num_requests = self.cfg.num_workers // ( self.cfg.num_policies * self.cfg.policy_workers_per_policy) min_num_requests //= 3 min_num_requests = max(1, min_num_requests) # Again, very conservative timer. Only wait a little bit, then continue operation. wait_for_min_requests = 0.025 while not self.terminate: try: while self.stop_experience_collection[self.policy_id]: with self.resume_experience_collection_cv: self.resume_experience_collection_cv.wait(timeout=0.05) waiting_started = time.time() while len(self.requests) < min_num_requests and time.time( ) - waiting_started < wait_for_min_requests: try: with timing.timeit('wait_policy'), timing.add_time( 'wait_policy_total'): policy_requests = self.policy_queue.get_many( timeout=0.005) self.requests.extend(policy_requests) except Empty: pass self._update_weights(timing) with timing.timeit('one_step'), timing.add_time( 'handle_policy_step'): if self.initialized: if len(self.requests) > 0: request_count.append(len(self.requests)) self._handle_policy_steps(timing) try: task_type, data = self.task_queue.get_nowait() # task from the task_queue if task_type == TaskType.INIT: self._init() elif task_type == TaskType.TERMINATE: self.terminate = True break elif task_type == TaskType.INIT_MODEL: self._init_model(data) self.task_queue.task_done() except Empty: pass if time.time() - last_report > 3.0 and 'one_step' in timing: timing_stats = dict(wait_policy=timing.wait_policy, step_policy=timing.one_step) samples_since_last_report = self.total_num_samples - last_report_samples stats = memory_stats('policy_worker', self.device) if len(request_count) > 0: stats['avg_request_count'] = np.mean(request_count) self.report_queue.put( dict( timing=timing_stats, samples=samples_since_last_report, policy_id=self.policy_id, stats=stats, )) last_report = time.time() last_report_samples = self.total_num_samples if time.time() - last_cache_cleanup > 300.0 or ( not self.cfg.benchmark and self.total_num_samples < 1000): if self.cfg.device == 'gpu': torch.cuda.empty_cache() last_cache_cleanup = time.time() except KeyboardInterrupt: log.warning('Keyboard interrupt detected on worker %d-%d', self.policy_id, self.worker_idx) self.terminate = True except: log.exception('Unknown exception on policy worker') self.terminate = True time.sleep(0.2) log.info('Policy worker avg. requests %.2f, timing: %s', np.mean(request_count), timing)