예제 #1
0
    def _process_training_data(self, data, timing, wait_stats=None):
        self.is_training = True

        buffer, batch_size, samples, env_steps = data
        assert samples == batch_size * self.cfg.num_batches_per_iteration

        self.env_steps += env_steps
        experience_size = buffer.rewards.shape[0]

        stats = dict(learner_env_steps=self.env_steps,
                     policy_id=self.policy_id)

        with timing.add_time('train'):
            discarding_rate = self._discarding_rate()

            self._update_pbt()

            train_stats = self._train(buffer, batch_size, experience_size,
                                      timing)

            if train_stats is not None:
                stats['train'] = train_stats

                if wait_stats is not None:
                    wait_avg, wait_min, wait_max = wait_stats
                    stats['train']['wait_avg'] = wait_avg
                    stats['train']['wait_min'] = wait_min
                    stats['train']['wait_max'] = wait_max

                stats['train'][
                    'discarded_rollouts'] = self.num_discarded_rollouts
                stats['train']['discarding_rate'] = discarding_rate

                stats['stats'] = memory_stats('learner', self.device)

        self.is_training = False

        try:
            self.report_queue.put(stats)
        except Full:
            log.warning(
                'Could not report training stats, the report queue is full!')
예제 #2
0
    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        psutil.Process().nice(min(self.cfg.default_niceness + 2, 20))

        cuda_envvars(self.policy_id)
        torch.multiprocessing.set_sharing_strategy('file_system')

        timing = Timing()

        with timing.timeit('init'):
            # initialize the Torch modules
            log.info('Initializing model on the policy worker %d-%d...',
                     self.policy_id, self.worker_idx)

            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')

            self.actor_critic = create_actor_critic(self.cfg, self.obs_space,
                                                    self.action_space, timing)
            self.actor_critic.model_to_device(self.device)
            for p in self.actor_critic.parameters():
                p.requires_grad = False  # we don't train anything here

            log.info('Initialized model on the policy worker %d-%d!',
                     self.policy_id, self.worker_idx)

        last_report = last_cache_cleanup = time.time()
        last_report_samples = 0
        request_count = deque(maxlen=50)

        # very conservative limit on the minimum number of requests to wait for
        # this will almost guarantee that the system will continue collecting experience
        # at max rate even when 2/3 of workers are stuck for some reason (e.g. doing a long env reset)
        # Although if your workflow involves very lengthy operations that often freeze workers, it can be beneficial
        # to set min_num_requests to 1 (at a cost of potential inefficiency, i.e. policy worker will use very small
        # batches)
        min_num_requests = self.cfg.num_workers // (
            self.cfg.num_policies * self.cfg.policy_workers_per_policy)
        min_num_requests //= 3
        min_num_requests = max(1, min_num_requests)

        # Again, very conservative timer. Only wait a little bit, then continue operation.
        wait_for_min_requests = 0.025

        while not self.terminate:
            try:
                while self.stop_experience_collection[self.policy_id]:
                    with self.resume_experience_collection_cv:
                        self.resume_experience_collection_cv.wait(timeout=0.05)

                waiting_started = time.time()
                while len(self.requests) < min_num_requests and time.time(
                ) - waiting_started < wait_for_min_requests:
                    try:
                        with timing.timeit('wait_policy'), timing.add_time(
                                'wait_policy_total'):
                            policy_requests = self.policy_queue.get_many(
                                timeout=0.005)
                        self.requests.extend(policy_requests)
                    except Empty:
                        pass

                self._update_weights(timing)

                with timing.timeit('one_step'), timing.add_time(
                        'handle_policy_step'):
                    if self.initialized:
                        if len(self.requests) > 0:
                            request_count.append(len(self.requests))
                            self._handle_policy_steps(timing)

                try:
                    task_type, data = self.task_queue.get_nowait()

                    # task from the task_queue
                    if task_type == TaskType.INIT:
                        self._init()
                    elif task_type == TaskType.TERMINATE:
                        self.terminate = True
                        break
                    elif task_type == TaskType.INIT_MODEL:
                        self._init_model(data)

                    self.task_queue.task_done()
                except Empty:
                    pass

                if time.time() - last_report > 3.0 and 'one_step' in timing:
                    timing_stats = dict(wait_policy=timing.wait_policy,
                                        step_policy=timing.one_step)
                    samples_since_last_report = self.total_num_samples - last_report_samples

                    stats = memory_stats('policy_worker', self.device)
                    if len(request_count) > 0:
                        stats['avg_request_count'] = np.mean(request_count)

                    self.report_queue.put(
                        dict(
                            timing=timing_stats,
                            samples=samples_since_last_report,
                            policy_id=self.policy_id,
                            stats=stats,
                        ))
                    last_report = time.time()
                    last_report_samples = self.total_num_samples

                if time.time() - last_cache_cleanup > 300.0 or (
                        not self.cfg.benchmark
                        and self.total_num_samples < 1000):
                    if self.cfg.device == 'gpu':
                        torch.cuda.empty_cache()
                    last_cache_cleanup = time.time()

            except KeyboardInterrupt:
                log.warning('Keyboard interrupt detected on worker %d-%d',
                            self.policy_id, self.worker_idx)
                self.terminate = True
            except:
                log.exception('Unknown exception on policy worker')
                self.terminate = True

        time.sleep(0.2)
        log.info('Policy worker avg. requests %.2f, timing: %s',
                 np.mean(request_count), timing)