Пример #1
0
    def __init__(self, cfg):
        super().__init__(cfg)

        self.processes = []
        self.terminate = RawValue(ctypes.c_bool, False)

        self.start_event = multiprocessing.Event()
        self.start_event.clear()

        self.report_queue = MpQueue()
        self.report_every_sec = 1.0
        self.last_report = 0

        self.avg_stats_intervals = (1, 10, 60, 300, 600)
        self.fps_stats = deque([], maxlen=max(self.avg_stats_intervals))
Пример #2
0
class APPO(ReinforcementLearningAlgorithm):
    """Async PPO."""
    @classmethod
    def add_cli_args(cls, parser):
        p = parser
        super().add_cli_args(p)
        p.add_argument(
            '--experiment_summaries_interval',
            default=20,
            type=int,
            help=
            'How often in seconds we write avg. statistics about the experiment (reward, episode length, extra stats...)'
        )

        p.add_argument(
            '--adam_eps',
            default=1e-6,
            type=float,
            help=
            'Adam epsilon parameter (1e-8 to 1e-5 seem to reliably work okay, 1e-3 and up does not work)'
        )
        p.add_argument('--adam_beta1',
                       default=0.9,
                       type=float,
                       help='Adam momentum decay coefficient')
        p.add_argument('--adam_beta2',
                       default=0.999,
                       type=float,
                       help='Adam second momentum decay coefficient')

        p.add_argument(
            '--gae_lambda',
            default=0.95,
            type=float,
            help=
            'Generalized Advantage Estimation discounting (only used when V-trace is False'
        )

        p.add_argument(
            '--rollout',
            default=32,
            type=int,
            help='Length of the rollout from each environment in timesteps.'
            'Once we collect this many timesteps on actor worker, we send this trajectory to the learner.'
            'The length of the rollout will determine how many timesteps are used to calculate bootstrapped'
            'Monte-Carlo estimates of discounted rewards, advantages, GAE, or V-trace targets. Shorter rollouts'
            'reduce variance, but the estimates are less precise (bias vs variance tradeoff).'
            'For RNN policies, this should be a multiple of --recurrence, so every rollout will be split'
            'into (n = rollout / recurrence) segments for backpropagation. V-trace algorithm currently requires that'
            'rollout == recurrence, which what you want most of the time anyway.'
            'Rollout length is independent from the episode length. Episode length can be both shorter or longer than'
            'rollout, although for PBT training it is currently recommended that rollout << episode_len'
            '(see function finalize_trajectory in actor_worker.py)',
        )

        p.add_argument(
            '--num_workers',
            default=multiprocessing.cpu_count(),
            type=int,
            help=
            'Number of parallel environment workers. Should be less than num_envs and should divide num_envs'
        )

        p.add_argument(
            '--recurrence',
            default=32,
            type=int,
            help=
            'Trajectory length for backpropagation through time. If recurrence=1 there is no backpropagation through time, and experience is shuffled completely randomly'
            'For V-trace recurrence should be equal to rollout length.',
        )

        p.add_argument('--use_rnn',
                       default=True,
                       type=str2bool,
                       help='Whether to use RNN core in a policy or not')
        p.add_argument('--rnn_type',
                       default='gru',
                       choices=['gru', 'lstm'],
                       type=str,
                       help='Type of RNN cell to use is use_rnn is True')

        p.add_argument(
            '--ppo_clip_ratio',
            default=0.1,
            type=float,
            help=
            'We use unbiased clip(x, 1+e, 1/(1+e)) instead of clip(x, 1+e, 1-e) in the paper'
        )
        p.add_argument(
            '--ppo_clip_value',
            default=1.0,
            type=float,
            help=
            'Maximum absolute change in value estimate until it is clipped. Sensitive to value magnitude'
        )
        p.add_argument('--batch_size',
                       default=1024,
                       type=int,
                       help='Minibatch size for SGD')
        p.add_argument(
            '--num_batches_per_iteration',
            default=1,
            type=int,
            help=
            'How many minibatches we collect before training on the collected experience. It is generally recommended to set this to 1 for most experiments, because any higher value will increase the policy lag.'
            'But in some specific circumstances it can be beneficial to have a larger macro-batch in order to shuffle and decorrelate the minibatches.'
            'Here and throughout the codebase: macro batch is the portion of experience that learner processes per iteration (consisting of 1 or several minibatches)',
        )
        p.add_argument(
            '--ppo_epochs',
            default=1,
            type=int,
            help=
            'Number of training epochs before a new batch of experience is collected'
        )

        p.add_argument(
            '--num_minibatches_to_accumulate',
            default=-1,
            type=int,
            help=
            'This parameter governs the maximum number of minibatches the learner can accumulate before further experience collection is stopped.'
            'The default value (-1) will set this to 2 * num_batches_per_iteration, so if the experience collection is faster than the training,'
            'the learner will accumulate enough minibatches for 2 iterations of training (but no more). This is a good balance between policy-lag and throughput.'
            'When the limit is reached, the learner will notify the policy workers that they ought to stop the experience collection until accumulated minibatches'
            'are processed. Set this parameter to 1 * num_batches_per_iteration to further reduce policy-lag.'
            'If the experience collection is very non-uniform, increasing this parameter can increase overall throughput, at the cost of increased policy-lag.'
            'A value of 0 is treated specially. This means the experience accumulation is turned off, and all experience collection will be halted during training.'
            'This is the regime with potentially lowest policy-lag.'
            'When this parameter is 0 and num_workers * num_envs_per_worker * rollout == num_batches_per_iteration * batch_size, the algorithm is similar to'
            'regular synchronous PPO.',
        )

        p.add_argument('--max_grad_norm',
                       default=4.0,
                       type=float,
                       help='Max L2 norm of the gradient vector')

        # components of the loss function
        p.add_argument(
            '--entropy_loss_coeff',
            default=0.003,
            type=float,
            help=
            'Coefficient for the exploration component of the loss function.')
        p.add_argument('--value_loss_coeff',
                       default=0.5,
                       type=float,
                       help='Coefficient for the critic loss')

        # APPO-specific
        p.add_argument(
            '--num_envs_per_worker',
            default=2,
            type=int,
            help=
            'Number of envs on a single CPU actor, in high-throughput configurations this should be in 10-30 range for Atari/VizDoom'
            'Must be even for double-buffered sampling!',
        )
        p.add_argument(
            '--worker_num_splits',
            default=2,
            type=int,
            help=
            'Typically we split a vector of envs into two parts for "double buffered" experience collection'
            'Set this to 1 to disable double buffering. Set this to 3 for triple buffering!',
        )

        p.add_argument('--num_policies',
                       default=1,
                       type=int,
                       help='Number of policies to train jointly')
        p.add_argument(
            '--policy_workers_per_policy',
            default=1,
            type=int,
            help=
            'Number of policy workers that compute forward pass (per policy)')
        p.add_argument(
            '--max_policy_lag',
            default=100,
            type=int,
            help=
            'Max policy lag in policy versions. Discard all experience that is older than this. This should be increased for configurations with multiple epochs of SGD because naturally'
            'policy-lag may exceed this value.',
        )
        p.add_argument(
            '--min_traj_buffers_per_worker',
            default=2,
            type=int,
            help=
            'How many shared rollout tensors to allocate per actor worker to exchange information between actors and learners'
            'Default value of 2 is fine for most workloads, except when differences in 1-step simulation time are extreme, like with some DMLab environments.'
            'If you see a lot of warnings about actor workers having to wait for trajectory buffers, try increasing this to 4-6, this should eliminate the problem at a cost of more RAM.',
        )
        p.add_argument(
            '--decorrelate_experience_max_seconds',
            default=10,
            type=int,
            help=
            'Decorrelating experience serves two benefits. First: this is better for learning because samples from workers come from random moments in the episode, becoming more "i.i.d".'
            'Second, and more important one: this is good for environments with highly non-uniform one-step times, including long and expensive episode resets. If experience is not decorrelated'
            'then training batches will come in bursts e.g. after a bunch of environments finished resets and many iterations on the learner might be required,'
            'which will increase the policy-lag of the new experience collected. The performance of the Sample Factory is best when experience is generated as more-or-less'
            'uniform stream. Try increasing this to 100-200 seconds to smoothen the experience distribution in time right from the beginning (it will eventually spread out and settle anyway)',
        )
        p.add_argument(
            '--decorrelate_envs_on_one_worker',
            default=True,
            type=str2bool,
            help=
            'In addition to temporal decorrelation of worker processes, also decorrelate envs within one worker process'
            'For environments with a fixed episode length it can prevent the reset from happening in the same rollout for all envs simultaneously, which makes experience collection more uniform.',
        )

        p.add_argument(
            '--with_vtrace',
            default=True,
            type=str2bool,
            help=
            'Enables V-trace off-policy correction. If this is True, then GAE is not used'
        )
        p.add_argument(
            '--vtrace_rho',
            default=1.0,
            type=float,
            help=
            'rho_hat clipping parameter of the V-trace algorithm (importance sampling truncation)'
        )
        p.add_argument(
            '--vtrace_c',
            default=1.0,
            type=float,
            help=
            'c_hat clipping parameter of the V-trace algorithm. Low values for c_hat can reduce variance of the advantage estimates (similar to GAE lambda < 1)'
        )

        p.add_argument(
            '--set_workers_cpu_affinity',
            default=True,
            type=str2bool,
            help=
            'Whether to assign workers to specific CPU cores or not. The logic is beneficial for most workloads because prevents a lot of context switching.'
            'However for some environments it can be better to disable it, to allow one worker to use all cores some of the time. This can be the case for some DMLab environments with very expensive episode reset'
            'that can use parallel CPU cores for level generation.',
        )
        p.add_argument(
            '--force_envs_single_thread',
            default=True,
            type=str2bool,
            help=
            'Some environments may themselves use parallel libraries such as OpenMP or MKL. Since we parallelize environments on the level of workers, there is no need to keep this parallel semantic.'
            'This flag uses threadpoolctl to force libraries such as OpenMP and MKL to use only a single thread within the environment.'
            'Default value (True) is recommended unless you are running fewer workers than CPU cores.',
        )
        p.add_argument(
            '--reset_timeout_seconds',
            default=120,
            type=int,
            help=
            'Fail worker on initialization if not a single environment was reset in this time (worker probably got stuck)'
        )

        p.add_argument(
            '--default_niceness',
            default=0,
            type=int,
            help=
            'Niceness of the highest priority process (the learner). Values below zero require elevated privileges.'
        )

        p.add_argument(
            '--train_in_background_thread',
            default=True,
            type=str2bool,
            help=
            'Using background thread for training is faster and allows preparing the next batch while training is in progress.'
            'Unfortunately debugging can become very tricky in this case. So there is an option to use only a single thread on the learner to simplify the debugging.',
        )
        p.add_argument(
            '--learner_main_loop_num_cores',
            default=1,
            type=int,
            help=
            'When batching on the learner is the bottleneck, increasing the number of cores PyTorch uses can improve the performance'
        )
        p.add_argument(
            '--actor_worker_gpus',
            default=[],
            type=int,
            nargs='*',
            help=
            'By default, actor workers only use CPUs. Changes this if e.g. you need GPU-based rendering on the actors'
        )

        # PBT stuff
        p.add_argument('--with_pbt',
                       default=False,
                       type=str2bool,
                       help='Enables population-based training basic features')
        p.add_argument(
            '--pbt_period_env_steps',
            default=int(5e6),
            type=int,
            help=
            'Periodically replace the worst policies with the best ones and perturb the hyperparameters'
        )
        p.add_argument(
            '--pbt_start_mutation',
            default=int(2e7),
            type=int,
            help=
            'Allow initial diversification, start PBT after this many env steps'
        )
        p.add_argument(
            '--pbt_replace_fraction',
            default=0.3,
            type=float,
            help=
            'A portion of policies performing worst to be replace by better policies (rounded up)'
        )
        p.add_argument('--pbt_mutation_rate',
                       default=0.15,
                       type=float,
                       help='Probability that a parameter mutates')
        p.add_argument(
            '--pbt_replace_reward_gap',
            default=0.1,
            type=float,
            help=
            'Relative gap in true reward when replacing weights of the policy with a better performing one'
        )
        p.add_argument(
            '--pbt_replace_reward_gap_absolute',
            default=1e-6,
            type=float,
            help=
            'Absolute gap in true reward when replacing weights of the policy with a better performing one'
        )
        p.add_argument(
            '--pbt_optimize_batch_size',
            default=False,
            type=str2bool,
            help='Whether to optimize batch size or not (experimental)')
        p.add_argument(
            '--pbt_target_objective',
            default='true_reward',
            type=str,
            help=
            'Policy stat to optimize with PBT. true_reward (default) is equal to raw env reward if not specified, but can also be any other per-policy stat.'
            'For DMlab-30 use value "dmlab_target_objective" (which is capped human normalized score)',
        )

        # debugging options
        p.add_argument('--benchmark',
                       default=False,
                       type=str2bool,
                       help='Benchmark mode')
        p.add_argument(
            '--sampler_only',
            default=False,
            type=str2bool,
            help=
            'Do not send experience to the learner, measuring sampling throughput'
        )

    def __init__(self, cfg):
        super().__init__(cfg)

        # we should not use CUDA in the main thread, only on the workers
        set_global_cuda_envvars(cfg)

        tmp_env = make_env_func(self.cfg, env_config=None)
        self.obs_space = tmp_env.observation_space
        self.action_space = tmp_env.action_space
        self.num_agents = tmp_env.num_agents

        self.reward_shaping_scheme = None
        if self.cfg.with_pbt:
            if hasattr(tmp_env.unwrapped, '_reward_shaping_wrapper'):
                # noinspection PyProtectedMember
                self.reward_shaping_scheme = tmp_env.unwrapped._reward_shaping_wrapper.reward_shaping_scheme
            else:
                try:
                    from envs.doom.multiplayer.doom_multiagent_wrapper import MultiAgentEnv
                    if isinstance(tmp_env.unwrapped, MultiAgentEnv):
                        self.reward_shaping_scheme = tmp_env.unwrapped.default_reward_shaping
                except ImportError:
                    pass

        tmp_env.close()

        # shared memory allocation
        self.traj_buffers = SharedBuffers(self.cfg, self.num_agents,
                                          self.obs_space, self.action_space)

        self.actor_workers = None

        self.report_queue = MpQueue(20 * 1000 * 1000)
        self.policy_workers = dict()
        self.policy_queues = dict()

        self.learner_workers = dict()

        self.workers_by_handle = None

        self.policy_inputs = [[] for _ in range(self.cfg.num_policies)]
        self.policy_outputs = dict()
        for worker_idx in range(self.cfg.num_workers):
            for split_idx in range(self.cfg.worker_num_splits):
                self.policy_outputs[(worker_idx, split_idx)] = dict()

        self.policy_avg_stats = dict()
        self.policy_lag = [dict() for _ in range(self.cfg.num_policies)]

        self.last_timing = dict()
        self.env_steps = dict()
        self.samples_collected = [0 for _ in range(self.cfg.num_policies)]
        self.total_env_steps_since_resume = 0

        # currently this applies only to the current run, not experiment as a whole
        # to change this behavior we'd need to save the state of the main loop to a filesystem
        self.total_train_seconds = 0

        self.last_report = time.time()
        self.last_experiment_summaries = 0

        self.report_interval = 5.0  # sec
        self.experiment_summaries_interval = self.cfg.experiment_summaries_interval  # sec

        self.avg_stats_intervals = (2, 12, 60
                                    )  # 10 seconds, 1 minute, 5 minutes

        self.fps_stats = deque([], maxlen=max(self.avg_stats_intervals))
        self.throughput_stats = [
            deque([], maxlen=5) for _ in range(self.cfg.num_policies)
        ]
        self.avg_stats = dict()
        self.stats = dict()  # regular (non-averaged) stats

        self.writers = dict()
        writer_keys = list(range(self.cfg.num_policies))
        for key in writer_keys:
            summary_dir = join(summaries_dir(experiment_dir(cfg=self.cfg)),
                               str(key))
            summary_dir = ensure_dir_exists(summary_dir)
            self.writers[key] = SummaryWriter(summary_dir, flush_secs=20)

        self.pbt = PopulationBasedTraining(self.cfg,
                                           self.reward_shaping_scheme,
                                           self.writers)

    def _cfg_dict(self):
        if isinstance(self.cfg, dict):
            return self.cfg
        else:
            return vars(self.cfg)

    def _save_cfg(self):
        cfg_dict = self._cfg_dict()
        with open(cfg_file(self.cfg), 'w') as json_file:
            json.dump(cfg_dict, json_file, indent=2)

    def initialize(self):
        self._save_cfg()
        save_git_diff(experiment_dir(cfg=self.cfg))

    def finalize(self):
        pass

    def create_actor_worker(self, idx, actor_queue):
        learner_queues = {
            p: w.task_queue
            for p, w in self.learner_workers.items()
        }

        return ActorWorker(
            self.cfg,
            self.obs_space,
            self.action_space,
            self.num_agents,
            idx,
            self.traj_buffers,
            task_queue=actor_queue,
            policy_queues=self.policy_queues,
            report_queue=self.report_queue,
            learner_queues=learner_queues,
        )

    # noinspection PyProtectedMember
    def init_subset(self, indices, actor_queues):
        """
        Initialize a subset of actor workers (rollout workers) and wait until the first reset() is completed for all
        envs on these workers.

        This function will retry if the worker process crashes during the initial reset.

        :param indices: indices of actor workers to initialize
        :param actor_queues: task queues corresponding to these workers
        :return: initialized workers
        """

        reset_timelimit_seconds = self.cfg.reset_timeout_seconds  # fail worker if not a single env was reset in that time

        workers = dict()
        last_env_initialized = dict()
        for i in indices:
            w = self.create_actor_worker(i, actor_queues[i])
            w.init()
            w.request_reset()
            workers[i] = w
            last_env_initialized[i] = time.time()

        total_num_envs = self.cfg.num_workers * self.cfg.num_envs_per_worker
        envs_initialized = [0] * self.cfg.num_workers
        workers_finished = set()

        while len(workers_finished) < len(workers):
            failed_worker = -1

            try:
                report = self.report_queue.get(timeout=1.0)

                if 'initialized_env' in report:
                    worker_idx, split_idx, env_i = report['initialized_env']
                    last_env_initialized[worker_idx] = time.time()
                    envs_initialized[worker_idx] += 1

                    log.debug(
                        'Progress for %d workers: %d/%d envs initialized...',
                        len(indices),
                        sum(envs_initialized),
                        total_num_envs,
                    )
                elif 'finished_reset' in report:
                    workers_finished.add(report['finished_reset'])
                elif 'critical_error' in report:
                    failed_worker = report['critical_error']
            except Empty:
                pass

            for worker_idx, w in workers.items():
                if worker_idx in workers_finished:
                    continue

                time_passed = time.time() - last_env_initialized[worker_idx]
                timeout = time_passed > reset_timelimit_seconds

                if timeout or failed_worker == worker_idx or not w.process.is_alive(
                ):
                    envs_initialized[worker_idx] = 0

                    log.error('Worker %d is stuck or failed (%.3f). Reset!',
                              w.worker_idx, time_passed)
                    log.debug('Status: %r', w.process.is_alive())
                    stuck_worker = w
                    stuck_worker.process.kill()

                    new_worker = self.create_actor_worker(
                        worker_idx, actor_queues[worker_idx])
                    new_worker.init()
                    new_worker.request_reset()

                    last_env_initialized[worker_idx] = time.time()
                    workers[worker_idx] = new_worker
                    del stuck_worker

        return workers.values()

    # noinspection PyUnresolvedReferences
    def init_workers(self):
        """
        Initialize all types of workers and start their worker processes.
        """

        actor_queues = [MpQueue() for _ in range(self.cfg.num_workers)]

        policy_worker_queues = dict()
        for policy_id in range(self.cfg.num_policies):
            policy_worker_queues[policy_id] = []
            for i in range(self.cfg.policy_workers_per_policy):
                policy_worker_queues[policy_id].append(TorchJoinableQueue())

        log.info('Initializing learners...')
        policy_locks = [
            multiprocessing.Lock() for _ in range(self.cfg.num_policies)
        ]
        resume_experience_collection_cv = [
            multiprocessing.Condition() for _ in range(self.cfg.num_policies)
        ]

        learner_idx = 0
        for policy_id in range(self.cfg.num_policies):
            learner_worker = LearnerWorker(
                learner_idx,
                policy_id,
                self.cfg,
                self.obs_space,
                self.action_space,
                self.report_queue,
                policy_worker_queues[policy_id],
                self.traj_buffers,
                policy_locks[policy_id],
                resume_experience_collection_cv[policy_id],
            )
            learner_worker.start_process()
            learner_worker.init()

            self.learner_workers[policy_id] = learner_worker
            learner_idx += 1

        log.info('Initializing policy workers...')
        for policy_id in range(self.cfg.num_policies):
            self.policy_workers[policy_id] = []

            policy_queue = MpQueue()
            self.policy_queues[policy_id] = policy_queue

            for i in range(self.cfg.policy_workers_per_policy):
                policy_worker = PolicyWorker(
                    i,
                    policy_id,
                    self.cfg,
                    self.obs_space,
                    self.action_space,
                    self.traj_buffers,
                    policy_queue,
                    actor_queues,
                    self.report_queue,
                    policy_worker_queues[policy_id][i],
                    policy_locks[policy_id],
                    resume_experience_collection_cv[policy_id],
                )
                self.policy_workers[policy_id].append(policy_worker)
                policy_worker.start_process()

        log.info('Initializing actors...')

        # We support actor worker initialization in groups, which can be useful for some envs that
        # e.g. crash when too many environments are being initialized in parallel.
        # Currently the limit is not used since it is not required for any envs supported out of the box,
        # so we parallelize initialization as hard as we can.
        # If this is required for your environment, perhaps a better solution would be to use global locks,
        # like FileLock (see doom_gym.py)
        self.actor_workers = []
        max_parallel_init = int(
            1e9)  # might be useful to limit this for some envs
        worker_indices = list(range(self.cfg.num_workers))
        for i in range(0, self.cfg.num_workers, max_parallel_init):
            workers = self.init_subset(worker_indices[i:i + max_parallel_init],
                                       actor_queues)
            self.actor_workers.extend(workers)

    def init_pbt(self):
        if self.cfg.with_pbt:
            self.pbt.init(self.learner_workers, self.actor_workers)

    def finish_initialization(self):
        """Wait until policy workers are fully initialized."""
        for policy_id, workers in self.policy_workers.items():
            for w in workers:
                log.debug(
                    'Waiting for policy worker %d-%d to finish initialization...',
                    policy_id, w.worker_idx)
                w.init()
                log.debug('Policy worker %d-%d initialized!', policy_id,
                          w.worker_idx)

    def process_report(self, report):
        """Process stats from various types of workers."""

        if 'policy_id' in report:
            policy_id = report['policy_id']

            if 'learner_env_steps' in report:
                if policy_id in self.env_steps:
                    delta = report['learner_env_steps'] - self.env_steps[
                        policy_id]
                    self.total_env_steps_since_resume += delta
                self.env_steps[policy_id] = report['learner_env_steps']

            if 'episodic' in report:
                s = report['episodic']
                for _, key, value in iterate_recursively(s):
                    if key not in self.policy_avg_stats:
                        self.policy_avg_stats[key] = [
                            deque(maxlen=self.cfg.stats_avg)
                            for _ in range(self.cfg.num_policies)
                        ]

                    self.policy_avg_stats[key][policy_id].append(value)

                    for extra_stat_func in EXTRA_EPISODIC_STATS_PROCESSING:
                        extra_stat_func(policy_id, key, value, self.cfg)

            if 'train' in report:
                self.report_train_summaries(report['train'], policy_id)

            if 'samples' in report:
                self.samples_collected[policy_id] += report['samples']

        if 'timing' in report:
            for k, v in report['timing'].items():
                if k not in self.avg_stats:
                    self.avg_stats[k] = deque([], maxlen=50)
                self.avg_stats[k].append(v)

        if 'stats' in report:
            self.stats.update(report['stats'])

    def report(self):
        """
        Called periodically (every X seconds, see report_interval).
        Print experiment stats (FPS, avg rewards) to console and dump TF summaries collected from workers to disk.
        """

        if len(self.env_steps) < self.cfg.num_policies:
            return

        now = time.time()
        self.fps_stats.append((now, self.total_env_steps_since_resume))
        if len(self.fps_stats) <= 1:
            return

        fps = []
        for avg_interval in self.avg_stats_intervals:
            past_moment, past_frames = self.fps_stats[max(
                0,
                len(self.fps_stats) - 1 - avg_interval)]
            fps.append((self.total_env_steps_since_resume - past_frames) /
                       (now - past_moment))

        sample_throughput = dict()
        for policy_id in range(self.cfg.num_policies):
            self.throughput_stats[policy_id].append(
                (now, self.samples_collected[policy_id]))
            if len(self.throughput_stats[policy_id]) > 1:
                past_moment, past_samples = self.throughput_stats[policy_id][0]
                sample_throughput[policy_id] = (
                    self.samples_collected[policy_id] -
                    past_samples) / (now - past_moment)
            else:
                sample_throughput[policy_id] = math.nan

        total_env_steps = sum(self.env_steps.values())
        self.print_stats(fps, sample_throughput, total_env_steps)

        if time.time(
        ) - self.last_experiment_summaries > self.experiment_summaries_interval:
            self.report_experiment_summaries(fps[0], sample_throughput)
            self.last_experiment_summaries = time.time()

    def print_stats(self, fps, sample_throughput, total_env_steps):
        fps_str = []
        for interval, fps_value in zip(self.avg_stats_intervals, fps):
            fps_str.append(
                f'{int(interval * self.report_interval)} sec: {fps_value:.1f}')
        fps_str = f'({", ".join(fps_str)})'

        samples_per_policy = ', '.join(
            [f'{p}: {s:.1f}' for p, s in sample_throughput.items()])

        lag_stats = self.policy_lag[0]
        lag = AttrDict()
        for key in ['min', 'avg', 'max']:
            lag[key] = lag_stats.get(f'version_diff_{key}', -1)
        policy_lag_str = f'min: {lag.min:.1f}, avg: {lag.avg:.1f}, max: {lag.max:.1f}'

        log.debug(
            'Fps is %s. Total num frames: %d. Throughput: %s. Samples: %d. Policy #0 lag: (%s)',
            fps_str,
            total_env_steps,
            samples_per_policy,
            sum(self.samples_collected),
            policy_lag_str,
        )

        if 'reward' in self.policy_avg_stats:
            policy_reward_stats = []
            for policy_id in range(self.cfg.num_policies):
                reward_stats = self.policy_avg_stats['reward'][policy_id]
                if len(reward_stats) > 0:
                    policy_reward_stats.append(
                        (policy_id, f'{np.mean(reward_stats):.3f}'))
            log.debug('Avg episode reward: %r', policy_reward_stats)

    def report_train_summaries(self, stats, policy_id):
        for key, scalar in stats.items():
            self.writers[policy_id].add_scalar(f'train/{key}', scalar,
                                               self.env_steps[policy_id])
            if 'version_diff' in key:
                self.policy_lag[policy_id][key] = scalar

    def report_experiment_summaries(self, fps, sample_throughput):
        memory_mb = memory_consumption_mb()

        default_policy = 0
        for policy_id, env_steps in self.env_steps.items():
            if policy_id == default_policy:
                self.writers[policy_id].add_scalar('0_aux/_fps', fps,
                                                   env_steps)
                self.writers[policy_id].add_scalar(
                    '0_aux/master_process_memory_mb', float(memory_mb),
                    env_steps)
                for key, value in self.avg_stats.items():
                    if len(value) >= value.maxlen or (
                            len(value) > 10
                            and self.total_train_seconds > 300):
                        self.writers[policy_id].add_scalar(
                            f'stats/{key}', np.mean(value), env_steps)

                for key, value in self.stats.items():
                    self.writers[policy_id].add_scalar(f'stats/{key}', value,
                                                       env_steps)

            if not math.isnan(sample_throughput[policy_id]):
                self.writers[policy_id].add_scalar(
                    '0_aux/_sample_throughput', sample_throughput[policy_id],
                    env_steps)

            for key, stat in self.policy_avg_stats.items():
                if len(stat[policy_id]) >= stat[policy_id].maxlen or (
                        len(stat[policy_id]) > 10
                        and self.total_train_seconds > 300):
                    stat_value = np.mean(stat[policy_id])
                    writer = self.writers[policy_id]
                    writer.add_scalar(f'0_aux/avg_{key}', float(stat_value),
                                      env_steps)

                    # for key stats report min/max as well
                    if key in ('reward', 'true_reward', 'len'):
                        writer.add_scalar(f'0_aux/avg_{key}_min',
                                          float(min(stat[policy_id])),
                                          env_steps)
                        writer.add_scalar(f'0_aux/avg_{key}_max',
                                          float(max(stat[policy_id])),
                                          env_steps)

            for extra_summaries_func in EXTRA_PER_POLICY_SUMMARIES:
                extra_summaries_func(policy_id, self.policy_avg_stats,
                                     env_steps, self.writers[policy_id],
                                     self.cfg)

    def _should_end_training(self):
        end = len(self.env_steps) > 0 and all(s > self.cfg.train_for_env_steps
                                              for s in self.env_steps.values())
        end |= self.total_train_seconds > self.cfg.train_for_seconds

        if self.cfg.benchmark:
            end |= self.total_env_steps_since_resume >= int(2e6)
            end |= sum(self.samples_collected) >= int(1e6)

        return end

    def run(self):
        """
        This function contains the main loop of the algorithm, as well as initialization/cleanup code.

        :return: ExperimentStatus (SUCCESS, FAILURE, INTERRUPTED). Useful in testing.
        """

        status = ExperimentStatus.SUCCESS

        if os.path.isfile(done_filename(self.cfg)):
            log.warning(
                'Training already finished! Remove "done" file to continue training'
            )
            return status

        self.init_workers()
        self.init_pbt()
        self.finish_initialization()

        log.info('Collecting experience...')

        timing = Timing()
        with timing.timeit('experience'):
            # noinspection PyBroadException
            try:
                while not self._should_end_training():
                    try:
                        reports = self.report_queue.get_many(timeout=0.1)
                        for report in reports:
                            self.process_report(report)
                    except Empty:
                        pass

                    if time.time() - self.last_report > self.report_interval:
                        self.report()

                        now = time.time()
                        self.total_train_seconds += now - self.last_report
                        self.last_report = now

                    self.pbt.update(self.env_steps, self.policy_avg_stats)

            except Exception:
                log.exception('Exception in driver loop')
                status = ExperimentStatus.FAILURE
            except KeyboardInterrupt:
                log.warning(
                    'Keyboard interrupt detected in driver loop, exiting...')
                status = ExperimentStatus.INTERRUPTED

        for learner in self.learner_workers.values():
            # timeout is needed here because some environments may crash on KeyboardInterrupt (e.g. VizDoom)
            # Therefore the learner train loop will never do another iteration and will never save the model.
            # This is not an issue with normal exit, e.g. due to desired number of frames reached.
            learner.save_model(timeout=5.0)

        all_workers = self.actor_workers
        for workers in self.policy_workers.values():
            all_workers.extend(workers)
        all_workers.extend(self.learner_workers.values())

        child_processes = list_child_processes()

        time.sleep(0.1)
        log.debug('Closing workers...')
        for i, w in enumerate(all_workers):
            w.close()
            time.sleep(0.01)
        for i, w in enumerate(all_workers):
            w.join()
        log.debug('Workers joined!')

        # VizDoom processes often refuse to die for an unidentified reason, so we're force killing them with a hack
        kill_processes(child_processes)

        fps = self.total_env_steps_since_resume / timing.experience
        log.info('Collected %r, FPS: %.1f', self.env_steps, fps)
        log.info('Timing: %s', timing)

        if self._should_end_training():
            with open(done_filename(self.cfg), 'w') as fobj:
                fobj.write(f'{self.env_steps}')

        time.sleep(0.5)
        log.info('Done!')

        return status
Пример #3
0
    def __init__(self, cfg):
        super().__init__(cfg)

        # we should not use CUDA in the main thread, only on the workers
        set_global_cuda_envvars(cfg)

        tmp_env = make_env_func(self.cfg, env_config=None)
        self.obs_space = tmp_env.observation_space
        self.action_space = tmp_env.action_space
        self.num_agents = tmp_env.num_agents

        self.reward_shaping_scheme = None
        if self.cfg.with_pbt:
            if hasattr(tmp_env.unwrapped, '_reward_shaping_wrapper'):
                # noinspection PyProtectedMember
                self.reward_shaping_scheme = tmp_env.unwrapped._reward_shaping_wrapper.reward_shaping_scheme
            else:
                try:
                    from envs.doom.multiplayer.doom_multiagent_wrapper import MultiAgentEnv
                    if isinstance(tmp_env.unwrapped, MultiAgentEnv):
                        self.reward_shaping_scheme = tmp_env.unwrapped.default_reward_shaping
                except ImportError:
                    pass

        tmp_env.close()

        # shared memory allocation
        self.traj_buffers = SharedBuffers(self.cfg, self.num_agents,
                                          self.obs_space, self.action_space)

        self.actor_workers = None

        self.report_queue = MpQueue(20 * 1000 * 1000)
        self.policy_workers = dict()
        self.policy_queues = dict()

        self.learner_workers = dict()

        self.workers_by_handle = None

        self.policy_inputs = [[] for _ in range(self.cfg.num_policies)]
        self.policy_outputs = dict()
        for worker_idx in range(self.cfg.num_workers):
            for split_idx in range(self.cfg.worker_num_splits):
                self.policy_outputs[(worker_idx, split_idx)] = dict()

        self.policy_avg_stats = dict()
        self.policy_lag = [dict() for _ in range(self.cfg.num_policies)]

        self.last_timing = dict()
        self.env_steps = dict()
        self.samples_collected = [0 for _ in range(self.cfg.num_policies)]
        self.total_env_steps_since_resume = 0

        # currently this applies only to the current run, not experiment as a whole
        # to change this behavior we'd need to save the state of the main loop to a filesystem
        self.total_train_seconds = 0

        self.last_report = time.time()
        self.last_experiment_summaries = 0

        self.report_interval = 5.0  # sec
        self.experiment_summaries_interval = self.cfg.experiment_summaries_interval  # sec

        self.avg_stats_intervals = (2, 12, 60
                                    )  # 10 seconds, 1 minute, 5 minutes

        self.fps_stats = deque([], maxlen=max(self.avg_stats_intervals))
        self.throughput_stats = [
            deque([], maxlen=5) for _ in range(self.cfg.num_policies)
        ]
        self.avg_stats = dict()
        self.stats = dict()  # regular (non-averaged) stats

        self.writers = dict()
        writer_keys = list(range(self.cfg.num_policies))
        for key in writer_keys:
            summary_dir = join(summaries_dir(experiment_dir(cfg=self.cfg)),
                               str(key))
            summary_dir = ensure_dir_exists(summary_dir)
            self.writers[key] = SummaryWriter(summary_dir, flush_secs=20)

        self.pbt = PopulationBasedTraining(self.cfg,
                                           self.reward_shaping_scheme,
                                           self.writers)
Пример #4
0
class DummySampler(AlgorithmBase):
    @classmethod
    def add_cli_args(cls, parser):
        p = parser
        super().add_cli_args(p)

        p.add_argument('--num_workers', default=multiprocessing.cpu_count(), type=int, help='Number of processes to use to sample the environment.')
        p.add_argument('--num_envs_per_worker', default=1, type=int, help='Number of envs on a single CPU sampled sequentially.')

        p.add_argument('--sample_env_frames', default=int(2e6), type=int, help='Stop after sampling this many env frames (this takes frameskip into account)')
        p.add_argument('--sample_env_frames_per_worker', default=int(1e5), type=int, help='Stop after sampling this many env frames per worker (this takes frameskip into account)')

        p.add_argument(
            '--set_workers_cpu_affinity', default=True, type=str2bool,
            help=(
                'Whether to assign workers to specific CPU cores or not. The logic is beneficial for most workloads because prevents a lot of context switching.'
                'However for some environments it can be better to disable it, to allow one worker to use all cores some of the time. This can be the case for some DMLab environments with very expensive episode reset'
                'that can use parallel CPU cores for level generation.'),
        )

    def __init__(self, cfg):
        super().__init__(cfg)

        self.processes = []
        self.terminate = RawValue(ctypes.c_bool, False)

        self.start_event = multiprocessing.Event()
        self.start_event.clear()

        self.report_queue = MpQueue()
        self.report_every_sec = 1.0
        self.last_report = 0

        self.avg_stats_intervals = (1, 10, 60, 300, 600)
        self.fps_stats = deque([], maxlen=max(self.avg_stats_intervals))

    def initialize(self):
        # creating an environment in the main process tends to fix some very weird issues further down the line
        # https://stackoverflow.com/questions/60963839/importing-opencv-after-importing-pytorch-messes-with-cpu-affinity
        # do not delete this unless you know what you're doing
        tmp_env = create_env(self.cfg.env, cfg=self.cfg, env_config=None)
        tmp_env.close()

        for i in range(self.cfg.num_workers):
            p = multiprocessing.Process(target=self.sample, args=(i, ))
            self.processes.append(p)

    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        from threadpoolctl import threadpool_limits
        with threadpool_limits(limits=1, user_api=None):
            if self.cfg.set_workers_cpu_affinity:
                set_process_cpu_affinity(proc_idx, self.cfg.num_workers)

            initial_cpu_affinity = psutil.Process().cpu_affinity() if platform != 'darwin' else None
            psutil.Process().nice(10)

            with timing.timeit('env_init'):
                envs = []
                env_key = ['env' for _ in range(self.cfg.num_envs_per_worker)]

                for env_idx in range(self.cfg.num_envs_per_worker):
                    global_env_id = proc_idx * self.cfg.num_envs_per_worker + env_idx
                    env_config = AttrDict(worker_index=proc_idx, vector_index=env_idx, env_id=global_env_id)

                    env = make_env_func(cfg=self.cfg, env_config=env_config)
                    log.debug('CPU affinity after create_env: %r', psutil.Process().cpu_affinity() if platform != 'darwin' else 'MacOS - None')
                    env.seed(global_env_id)
                    envs.append(env)

                    # this is to track the performance for individual DMLab levels
                    if hasattr(env.unwrapped, 'level_name'):
                        env_key[env_idx] = env.unwrapped.level_name

                episode_length = [0 for _ in envs]
                episode_lengths = [deque([], maxlen=20) for _ in envs]

            try:
                with timing.timeit('first_reset'):
                    for env_idx, env in enumerate(envs):
                        env.reset()
                        log.info('Process %d finished resetting %d/%d envs', proc_idx, env_idx + 1, len(envs))

                    self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True))

                self.start_event.wait()

                with timing.timeit('work'):
                    last_report = last_report_frames = total_env_frames = 0
                    while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                        for env_idx, env in enumerate(envs):
                            actions = [env.action_space.sample() for _ in range(env.num_agents)]
                            with timing.add_time(f'{env_key[env_idx]}.step'):
                                obs, rewards, dones, infos = env.step(actions)

                            num_frames = sum([info.get('num_frames', 1) for info in infos])
                            total_env_frames += num_frames
                            episode_length[env_idx] += num_frames

                            if all(dones):
                                episode_lengths[env_idx].append(episode_length[env_idx])
                                episode_length[env_idx] = 0

                        with timing.add_time('report'):
                            now = time.time()
                            if now - last_report > self.report_every_sec:
                                last_report = now
                                frames_since_last_report = total_env_frames - last_report_frames
                                last_report_frames = total_env_frames
                                self.report_queue.put(dict(proc_idx=proc_idx, env_frames=frames_since_last_report))

                # Extra check to make sure cpu affinity is preserved throughout the execution.
                # I observed weird effect when some environments tried to alter affinity of the current process, leading
                # to decreased performance.
                # This can be caused by some interactions between deep learning libs, OpenCV, MKL, OpenMP, etc.
                # At least user should know about it if this is happening.
                cpu_affinity = psutil.Process().cpu_affinity() if platform != 'darwin' else None
                assert initial_cpu_affinity == cpu_affinity, \
                    f'Worker CPU affinity was changed from {initial_cpu_affinity} to {cpu_affinity}!' \
                    f'This can significantly affect performance!'

            except:
                log.exception('Unknown exception')
                log.error('Unknown exception in worker %d, terminating...', proc_idx)
                self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

            time.sleep(proc_idx * 0.01 + 0.01)
            log.info('Process %d finished sampling. Timing: %s', proc_idx, timing)

            for env_idx, env in enumerate(envs):
                if len(episode_lengths[env_idx]) > 0:
                    log.warning('Level %s avg episode len %d', env_key[env_idx], np.mean(episode_lengths[env_idx]))

            for env in envs:
                env.close()

    def report(self, env_frames):
        now = time.time()
        self.last_report = now

        self.fps_stats.append((now, env_frames))
        if len(self.fps_stats) <= 1:
            return

        fps = []
        for avg_interval in self.avg_stats_intervals:
            past_moment, past_frames = self.fps_stats[max(0, len(self.fps_stats) - 1 - avg_interval)]
            fps.append((env_frames - past_frames) / (now - past_moment))

        fps_str = []
        for interval, fps_value in zip(self.avg_stats_intervals, fps):
            fps_str.append(f'{int(interval * self.report_every_sec)} sec: {fps_value:.1f}')
        fps_str = f'({", ".join(fps_str)})'
        log.info('Sampling FPS: %s. Total frames collected: %d', fps_str, env_frames)

    def run(self):
        for p in self.processes:
            p.start()

        finished_reset = np.zeros([self.cfg.num_workers], dtype=np.bool)
        while not all(finished_reset):
            try:
                msg = self.report_queue.get(timeout=0.1)
                if 'finished_reset' in msg:
                    finished_reset[msg['proc_idx']] = True
                    log.debug('Process %d finished reset! Status %r', msg['proc_idx'], finished_reset)
            except Empty:
                pass

        log.debug('All workers finished reset!')
        time.sleep(3)
        self.start_event.set()

        start = time.time()
        env_frames = 0
        last_process_report = [time.time() for _ in self.processes]

        while not self.terminate.value:
            try:
                try:
                    msgs = self.report_queue.get_many(timeout=self.report_every_sec * 1.5)
                    for msg in msgs:
                        last_process_report[msg['proc_idx']] = time.time()

                        if 'crash' in msg:
                            self.terminate.value = True
                            log.error('Terminating due to process %d crashing...', msg['proc_idx'])
                            break

                        env_frames += msg['env_frames']

                    if env_frames >= self.cfg.sample_env_frames:
                        self.terminate.value = True
                except Empty:
                    pass
            except KeyboardInterrupt:
                self.terminate.value = True
                log.error('KeyboardInterrupt in main loop! Terminating...')
                break

            if time.time() - self.last_report > self.report_every_sec:
                self.report(env_frames)

            for proc_idx, p in enumerate(self.processes):
                delay = time.time() - last_process_report[proc_idx]
                if delay > 600:
                    # killing the whole script is the best way to know that some of the processes froze
                    log.error('Process %d had not responded in %.1f s!!! Terminating...', proc_idx, delay)
                    self.terminate.value = True

            for p in self.processes:
                if not p.is_alive():
                    self.terminate.value = True
                    log.error('Process %r died! terminating...', p)

        total_time = time.time() - start
        log.info('Collected %d frames in %.1f s, avg FPS: %.1f', env_frames, total_time, env_frames / total_time)
        log.debug('Done sampling...')

    def finalize(self):
        try:
            self.report_queue.get_many_nowait()
        except Empty:
            pass

        log.debug('Joining worker processes...')
        for p in self.processes:
            p.join()
        log.debug('Done joining!')
Пример #5
0
    def __init__(
        self, worker_idx, policy_id, cfg, obs_space, action_space, report_queue, policy_worker_queues, shared_buffers,
        policy_lock, resume_experience_collection_cv,
    ):
        log.info('Initializing the learner %d for policy %d', worker_idx, policy_id)

        self.worker_idx = worker_idx
        self.policy_id = policy_id

        self.cfg = cfg

        # PBT-related stuff
        self.should_save_model = True  # set to true if we need to save the model to disk on the next training iteration
        self.load_policy_id = None  # non-None when we need to replace our parameters with another policy's parameters
        self.pbt_mutex = threading.Lock()
        self.new_cfg = None  # non-None when we need to update the learning hyperparameters

        self.terminate = False
        self.num_batches_processed = 0

        self.obs_space = obs_space
        self.action_space = action_space

        self.rollout_tensors = shared_buffers.tensor_trajectories
        self.traj_tensors_available = shared_buffers.is_traj_tensor_available
        self.policy_versions = shared_buffers.policy_versions
        self.stop_experience_collection = shared_buffers.stop_experience_collection

        self.stop_experience_collection_num_msgs = self.resume_experience_collection_num_msgs = 0

        self.device = None
        self.dqn = None
        self.optimizer = None
        self.policy_lock = policy_lock
        self.resume_experience_collection_cv = resume_experience_collection_cv

        self.task_queue = MpQueue()
        self.report_queue = report_queue

        self.initialized_event = MultiprocessingEvent()
        self.initialized_event.clear()

        self.model_saved_event = MultiprocessingEvent()
        self.model_saved_event.clear()

        # queues corresponding to policy workers using the same policy
        # we send weight updates via these queues
        self.policy_worker_queues = policy_worker_queues

        self.experience_buffer_queue = Queue()

        self.tensor_batch_pool = ObjectPool()
        self.tensor_batcher = TensorBatcher(self.tensor_batch_pool)

        self.with_training = True  # set to False for debugging no-training regime
        self.train_in_background = self.cfg.train_in_background_thread  # set to False for debugging

        self.training_thread = Thread(target=self._train_loop) if self.train_in_background else None
        self.train_thread_initialized = threading.Event()

        self.is_training = False

        self.train_step = self.env_steps = 0

        # decay rate at which summaries are collected
        # save summaries every 20 seconds in the beginning, but decay to every 4 minutes in the limit, because we
        # do not need frequent summaries for longer experiments
        self.summary_rate_decay_seconds = LinearDecay([(0, 20), (100000, 120), (1000000, 240)])
        self.last_summary_time = 0

        self.last_saved_time = self.last_milestone_time = 0

        self.discarded_experience_over_time = deque([], maxlen=30)
        self.discarded_experience_timer = time.time()
        self.num_discarded_rollouts = 0

        self.process = Process(target=self._run, daemon=True)

        if is_continuous_action_space(self.action_space) and self.cfg.exploration_loss == 'symmetric_kl':
            raise NotImplementedError('KL-divergence exploration loss is not supported with '
                                      'continuous action spaces. Use entropy exploration loss')

        if self.cfg.exploration_loss_coeff == 0.0:
            self.exploration_loss_func = lambda action_distr: 0.0
        elif self.cfg.exploration_loss == 'entropy':
            self.exploration_loss_func = self.entropy_exploration_loss
        elif self.cfg.exploration_loss == 'symmetric_kl':
            self.exploration_loss_func = self.symmetric_kl_exploration_loss
        else:
            raise NotImplementedError(f'{self.cfg.exploration_loss} not supported!')
Пример #6
0
class LearnerWorker:
    def __init__(
        self, worker_idx, policy_id, cfg, obs_space, action_space, report_queue, policy_worker_queues, shared_buffers,
        policy_lock, resume_experience_collection_cv,
    ):
        log.info('Initializing the learner %d for policy %d', worker_idx, policy_id)

        self.worker_idx = worker_idx
        self.policy_id = policy_id

        self.cfg = cfg

        # PBT-related stuff
        self.should_save_model = True  # set to true if we need to save the model to disk on the next training iteration
        self.load_policy_id = None  # non-None when we need to replace our parameters with another policy's parameters
        self.pbt_mutex = threading.Lock()
        self.new_cfg = None  # non-None when we need to update the learning hyperparameters

        self.terminate = False
        self.num_batches_processed = 0

        self.obs_space = obs_space
        self.action_space = action_space

        self.rollout_tensors = shared_buffers.tensor_trajectories
        self.traj_tensors_available = shared_buffers.is_traj_tensor_available
        self.policy_versions = shared_buffers.policy_versions
        self.stop_experience_collection = shared_buffers.stop_experience_collection

        self.stop_experience_collection_num_msgs = self.resume_experience_collection_num_msgs = 0

        self.device = None
        self.dqn = None
        self.optimizer = None
        self.policy_lock = policy_lock
        self.resume_experience_collection_cv = resume_experience_collection_cv

        self.task_queue = MpQueue()
        self.report_queue = report_queue

        self.initialized_event = MultiprocessingEvent()
        self.initialized_event.clear()

        self.model_saved_event = MultiprocessingEvent()
        self.model_saved_event.clear()

        # queues corresponding to policy workers using the same policy
        # we send weight updates via these queues
        self.policy_worker_queues = policy_worker_queues

        self.experience_buffer_queue = Queue()

        self.tensor_batch_pool = ObjectPool()
        self.tensor_batcher = TensorBatcher(self.tensor_batch_pool)

        self.with_training = True  # set to False for debugging no-training regime
        self.train_in_background = self.cfg.train_in_background_thread  # set to False for debugging

        self.training_thread = Thread(target=self._train_loop) if self.train_in_background else None
        self.train_thread_initialized = threading.Event()

        self.is_training = False

        self.train_step = self.env_steps = 0

        # decay rate at which summaries are collected
        # save summaries every 20 seconds in the beginning, but decay to every 4 minutes in the limit, because we
        # do not need frequent summaries for longer experiments
        self.summary_rate_decay_seconds = LinearDecay([(0, 20), (100000, 120), (1000000, 240)])
        self.last_summary_time = 0

        self.last_saved_time = self.last_milestone_time = 0

        self.discarded_experience_over_time = deque([], maxlen=30)
        self.discarded_experience_timer = time.time()
        self.num_discarded_rollouts = 0

        self.process = Process(target=self._run, daemon=True)

        if is_continuous_action_space(self.action_space) and self.cfg.exploration_loss == 'symmetric_kl':
            raise NotImplementedError('KL-divergence exploration loss is not supported with '
                                      'continuous action spaces. Use entropy exploration loss')

        if self.cfg.exploration_loss_coeff == 0.0:
            self.exploration_loss_func = lambda action_distr: 0.0
        elif self.cfg.exploration_loss == 'entropy':
            self.exploration_loss_func = self.entropy_exploration_loss
        elif self.cfg.exploration_loss == 'symmetric_kl':
            self.exploration_loss_func = self.symmetric_kl_exploration_loss
        else:
            raise NotImplementedError(f'{self.cfg.exploration_loss} not supported!')

    def start_process(self):
        self.process.start()

    def _init(self):
        log.info('Waiting for the learner to initialize...')
        self.train_thread_initialized.wait()
        log.info('Learner %d initialized', self.worker_idx)
        self.initialized_event.set()

    def _terminate(self):
        self.terminate = True

    def _broadcast_model_weights(self):
        state_dict = self.dqn.main.state_dict()
        policy_version = self.train_step
        log.debug('Broadcast model weights for model version %d', policy_version)
        model_state = (policy_version, state_dict)
        for q in self.policy_worker_queues:
            q.put((TaskType.INIT_MODEL, model_state))

    def _mark_rollout_buffer_free(self, rollout):
        r = rollout
        self.traj_tensors_available[r.worker_idx, r.split_idx][r.env_idx, r.agent_idx, r.traj_buffer_idx] = 1

    def _prepare_train_buffer(self, rollouts, macro_batch_size, timing):
        trajectories = [AttrDict(r['t']) for r in rollouts]

        with timing.add_time('buffers'):
            buffer = AttrDict()

            # by the end of this loop the buffer is a dictionary containing lists of numpy arrays
            for i, t in enumerate(trajectories):
                for key, x in t.items():
                    if key not in buffer:
                        buffer[key] = []
                    buffer[key].append(x)

            # convert lists of dict observations to a single dictionary of lists
            for key, x in buffer.items():
                if isinstance(x[0], (dict, OrderedDict)):
                    buffer[key] = list_of_dicts_to_dict_of_lists(x)

        if not self.cfg.with_vtrace:
            with timing.add_time('calc_gae'):
                buffer = self._calculate_gae(buffer)

        with timing.add_time('batching'):
            # concatenate rollouts from different workers into a single batch efficiently
            # that is, if we already have memory for the buffers allocated, we can just copy the data into
            # existing cached tensors instead of creating new ones. This is a performance optimization.
            use_pinned_memory = self.cfg.device == 'gpu'
            buffer = self.tensor_batcher.cat(buffer, macro_batch_size, use_pinned_memory, timing)

        with timing.add_time('buff_ready'):
            for r in rollouts:
                self._mark_rollout_buffer_free(r)

        with timing.add_time('tensors_gpu_float'):
            device_buffer = self._copy_train_data_to_device(buffer)

        with timing.add_time('squeeze'):
            # will squeeze actions only in simple categorical case
            tensors_to_squeeze = [
                'actions', 'log_prob_actions', 'policy_version', 'values',
                'rewards', 'dones', 'rewards_cpu', 'dones_cpu',
            ]
            for tensor_name in tensors_to_squeeze:
                device_buffer[tensor_name].squeeze_()

        # we no longer need the cached buffer, and can put it back into the pool
        self.tensor_batch_pool.put(buffer)
        return device_buffer

    def _macro_batch_size(self, batch_size):
        return self.cfg.num_batches_per_iteration * batch_size

    def _process_macro_batch(self, rollouts, batch_size, timing):
        macro_batch_size = self._macro_batch_size(batch_size)

        assert macro_batch_size % self.cfg.rollout == 0
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert macro_batch_size % self.cfg.recurrence == 0

        samples = env_steps = 0
        for rollout in rollouts:
            samples += rollout['length']
            env_steps += rollout['env_steps']

        with timing.add_time('prepare'):
            buffer = self._prepare_train_buffer(rollouts, macro_batch_size, timing)
            self.experience_buffer_queue.put((buffer, batch_size, samples, env_steps))

            if not self.cfg.benchmark and self.cfg.train_in_background_thread:
                # in PyTorch 1.4.0 there is an intense memory spike when the very first batch is being processed
                # we wait here until this is over so we can continue queueing more batches onto a GPU without having
                # a risk to run out of GPU memory
                while self.num_batches_processed < 1:
                    log.debug('Waiting for the first batch to be processed')
                    time.sleep(0.5)

    def _process_rollouts(self, rollouts, timing):
        # batch_size can potentially change through PBT, so we should keep it the same and pass it around
        # using function arguments, instead of using global self.cfg

        batch_size = self.cfg.batch_size
        rollouts_in_macro_batch = self._macro_batch_size(batch_size) // self.cfg.rollout

        if len(rollouts) < rollouts_in_macro_batch:
            return rollouts

        discard_rollouts = 0
        policy_version = self.train_step
        for r in rollouts:
            rollout_min_version = r['t']['policy_version'].min().item()
            if policy_version - rollout_min_version >= self.cfg.max_policy_lag:
                discard_rollouts += 1
                self._mark_rollout_buffer_free(r)
            else:
                break

        if discard_rollouts > 0:
            log.warning(
                'Discarding %d old rollouts, cut by policy lag threshold %d (learner %d)',
                discard_rollouts, self.cfg.max_policy_lag, self.policy_id,
            )
            rollouts = rollouts[discard_rollouts:]
            self.num_discarded_rollouts += discard_rollouts

        if len(rollouts) >= rollouts_in_macro_batch:
            # process newest rollouts
            rollouts_to_process = rollouts[:rollouts_in_macro_batch]
            rollouts = rollouts[rollouts_in_macro_batch:]

            self._process_macro_batch(rollouts_to_process, batch_size, timing)
            # log.info('Unprocessed rollouts: %d (%d samples)', len(rollouts), len(rollouts) * self.cfg.rollout)

        return rollouts

    def _get_minibatches(self, batch_size, experience_size):
        """Generating minibatches for training."""
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert experience_size % batch_size == 0, f'experience size: {experience_size}, batch size: {batch_size}'

        if self.cfg.num_batches_per_iteration == 1:
            return [None]  # single minibatch is actually the entire buffer, we don't need indices

        # indices that will start the mini-trajectories from the same episode (for bptt)
        indices = np.arange(0, experience_size, self.cfg.recurrence)
        indices = np.random.permutation(indices)

        # complete indices of mini trajectories, e.g. with recurrence==4: [4, 16] -> [4, 5, 6, 7, 16, 17, 18, 19]
        indices = [np.arange(i, i + self.cfg.recurrence) for i in indices]
        indices = np.concatenate(indices)

        assert len(indices) == experience_size

        num_minibatches = experience_size // batch_size
        minibatches = np.split(indices, num_minibatches)
        return minibatches

    @staticmethod
    def _get_minibatch(buffer, indices):
        if indices is None:
            # handle the case of a single batch, where the entire buffer is a minibatch
            return buffer

        mb = AttrDict()

        for item, x in buffer.items():
            if isinstance(x, (dict, OrderedDict)):
                mb[item] = AttrDict()
                for key, x_elem in x.items():
                    mb[item][key] = x_elem[indices]
            else:
                mb[item] = x[indices]

        return mb

    def _should_save_summaries(self):
        summaries_every_seconds = self.summary_rate_decay_seconds.at(self.train_step)
        if time.time() - self.last_summary_time < summaries_every_seconds:
            return False

        return True

    def _after_optimizer_step(self):
        """A hook to be called after each optimizer step."""
        self.train_step += 1
        self._maybe_save()

    def _maybe_save(self):
        if time.time() - self.last_saved_time >= self.cfg.save_every_sec or self.should_save_model:
            self._save()
            self.model_saved_event.set()
            self.should_save_model = False
            self.last_saved_time = time.time()

    @staticmethod
    def checkpoint_dir(cfg, policy_id):
        checkpoint_dir = join(experiment_dir(cfg=cfg), f'checkpoint_p{policy_id}')
        return ensure_dir_exists(checkpoint_dir)

    @staticmethod
    def get_checkpoints(checkpoints_dir):
        checkpoints = glob.glob(join(checkpoints_dir, 'checkpoint_*'))
        return sorted(checkpoints)

    def _get_checkpoint_dict(self):
        checkpoint = {
            'train_step': self.train_step,
            'env_steps': self.env_steps,
            'model': self.dqn.main.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        return checkpoint

    def _save(self):
        checkpoint = self._get_checkpoint_dict()
        assert checkpoint is not None

        checkpoint_dir = self.checkpoint_dir(self.cfg, self.policy_id)
        tmp_filepath = join(checkpoint_dir, '.temp_checkpoint')
        checkpoint_name = f'checkpoint_{self.train_step:09d}_{self.env_steps}.pth'
        filepath = join(checkpoint_dir, checkpoint_name)
        log.info('Saving %s...', tmp_filepath)
        torch.save(checkpoint, tmp_filepath)
        log.info('Renaming %s to %s', tmp_filepath, filepath)
        os.rename(tmp_filepath, filepath)

        while len(self.get_checkpoints(checkpoint_dir)) > self.cfg.keep_checkpoints:
            oldest_checkpoint = self.get_checkpoints(checkpoint_dir)[0]
            if os.path.isfile(oldest_checkpoint):
                log.debug('Removing %s', oldest_checkpoint)
                os.remove(oldest_checkpoint)

        if self.cfg.save_milestones_sec > 0:
            # milestones enabled
            if time.time() - self.last_milestone_time >= self.cfg.save_milestones_sec:
                milestones_dir = ensure_dir_exists(join(checkpoint_dir, 'milestones'))
                milestone_path = join(milestones_dir, f'{checkpoint_name}.milestone')
                log.debug('Saving a milestone %s', milestone_path)
                shutil.copy(filepath, milestone_path)
                self.last_milestone_time = time.time()

    def _prepare_observations(self, obs_tensors, gpu_buffer_obs):
        for d, gpu_d, k, v, _ in iter_dicts_recursively(obs_tensors, gpu_buffer_obs):
            device, dtype = self.actor_critic.device_and_type_for_input_tensor(k) #TODO
            tensor = v.detach().to(device, copy=True).type(dtype)
            gpu_d[k] = tensor

    def _copy_train_data_to_device(self, buffer):
        device_buffer = copy_dict_structure(buffer)

        for key, item in buffer.items():
            if key == 'obs':
                self._prepare_observations(item, device_buffer['obs'])
            else:
                device_tensor = item.detach().to(self.device, copy=True, non_blocking=True)
                device_buffer[key] = device_tensor.float()

        device_buffer['dones_cpu'] = buffer.dones.to('cpu', copy=True, non_blocking=True).float()
        device_buffer['rewards_cpu'] = buffer.rewards.to('cpu', copy=True, non_blocking=True).float()

        return device_buffer

    def _train(self, gpu_buffer, batch_size, experience_size, timing):
        with torch.no_grad():
            early_stopping_tolerance = 1e-6
            early_stop = False
            prev_epoch_actor_loss = 1e9
            epoch_actor_losses = []

            # V-trace parameters
            # noinspection PyArgumentList
            rho_hat = torch.Tensor([self.cfg.vtrace_rho])
            # noinspection PyArgumentList
            c_hat = torch.Tensor([self.cfg.vtrace_c])

            clip_ratio_high = 1.0 + self.cfg.ppo_clip_ratio  # e.g. 1.1
            # this still works with e.g. clip_ratio = 2, while PPO's 1-r would give negative ratio
            clip_ratio_low = 1.0 / clip_ratio_high

            clip_value = self.cfg.ppo_clip_value
            gamma = self.cfg.gamma
            recurrence = self.cfg.recurrence

            if self.cfg.with_vtrace:
                assert recurrence == self.cfg.rollout and recurrence > 1, \
                    'V-trace requires to recurrence and rollout to be equal'

            num_sgd_steps = 0

            stats_and_summaries = None
            if not self.with_training:
                return stats_and_summaries

        for epoch in range(self.cfg.ppo_epochs):
            with timing.add_time('epoch_init'):
                if early_stop or self.terminate:
                    break

                summary_this_epoch = force_summaries = False

                minibatches = self._get_minibatches(batch_size, experience_size)

            for batch_num in range(len(minibatches)):
                with timing.add_time('minibatch_init'):
                    indices = minibatches[batch_num]

                    # current minibatch consisting of short trajectory segments with length == recurrence
                    mb = self._get_minibatch(gpu_buffer, indices)

                # calculate policy head outside of recurrent loop
                with timing.add_time('forward_head'):
                    head_outputs = self.actor_critic.forward_head(mb.obs)

                # initial rnn states
                with timing.add_time('bptt_initial'):
                    if self.cfg.use_rnn:
                        head_output_seq, rnn_states, inverted_select_inds = build_rnn_inputs(
                            head_outputs, mb.dones_cpu, mb.rnn_states, recurrence,
                        )
                    else:
                        rnn_states = mb.rnn_states[::recurrence]

                # calculate RNN outputs for each timestep in a loop
                with timing.add_time('bptt'):
                    if self.cfg.use_rnn:
                        with timing.add_time('bptt_forward_core'):
                            core_output_seq, _ = self.actor_critic.forward_core(head_output_seq, rnn_states)
                        core_outputs = build_core_out_from_seq(core_output_seq, inverted_select_inds)
                    else:
                        core_outputs, _ = self.actor_critic.forward_core(head_outputs, rnn_states)

                with timing.add_time('tail'):
                    assert core_outputs.shape[0] == head_outputs.shape[0]

                    # calculate policy tail outside of recurrent loop
                    result = self.actor_critic.forward_tail(core_outputs, with_action_distribution=True)

                    action_distribution = result.action_distribution
                    log_prob_actions = action_distribution.log_prob(mb.actions)
                    ratio = torch.exp(log_prob_actions - mb.log_prob_actions)  # pi / pi_old

                    # super large/small values can cause numerical problems and are probably noise anyway
                    ratio = torch.clamp(ratio, 0.05, 20.0)

                    values = result.values.squeeze()

                num_trajectories = head_outputs.size(0) // recurrence
                with torch.no_grad():  # these computations are not the part of the computation graph
                    if self.cfg.with_vtrace:
                        ratios_cpu = ratio.cpu()
                        values_cpu = values.cpu()
                        rewards_cpu = mb.rewards_cpu
                        dones_cpu = mb.dones_cpu

                        vtrace_rho = torch.min(rho_hat, ratios_cpu)
                        vtrace_c = torch.min(c_hat, ratios_cpu)

                        vs = torch.zeros((num_trajectories * recurrence))
                        adv = torch.zeros((num_trajectories * recurrence))

                        next_values = (values_cpu[recurrence - 1::recurrence] - rewards_cpu[recurrence - 1::recurrence]) / gamma
                        next_vs = next_values

                        with timing.add_time('vtrace'):
                            for i in reversed(range(self.cfg.recurrence)):
                                rewards = rewards_cpu[i::recurrence]
                                dones = dones_cpu[i::recurrence]
                                not_done = 1.0 - dones
                                not_done_times_gamma = not_done * gamma

                                curr_values = values_cpu[i::recurrence]
                                curr_vtrace_rho = vtrace_rho[i::recurrence]
                                curr_vtrace_c = vtrace_c[i::recurrence]

                                delta_s = curr_vtrace_rho * (rewards + not_done_times_gamma * next_values - curr_values)
                                adv[i::recurrence] = curr_vtrace_rho * (rewards + not_done_times_gamma * next_vs - curr_values)
                                next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (next_vs - next_values)
                                vs[i::recurrence] = next_vs

                                next_values = curr_values

                        targets = vs
                    else:
                        # using regular GAE
                        adv = mb.advantages
                        targets = mb.returns

                    adv_mean = adv.mean()
                    adv_std = adv.std()
                    adv = (adv - adv_mean) / max(1e-3, adv_std)  # normalize advantage
                    adv = adv.to(self.device)

                with timing.add_time('losses'):
                    policy_loss = self._policy_loss(ratio, adv, clip_ratio_low, clip_ratio_high)

                    exploration_loss = self.exploration_loss_func(action_distribution)
                    actor_loss = policy_loss + exploration_loss
                    epoch_actor_losses.append(actor_loss.item())

                    targets = targets.to(self.device)
                    old_values = mb.values
                    value_loss = self._value_loss(values, old_values, targets, clip_value)
                    critic_loss = value_loss

                    loss = actor_loss + critic_loss

                    high_loss = 30.0
                    if abs(to_scalar(policy_loss)) > high_loss or abs(to_scalar(value_loss)) > high_loss or abs(to_scalar(exploration_loss)) > high_loss:
                        log.warning(
                            'High loss value: %.4f %.4f %.4f %.4f (recommended to adjust the --reward_scale parameter)',
                            to_scalar(loss), to_scalar(policy_loss), to_scalar(value_loss), to_scalar(exploration_loss),
                        )
                        force_summaries = True

                with timing.add_time('update'):
                    # update the weights
                    self.optimizer.zero_grad()
                    loss.backward()

                    if self.cfg.max_grad_norm > 0.0:
                        with timing.add_time('clip'):
                            torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.cfg.max_grad_norm)

                    curr_policy_version = self.train_step  # policy version before the weight update
                    with self.policy_lock:
                        self.optimizer.step()

                    num_sgd_steps += 1

                with torch.no_grad():
                    with timing.add_time('after_optimizer'):
                        self._after_optimizer_step()

                        # collect and report summaries
                        with_summaries = self._should_save_summaries() or force_summaries
                        if with_summaries and not summary_this_epoch:
                            stats_and_summaries = self._record_summaries(AttrDict(locals()))
                            summary_this_epoch = True
                            force_summaries = False

            # end of an epoch
            # this will force policy update on the inference worker (policy worker)
            self.policy_versions[self.policy_id] = self.train_step

            new_epoch_actor_loss = np.mean(epoch_actor_losses)
            loss_delta_abs = abs(prev_epoch_actor_loss - new_epoch_actor_loss)
            if loss_delta_abs < early_stopping_tolerance:
                early_stop = True
                log.debug(
                    'Early stopping after %d epochs (%d sgd steps), loss delta %.7f',
                    epoch + 1, num_sgd_steps, loss_delta_abs,
                )
                break

            prev_epoch_actor_loss = new_epoch_actor_loss
            epoch_actor_losses = []

        return stats_and_summaries

    def _record_summaries(self, train_loop_vars):
        var = train_loop_vars

        self.last_summary_time = time.time()
        stats = AttrDict()

        grad_norm = sum(
            p.grad.data.norm(2).item() ** 2
            for p in self.actor_critic.parameters()
            if p.grad is not None
        ) ** 0.5
        stats.grad_norm = grad_norm
        stats.loss = var.loss
        stats.value = var.result.values.mean()
        stats.entropy = var.action_distribution.entropy().mean()
        stats.policy_loss = var.policy_loss
        stats.value_loss = var.value_loss
        stats.exploration_loss = var.exploration_loss
        stats.adv_min = var.adv.min()
        stats.adv_max = var.adv.max()
        stats.adv_std = var.adv_std
        stats.max_abs_logprob = torch.abs(var.mb.action_logits).max()

        if hasattr(var.action_distribution, 'summaries'):
            stats.update(var.action_distribution.summaries())

        if var.epoch == self.cfg.ppo_epochs - 1 and var.batch_num == len(var.minibatches) - 1:
            # we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style
            ratio_mean = torch.abs(1.0 - var.ratio).mean().detach()
            ratio_min = var.ratio.min().detach()
            ratio_max = var.ratio.max().detach()
            # log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item())

            value_delta = torch.abs(var.values - var.old_values)
            value_delta_avg, value_delta_max = value_delta.mean(), value_delta.max()

            # calculate KL-divergence with the behaviour policy action distribution
            old_action_distribution = get_action_distribution(
                self.actor_critic.action_space, var.mb.action_logits,
            )
            kl_old = var.action_distribution.kl_divergence(old_action_distribution)
            kl_old_mean = kl_old.mean()

            stats.kl_divergence = kl_old_mean
            stats.value_delta = value_delta_avg
            stats.value_delta_max = value_delta_max
            stats.fraction_clipped = ((var.ratio < var.clip_ratio_low).float() + (var.ratio > var.clip_ratio_high).float()).mean()
            stats.ratio_mean = ratio_mean
            stats.ratio_min = ratio_min
            stats.ratio_max = ratio_max
            stats.num_sgd_steps = var.num_sgd_steps

        # this caused numerical issues on some versions of PyTorch with second moment reaching infinity
        adam_max_second_moment = 0.0
        for key, tensor_state in self.optimizer.state.items():
            adam_max_second_moment = max(tensor_state['exp_avg_sq'].max().item(), adam_max_second_moment)
        stats.adam_max_second_moment = adam_max_second_moment

        version_diff = var.curr_policy_version - var.mb.policy_version
        stats.version_diff_avg = version_diff.mean()
        stats.version_diff_min = version_diff.min()
        stats.version_diff_max = version_diff.max()

        for key, value in stats.items():
            stats[key] = to_scalar(value)

        return stats

    # def _update_pbt(self):
    #     """To be called from the training loop, same thread that updates the model!"""
    #     with self.pbt_mutex:
    #         if self.load_policy_id is not None:
    #             assert self.cfg.with_pbt

    #             log.debug('Learner %d loads policy from %d', self.policy_id, self.load_policy_id)
    #             self.load_from_checkpoint(self.load_policy_id)
    #             self.load_policy_id = None

    #         if self.new_cfg is not None:
    #             for key, value in self.new_cfg.items():
    #                 if self.cfg[key] != value:
    #                     log.debug('Learner %d replacing cfg parameter %r with new value %r', self.policy_id, key, value)
    #                     self.cfg[key] = value

    #             for param_group in self.optimizer.param_groups:
    #                 param_group['lr'] = self.cfg.learning_rate
    #                 param_group['betas'] = (self.cfg.adam_beta1, self.cfg.adam_beta2)
    #                 log.debug('Updated optimizer lr to value %.7f, betas: %r', param_group['lr'], param_group['betas'])

    #             self.new_cfg = None

    @staticmethod
    def load_checkpoint(checkpoints, device):
        if len(checkpoints) <= 0:
            log.warning('No checkpoints found')
            return None
        else:
            latest_checkpoint = checkpoints[-1]

            # extra safety mechanism to recover from spurious filesystem errors
            num_attempts = 3
            for attempt in range(num_attempts):
                try:
                    log.warning('Loading state from checkpoint %s...', latest_checkpoint)
                    checkpoint_dict = torch.load(latest_checkpoint, map_location=device)
                    return checkpoint_dict
                except Exception:
                    log.exception(f'Could not load from checkpoint, attempt {attempt}')

    def _load_state(self, checkpoint_dict, load_progress=True):
        if load_progress:
            self.train_step = checkpoint_dict['train_step']
            self.env_steps = checkpoint_dict['env_steps']
        self.actor_critic.load_state_dict(checkpoint_dict['model'])
        self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
        log.info('Loaded experiment state at training iteration %d, env step %d', self.train_step, self.env_steps)

    def init_model(self, timing):
        self.actor_critic = create_actor_critic(self.cfg, self.obs_space, self.action_space, timing)
        self.actor_critic.model_to_device(self.device)
        self.actor_critic.share_memory()

    def load_from_checkpoint(self, policy_id):
        checkpoints = self.get_checkpoints(self.checkpoint_dir(self.cfg, policy_id))
        checkpoint_dict = self.load_checkpoint(checkpoints, self.device)
        if checkpoint_dict is None:
            log.debug('Did not load from checkpoint, starting from scratch!')
        else:
            log.debug('Loading model from checkpoint')

            # if we're replacing our policy with another policy (under PBT), let's not reload the env_steps
            load_progress = policy_id == self.policy_id
            self._load_state(checkpoint_dict, load_progress=load_progress)

    def initialize(self, timing):
        with timing.timeit('init'):
            # initialize the Torch modules
            if self.cfg.seed is None:
                log.info('Starting seed is not provided')
            else:
                log.info('Setting fixed seed %d', self.cfg.seed)
                torch.manual_seed(self.cfg.seed)
                np.random.seed(self.cfg.seed)

            # this does not help with a single experiment
            # but seems to do better when we're running more than one experiment in parallel
            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                torch.backends.cudnn.benchmark = True

                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')
            self.init_model(timing)

            self.optimizer = torch.optim.Adam(
                self.actor_critic.parameters(),
                self.cfg.learning_rate,
                betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
                eps=self.cfg.adam_eps,
            )

            self.load_from_checkpoint(self.policy_id)

            self._broadcast_model_weights()  # sync the very first version of the weights

        self.train_thread_initialized.set()

    def _process_training_data(self, data, timing, wait_stats=None):
        self.is_training = True

        buffer, batch_size, samples, env_steps = data
        assert samples == batch_size * self.cfg.num_batches_per_iteration

        self.env_steps += env_steps
        experience_size = buffer.rewards.shape[0]

        stats = dict(learner_env_steps=self.env_steps, policy_id=self.policy_id)

        with timing.add_time('train'):
            discarding_rate = self._discarding_rate()

            self._update_pbt()

            train_stats = self._train(buffer, batch_size, experience_size, timing)

            if train_stats is not None:
                stats['train'] = train_stats

                if wait_stats is not None:
                    wait_avg, wait_min, wait_max = wait_stats
                    stats['train']['wait_avg'] = wait_avg
                    stats['train']['wait_min'] = wait_min
                    stats['train']['wait_max'] = wait_max

                stats['train']['discarded_rollouts'] = self.num_discarded_rollouts
                stats['train']['discarding_rate'] = discarding_rate

                stats['stats'] = memory_stats('learner', self.device)

        self.is_training = False

        try:
            self.report_queue.put(stats)
        except Full:
            log.warning('Could not report training stats, the report queue is full!')

    def _train_loop(self):
        timing = Timing()
        self.initialize(timing)

        wait_times = deque([], maxlen=self.cfg.num_workers)
        last_cache_cleanup = time.time()

        while not self.terminate:
            with timing.timeit('train_wait'):
                data = safe_get(self.experience_buffer_queue)

            if self.terminate:
                break

            wait_stats = None
            wait_times.append(timing.train_wait)

            if len(wait_times) >= wait_times.maxlen:
                wait_times_arr = np.asarray(wait_times)
                wait_avg = np.mean(wait_times_arr)
                wait_min, wait_max = wait_times_arr.min(), wait_times_arr.max()
                # log.debug(
                #     'Training thread had to wait %.5f s for the new experience buffer (avg %.5f)',
                #     timing.train_wait, wait_avg,
                # )
                wait_stats = (wait_avg, wait_min, wait_max)

            self._process_training_data(data, timing, wait_stats)
            self.num_batches_processed += 1

            if time.time() - last_cache_cleanup > 300.0 or (not self.cfg.benchmark and self.num_batches_processed < 50):
                if self.cfg.device == 'gpu':
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                last_cache_cleanup = time.time()

        time.sleep(0.3)
        log.info('Train loop timing: %s', timing)
        del self.actor_critic
        del self.device

    def _experience_collection_rate_stats(self):
        now = time.time()
        if now - self.discarded_experience_timer > 1.0:
            self.discarded_experience_timer = now
            self.discarded_experience_over_time.append((now, self.num_discarded_rollouts))

    def _discarding_rate(self):
        if len(self.discarded_experience_over_time) <= 1:
            return 0

        first, last = self.discarded_experience_over_time[0], self.discarded_experience_over_time[-1]
        delta_rollouts = last[1] - first[1]
        delta_time = last[0] - first[0]
        discarding_rate = delta_rollouts / (delta_time + EPS)
        return discarding_rate

    def _extract_rollouts(self, data):
        data = AttrDict(data)
        worker_idx, split_idx, traj_buffer_idx = data.worker_idx, data.split_idx, data.traj_buffer_idx

        rollouts = []
        for rollout_data in data.rollouts:
            env_idx, agent_idx = rollout_data['env_idx'], rollout_data['agent_idx']
            tensors = self.rollout_tensors.index((worker_idx, split_idx, env_idx, agent_idx, traj_buffer_idx))

            rollout_data['t'] = tensors
            rollout_data['worker_idx'] = worker_idx
            rollout_data['split_idx'] = split_idx
            rollout_data['traj_buffer_idx'] = traj_buffer_idx
            rollouts.append(AttrDict(rollout_data))

        return rollouts

    # def _process_pbt_task(self, pbt_task):
    #     task_type, data = pbt_task

    #     with self.pbt_mutex:
    #         if task_type == PbtTask.SAVE_MODEL:
    #             policy_id = data
    #             assert policy_id == self.policy_id
    #             self.should_save_model = True
    #         elif task_type == PbtTask.LOAD_MODEL:
    #             policy_id, new_policy_id = data
    #             assert policy_id == self.policy_id
    #             assert new_policy_id is not None
    #             self.load_policy_id = new_policy_id
    #         elif task_type == PbtTask.UPDATE_CFG:
    #             policy_id, new_cfg = data
    #             assert policy_id == self.policy_id
    #             self.new_cfg = new_cfg

    def _accumulated_too_much_experience(self, rollouts):
        max_minibatches_to_accumulate = self.cfg.num_minibatches_to_accumulate
        if max_minibatches_to_accumulate == -1:
            # default value
            max_minibatches_to_accumulate = 2 * self.cfg.num_batches_per_iteration

        # allow the max batches to accumulate, plus the minibatches we're currently training on
        max_minibatches_on_learner = max_minibatches_to_accumulate + self.cfg.num_batches_per_iteration

        minibatches_currently_training = int(self.is_training) * self.cfg.num_batches_per_iteration

        rollouts_per_minibatch = self.cfg.batch_size / self.cfg.rollout

        # count contribution from unprocessed rollouts
        minibatches_currently_accumulated = len(rollouts) / rollouts_per_minibatch

        # count minibatches ready for training
        minibatches_currently_accumulated += self.experience_buffer_queue.qsize() * self.cfg.num_batches_per_iteration

        total_minibatches_on_learner = minibatches_currently_training + minibatches_currently_accumulated

        return total_minibatches_on_learner >= max_minibatches_on_learner

    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        try:
            psutil.Process().nice(self.cfg.default_niceness)
        except psutil.AccessDenied:
            log.error('Low niceness requires sudo!')

        if self.cfg.device == 'gpu':
            cuda_envvars_for_policy(self.policy_id, 'learner')

        torch.multiprocessing.set_sharing_strategy('file_system')
        torch.set_num_threads(self.cfg.learner_main_loop_num_cores)

        timing = Timing()

        rollouts = []

        if self.train_in_background:
            self.training_thread.start()
        else:
            self.initialize(timing)
            log.error(
                'train_in_background set to False on learner %d! This is slow, use only for testing!', self.policy_id,
            )

        while not self.terminate:
            while True:
                try:
                    tasks = self.task_queue.get_many(timeout=0.005)

                    for task_type, data in tasks:
                        if task_type == TaskType.TRAIN:
                            with timing.add_time('extract'):
                                rollouts.extend(self._extract_rollouts(data))
                                # log.debug('Learner %d has %d rollouts', self.policy_id, len(rollouts))
                        elif task_type == TaskType.INIT:
                            self._init()
                        elif task_type == TaskType.TERMINATE:
                            time.sleep(0.3)
                            log.info('GPU learner timing: %s', timing)
                            self._terminate()
                            break
                        elif task_type == TaskType.PBT:
                            self._process_pbt_task(data)
                except Empty:
                    break

            if self._accumulated_too_much_experience(rollouts):
                # if we accumulated too much experience, signal the policy workers to stop experience collection
                if not self.stop_experience_collection[self.policy_id]:
                    self.stop_experience_collection_num_msgs += 1
                    # TODO: add a logger function for this
                    if self.stop_experience_collection_num_msgs >= 50:
                        log.info(
                            'Learner %d accumulated too much experience, stop experience collection! '
                            'Learner is likely a bottleneck in your experiment (%d times)',
                            self.policy_id, self.stop_experience_collection_num_msgs,
                        )
                        self.stop_experience_collection_num_msgs = 0

                self.stop_experience_collection[self.policy_id] = True
            elif self.stop_experience_collection[self.policy_id]:
                # otherwise, resume the experience collection if it was stopped
                self.stop_experience_collection[self.policy_id] = False
                with self.resume_experience_collection_cv:
                    self.resume_experience_collection_num_msgs += 1
                    if self.resume_experience_collection_num_msgs >= 50:
                        log.debug('Learner %d is resuming experience collection!', self.policy_id)
                        self.resume_experience_collection_num_msgs = 0
                    self.resume_experience_collection_cv.notify_all()

            with torch.no_grad():
                rollouts = self._process_rollouts(rollouts, timing)

            if not self.train_in_background:
                while not self.experience_buffer_queue.empty():
                    training_data = self.experience_buffer_queue.get()
                    self._process_training_data(training_data, timing)

            self._experience_collection_rate_stats()

        if self.train_in_background:
            self.experience_buffer_queue.put(None)
            self.training_thread.join()

    def init(self):
        self.task_queue.put((TaskType.INIT, None))
        self.initialized_event.wait()

    def save_model(self, timeout=None):
        self.model_saved_event.clear()
        save_task = (PbtTask.SAVE_MODEL, self.policy_id)
        self.task_queue.put((TaskType.PBT, save_task))
        log.debug('Wait while learner %d saves the model...', self.policy_id)
        if self.model_saved_event.wait(timeout=timeout):
            log.debug('Learner %d saved the model!', self.policy_id)
        else:
            log.warning('Model saving request timed out!')
        self.model_saved_event.clear()

    def close(self):
        self.task_queue.put((TaskType.TERMINATE, None))

    def join(self):
        join_or_kill(self.process)