def multiprocess_training_loader(process_number: int, _config,
                                 _queue: mp.Queue, _wait_for_exit: mp.Event,
                                 _local_file, _fasttext_vocab_cached_mapping,
                                 _fasttext_vocab_cached_data):

    # workflow: we tokenize the data files with the costly spacy before training in a preprocessing step
    # (and concat the tokens with single whitespaces), so here we only split on the whitepsaces
    _tokenizer = None
    if _config["preprocessed_tokenized"] == True:
        _tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter())

    if _config["token_embedder_type"] == "embedding":
        _token_indexers = {
            "tokens": SingleIdTokenIndexer(lowercase_tokens=True)
        }
        _vocab = Vocabulary.from_files(_config["vocab_directory"])

    elif _config["token_embedder_type"] == "fasttext":
        _token_indexers = {
            "tokens": FastTextNGramIndexer(_config["fasttext_max_subwords"])
        }
        _vocab = FastTextVocab(_fasttext_vocab_cached_mapping,
                               _fasttext_vocab_cached_data,
                               _config["fasttext_max_subwords"])

    elif _config["token_embedder_type"] == "elmo":
        _token_indexers = {"tokens": ELMoTokenCharactersIndexer()}
        _vocab = None

    _triple_loader = IrTripleDatasetReader(
        lazy=True,
        tokenizer=_tokenizer,
        token_indexers=_token_indexers,
        max_doc_length=_config["max_doc_length"],
        max_query_length=_config["max_query_length"])

    _iterator = BucketIterator(batch_size=int(_config["batch_size_train"]),
                               sorting_keys=[("doc_pos_tokens", "num_tokens"),
                                             ("doc_neg_tokens", "num_tokens")])

    _iterator.index_with(_vocab)

    for training_batch in _iterator(_triple_loader.read(_local_file),
                                    num_epochs=1):

        _queue.put(
            training_batch)  # this moves the tensors in to shared memory

    _queue.close()  # indicate this local thread is done
    _wait_for_exit.wait(
    )  # keep this process alive until all the shared memory is used and not needed anymore
예제 #2
0
class MultimodalPatchesCache(object):
    def __init__(self,
                 cache_dir,
                 dataset_dir,
                 dataset_list,
                 cuda,
                 batch_size=500,
                 num_workers=3,
                 renew_frequency=5,
                 rejection_radius_position=0,
                 numpatches=900,
                 numneg=3,
                 pos_thr=50.0,
                 reject=True,
                 mode='train',
                 rejection_radius=3000,
                 dist_type='3D',
                 patch_radius=None,
                 use_depth=False,
                 use_normals=False,
                 use_silhouettes=False,
                 color_jitter=False,
                 greyscale=False,
                 maxres=4096,
                 scale_jitter=False,
                 photo_jitter=False,
                 uniform_negatives=False,
                 needles=0,
                 render_only=False,
                 maxitems=200,
                 cache_once=False):
        super(MultimodalPatchesCache, self).__init__()
        self.cache_dir = cache_dir
        self.dataset_dir = dataset_dir
        #self.images_path = images_path
        self.dataset_list = dataset_list
        self.cuda = cuda
        self.batch_size = batch_size

        self.num_workers = num_workers
        self.renew_frequency = renew_frequency
        self.rejection_radius_position = rejection_radius_position
        self.numpatches = numpatches
        self.numneg = numneg
        self.pos_thr = pos_thr
        self.reject = reject
        self.mode = mode
        self.rejection_radius = rejection_radius
        self.dist_type = dist_type
        self.patch_radius = patch_radius
        self.use_depth = use_depth
        self.use_normals = use_normals
        self.use_silhouettes = use_silhouettes
        self.color_jitter = color_jitter
        self.greyscale = greyscale
        self.maxres = maxres
        self.scale_jitter = scale_jitter
        self.photo_jitter = photo_jitter
        self.uniform_negatives = uniform_negatives
        self.needles = needles
        self.render_only = render_only

        self.cache_done_lock = Lock()
        self.all_done = Value('B', 0)  # 0 is False
        self.cache_done = Value('B', 0)  # 0 is False

        self.wait_for_cache_builder = Event()
        # prepare for wait until initial cache is built
        self.wait_for_cache_builder.clear()
        self.cache_builder_resume = Event()

        self.maxitems = maxitems
        self.cache_once = cache_once

        if self.mode == 'eval':
            self.maxitems = -1
        self.cache_builder = Process(target=self.buildCache,
                                     args=[self.maxitems])
        self.current_cache_build = Value('B', 0)  # 0th cache
        self.current_cache_use = Value('B', 1)  # 1th cache

        self.cache_names = ["cache1", "cache2"]  # constant

        rebuild_cache = True
        if self.mode == 'eval':
            validation_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(validation_dir):
                # we don't need to rebuild validation cache
                # TODO: check if cache is VALID
                rebuild_cache = False
        elif cache_once:
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                # we don't need to rebuild training cache if we are training
                # on limited subset of the training set
                rebuild_cache = False

        if rebuild_cache:
            # clear the caches if they already exist
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            use_dataset_dir = os.path.join(
                self.cache_dir, self.cache_names[self.current_cache_use.value])
            if os.path.isdir(use_dataset_dir):
                shutil.rmtree(use_dataset_dir)

            os.makedirs(build_dataset_dir)

            self.cache_builder_resume.set()
            self.cache_builder.start()

            # wait until initial cache is built
            # print("before wait to build")
            # print("wait for cache builder state",
            #       self.wait_for_cache_builder.is_set())
            self.wait_for_cache_builder.wait()
            # print("after wait to build")

        # we have been resumed
        if self.mode != 'eval' and (not self.cache_once):
            # for training, we can set up the cache builder to build
            # the second cache
            self.restart()
        else:
            # else for validation we don't need second cache
            # we just need to switch the built cache to the use cache in order
            # to use it
            tmp = self.current_cache_build.value
            self.current_cache_build.value = self.current_cache_use.value
            self.current_cache_use.value = tmp

        # initialization finished, now this dataset can be used

    def getCurrentCache(self):
        # Lock should not be needed - cache_done is not touched
        # and cache_len is read only for cache in use, which should not
        # been touched by other threads
        # self.cache_done_lock.acquire()
        h5_dataset_filename = os.path.join(
            self.cache_dir, self.cache_names[self.current_cache_use.value])
        # self.cache_done_lock.release()
        return h5_dataset_filename

    def restart(self):
        # print("Restarting - waiting for lock...")
        self.cache_done_lock.acquire()
        # print("Restarting cached dataset...")
        if self.cache_done.value and (not self.cache_once):
            cache_changed = True
            tmp_cache_name = self.current_cache_use.value
            self.current_cache_use.value = self.current_cache_build.value
            self.current_cache_build.value = tmp_cache_name
            # clear the old cache if exists
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            os.makedirs(build_dataset_dir)
            self.cache_done.value = 0  # 0 is False
            self.cache_builder_resume.set()
            # print("Switched cache to: ",
            #       self.cache_names[self.current_cache_use.value]
            # )
        else:
            cache_changed = False
            # print(
            #     "New cache not ready, continuing with old cache:",
            #     self.cache_names[self.current_cache_use.value]
            # )
        all_done_value = self.all_done.value
        self.cache_done_lock.release()
        # returns true if no more items are available to be loaded
        # this object should be destroyed and new dataset should be created
        # in order to start over.
        return cache_changed, all_done_value

    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None

    @staticmethod
    def loadBatch(sample_batched, mode, device, keep_all=False):
        if mode == 'eval':
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[10]
            item_id = sample_batched[11]
        else:
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[9]
            item_id = sample_batched[10]
        if keep_all:
            # requested to return fill batch
            batchsize = sample_batched[0].shape[0]
            keep = torch.ones(batchsize).byte()
        keep = keep.reshape(-1)
        keep = keep.bool()
        anchor = sample_batched[0]
        pos = sample_batched[1]
        neg = sample_batched[2]

        # swapped photo to render
        anchor_r = sample_batched[3]
        pos_p = sample_batched[4]
        neg_p = sample_batched[5]

        anchor = anchor[keep].to(device)
        pos = pos[keep].to(device)
        neg = neg[keep].to(device)

        anchor_r = anchor_r[keep]
        pos_p = pos_p[keep]
        neg_p = neg_p[keep]

        coords1 = coords1[keep]
        coords2 = coords2[keep]
        coords_neg = coords_neg[keep]
        item_id = item_id[keep]
        return anchor, pos, neg, anchor_r, pos_p, neg_p, coords1, coords2, \
            coords_neg, item_id

    @staticmethod
    def generateTrainingData(queue,
                             dataset,
                             dataloader,
                             batch_size,
                             qmaxsize,
                             resume,
                             shuffle=True,
                             disable_tqdm=False):
        local_buffer_a = []
        local_buffer_p = []
        local_buffer_n = []

        local_buffer_ar = []
        local_buffer_pp = []
        local_buffer_np = []

        local_buffer_c1 = []
        local_buffer_c2 = []
        local_buffer_cneg = []
        local_buffer_id = []
        nbatches = 10
        # cannot load to cuda in batckground process!
        device = torch.device('cpu')

        buffer_size = min(qmaxsize * batch_size, nbatches * batch_size)
        bidx = 0
        for i_batch, sample_batched in enumerate(dataloader):
            # tqdm(dataloader, disable=disable_tqdm)
            resume.wait()
            anchor, pos, neg, anchor_r, \
                pos_p, neg_p, c1, c2, cneg, id = \
                MultimodalPatchesCache.loadBatch(
                    sample_batched, dataset.mode, device
                )
            if anchor.shape[0] == 0:
                continue
            local_buffer_a.extend(list(anchor))  # [:current_batches]
            local_buffer_p.extend(list(pos))
            local_buffer_n.extend(list(neg))

            local_buffer_ar.extend(list(anchor_r))
            local_buffer_pp.extend(list(pos_p))
            local_buffer_np.extend(list(neg_p))

            local_buffer_c1.extend(list(c1))
            local_buffer_c2.extend(list(c2))
            local_buffer_cneg.extend(list(cneg))
            local_buffer_id.extend(list(id))
            if len(local_buffer_a) >= buffer_size:
                if shuffle:
                    local_buffer_a, local_buffer_p, local_buffer_n, \
                        local_buffer_ar, local_buffer_pp, local_buffer_np, \
                        local_buffer_c1, local_buffer_c2, local_buffer_cneg, \
                        local_buffer_id = sklearn.utils.shuffle(
                            local_buffer_a,
                            local_buffer_p,
                            local_buffer_n,
                            local_buffer_ar,
                            local_buffer_pp,
                            local_buffer_np,
                            local_buffer_c1,
                            local_buffer_c2,
                            local_buffer_cneg,
                            local_buffer_id
                        )
                curr_nbatches = int(np.floor(len(local_buffer_a) / batch_size))
                for i in range(0, curr_nbatches):
                    queue.put([
                        bidx,
                        torch.stack(local_buffer_a[:batch_size]),
                        torch.stack(local_buffer_p[:batch_size]),
                        torch.stack(local_buffer_n[:batch_size]),
                        torch.stack(local_buffer_ar[:batch_size]),
                        torch.stack(local_buffer_pp[:batch_size]),
                        torch.stack(local_buffer_np[:batch_size]),
                        torch.stack(local_buffer_c1[:batch_size]),
                        torch.stack(local_buffer_c2[:batch_size]),
                        torch.stack(local_buffer_cneg[:batch_size]),
                        torch.stack(local_buffer_id[:batch_size])
                    ])
                    del local_buffer_a[:batch_size]
                    del local_buffer_p[:batch_size]
                    del local_buffer_n[:batch_size]
                    del local_buffer_ar[:batch_size]
                    del local_buffer_pp[:batch_size]
                    del local_buffer_np[:batch_size]
                    del local_buffer_c1[:batch_size]
                    del local_buffer_c2[:batch_size]
                    del local_buffer_cneg[:batch_size]
                    del local_buffer_id[:batch_size]
                    bidx += 1
        remaining_batches = len(local_buffer_a) // batch_size
        for i in range(0, remaining_batches):
            queue.put([
                bidx,
                torch.stack(local_buffer_a[:batch_size]),
                torch.stack(local_buffer_p[:batch_size]),
                torch.stack(local_buffer_n[:batch_size]),
                torch.stack(local_buffer_ar[:batch_size]),
                torch.stack(local_buffer_pp[:batch_size]),
                torch.stack(local_buffer_np[:batch_size]),
                torch.stack(local_buffer_c1[:batch_size]),
                torch.stack(local_buffer_c2[:batch_size]),
                torch.stack(local_buffer_cneg[:batch_size]),
                torch.stack(local_buffer_id[:batch_size])
            ])
            del local_buffer_a[:batch_size]
            del local_buffer_p[:batch_size]
            del local_buffer_n[:batch_size]
            del local_buffer_ar[:batch_size]
            del local_buffer_pp[:batch_size]
            del local_buffer_np[:batch_size]
            del local_buffer_c1[:batch_size]
            del local_buffer_c2[:batch_size]
            del local_buffer_cneg[:batch_size]
            del local_buffer_id[:batch_size]
        ra = torch.randn(batch_size, 3, 64, 64)
        queue.put([-1, ra, ra, ra])
        queue.join()

    @staticmethod
    def my_collate(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)
예제 #3
0
def multiprocess_single_sequence_loader(process_number: int, _config,
                                        _queue: mp.Queue,
                                        _wait_for_exit: mp.Event, _local_file,
                                        _fasttext_vocab_cached_mapping,
                                        _fasttext_vocab_cached_data):

    torch.manual_seed(_config["random_seed"])
    numpy.random.seed(_config["random_seed"])
    random.seed(_config["random_seed"])

    if _config["token_embedder_type"] == "bert_cls":
        _tokenizer = BlingFireTokenizer()
        _ind = PretrainedBertIndexer(
            pretrained_model=_config["bert_pretrained_model"],
            do_lowercase=True)
        _token_indexers = {"tokens": _ind}

        _tuple_loader = IrSingleSequenceDatasetReader(
            lazy=True,
            tokenizer=_tokenizer,
            token_indexers=_token_indexers,
            max_seq_length=_config["max_doc_length"],
            min_seq_length=_config["min_doc_length"],
        )

        _iterator = BucketIterator(batch_size=int(_config["batch_size_eval"]),
                                   sorting_keys=[("seq_tokens", "num_tokens")])

        _iterator.index_with(Vocabulary.from_files(_config["vocab_directory"]))

    else:
        _tokenizer = BlingFireTokenizer()

        if _config["token_embedder_type"] == "embedding":
            _token_indexers = {
                "tokens": SingleIdTokenIndexer(lowercase_tokens=True)
            }
            _vocab = Vocabulary.from_files(_config["vocab_directory"])

        elif _config["token_embedder_type"] == "fasttext":
            _token_indexers = {
                "tokens":
                FastTextNGramIndexer(_config["fasttext_max_subwords"])
            }
            _vocab = FastTextVocab(_fasttext_vocab_cached_mapping,
                                   _fasttext_vocab_cached_data,
                                   _config["fasttext_max_subwords"])

        elif _config["token_embedder_type"] == "elmo":
            _token_indexers = {"tokens": ELMoTokenCharactersIndexer()}
            _vocab = None

        _tuple_loader = IrSingleSequenceDatasetReader(
            lazy=True,
            tokenizer=_tokenizer,
            token_indexers=_token_indexers,
            max_seq_length=_config["max_doc_length"],
            min_seq_length=_config["min_doc_length"],
        )

        _iterator = BucketIterator(batch_size=int(_config["batch_size_eval"]),
                                   sorting_keys=[("seq_tokens", "num_tokens")])

        _iterator.index_with(_vocab)

    for training_batch in _iterator(_tuple_loader.read(_local_file),
                                    num_epochs=1):

        _queue.put(
            training_batch)  # this moves the tensors in to shared memory

    _queue.put(None)  # signal end of queue

    _queue.close()  # indicate this local thread is done
    _wait_for_exit.wait(
    )  # keep this process alive until all the shared memory is used and not needed anymore
예제 #4
0
class LearnerWorker:
    def __init__(
        self,
        worker_idx,
        policy_id,
        cfg,
        obs_space,
        action_space,
        report_queue,
        policy_worker_queues,
        shared_buffers,
        policy_lock,
        resume_experience_collection_cv,
    ):
        log.info('Initializing the learner %d for policy %d', worker_idx,
                 policy_id)

        self.worker_idx = worker_idx
        self.policy_id = policy_id

        self.cfg = cfg

        # PBT-related stuff
        self.should_save_model = True  # set to true if we need to save the model to disk on the next training iteration
        self.load_policy_id = None  # non-None when we need to replace our parameters with another policy's parameters
        self.pbt_mutex = threading.Lock()
        self.new_cfg = None  # non-None when we need to update the learning hyperparameters

        self.terminate = False

        self.obs_space = obs_space
        self.action_space = action_space

        self.rollout_tensors = shared_buffers.tensor_trajectories
        self.traj_tensors_available = shared_buffers.is_traj_tensor_available
        self.policy_versions = shared_buffers.policy_versions
        self.stop_experience_collection = shared_buffers.stop_experience_collection

        self.device = None
        self.actor_critic = None
        self.optimizer = None
        self.policy_lock = policy_lock
        self.resume_experience_collection_cv = resume_experience_collection_cv

        self.task_queue = faster_fifo.Queue()
        self.report_queue = report_queue

        self.initialized_event = MultiprocessingEvent()
        self.initialized_event.clear()

        self.model_saved_event = MultiprocessingEvent()
        self.model_saved_event.clear()

        # queues corresponding to policy workers using the same policy
        # we send weight updates via these queues
        self.policy_worker_queues = policy_worker_queues

        self.experience_buffer_queue = Queue()

        self.tensor_batch_pool = ObjectPool()
        self.tensor_batcher = TensorBatcher(self.tensor_batch_pool)

        self.with_training = True  # set to False for debugging no-training regime
        self.train_in_background = self.cfg.train_in_background_thread  # set to False for debugging

        self.training_thread = Thread(
            target=self._train_loop) if self.train_in_background else None
        self.train_thread_initialized = threading.Event()

        self.is_training = False

        self.train_step = self.env_steps = 0

        # decay rate at which summaries are collected
        # save summaries every 20 seconds in the beginning, but decay to every 4 minutes in the limit, because we
        # do not need frequent summaries for longer experiments
        self.summary_rate_decay_seconds = LinearDecay([(0, 20), (100000, 120),
                                                       (1000000, 240)])
        self.last_summary_time = 0

        self.last_saved_time = self.last_milestone_time = 0

        self.discarded_experience_over_time = deque([], maxlen=30)
        self.discarded_experience_timer = time.time()
        self.num_discarded_rollouts = 0

        self.process = Process(target=self._run, daemon=True)

    def start_process(self):
        self.process.start()

    def _init(self):
        log.info('Waiting for the learner to initialize...')
        self.train_thread_initialized.wait()
        log.info('Learner %d initialized', self.worker_idx)
        self.initialized_event.set()

    def _terminate(self):
        self.terminate = True

    def _broadcast_model_weights(self):
        state_dict = self.actor_critic.state_dict()
        policy_version = self.train_step
        log.debug('Broadcast model weights for model version %d',
                  policy_version)
        model_state = (policy_version, state_dict)
        for q in self.policy_worker_queues:
            q.put((TaskType.INIT_MODEL, model_state))

    def _calculate_gae(self, buffer):
        """
        Calculate advantages using Generalized Advantage Estimation.
        This is leftover the from previous version of the algorithm.
        Perhaps should be re-implemented in PyTorch tensors, similar to V-trace for uniformity.
        """

        rewards = torch.stack(buffer.rewards).numpy().squeeze()  # [E, T]
        dones = torch.stack(buffer.dones).numpy().squeeze()  # [E, T]
        values_arr = torch.stack(buffer.values).numpy().squeeze()  # [E, T]

        # calculating fake values for the last step in the rollout
        # this will make sure that advantage of the very last action is always zero
        values = []
        for i in range(len(values_arr)):
            last_value, last_reward = values_arr[i][-1], rewards[i, -1]
            next_value = (last_value - last_reward) / self.cfg.gamma
            values.append(list(values_arr[i]))
            values[i].append(float(next_value))  # [T] -> [T+1]

        # calculating returns and GAE
        rewards = rewards.transpose((1, 0))  # [E, T] -> [T, E]
        dones = dones.transpose((1, 0))  # [E, T] -> [T, E]
        values = np.asarray(values).transpose((1, 0))  # [E, T+1] -> [T+1, E]

        advantages, returns = calculate_gae(rewards, dones, values,
                                            self.cfg.gamma,
                                            self.cfg.gae_lambda)

        # transpose tensors back to [E, T] before creating a single experience buffer
        buffer.advantages = advantages.transpose((1, 0))  # [T, E] -> [E, T]
        buffer.returns = returns.transpose((1, 0))  # [T, E] -> [E, T]
        buffer.returns = buffer.returns[:, :,
                                        np.newaxis]  # [E, T] -> [E, T, 1]

        buffer.advantages = [torch.tensor(buffer.advantages).reshape(-1)]
        buffer.returns = [torch.tensor(buffer.returns).reshape(-1)]

        return buffer

    def _mark_rollout_buffer_free(self, rollout):
        r = rollout
        self.traj_tensors_available[r.worker_idx,
                                    r.split_idx][r.env_idx, r.agent_idx,
                                                 r.traj_buffer_idx] = 1

    def _prepare_train_buffer(self, rollouts, macro_batch_size, timing):
        trajectories = [AttrDict(r['t']) for r in rollouts]

        with timing.add_time('buffers'):
            buffer = AttrDict()

            # by the end of this loop the buffer is a dictionary containing lists of numpy arrays
            for i, t in enumerate(trajectories):
                for key, x in t.items():
                    if key not in buffer:
                        buffer[key] = []
                    buffer[key].append(x)

            # convert lists of dict observations to a single dictionary of lists
            for key, x in buffer.items():
                if isinstance(x[0], (dict, OrderedDict)):
                    buffer[key] = list_of_dicts_to_dict_of_lists(x)

        if not self.cfg.with_vtrace:
            with timing.add_time('calc_gae'):
                buffer = self._calculate_gae(buffer)

        with timing.add_time('batching'):
            # concatenate rollouts from different workers into a single batch efficiently
            # that is, if we already have memory for the buffers allocated, we can just copy the data into
            # existing cached tensors instead of creating new ones. This is a performance optimization.
            use_pinned_memory = self.cfg.device == 'gpu'
            buffer = self.tensor_batcher.cat(buffer, macro_batch_size,
                                             use_pinned_memory, timing)

        with timing.add_time('buff_ready'):
            for r in rollouts:
                self._mark_rollout_buffer_free(r)

        with timing.add_time('tensors_gpu_float'):
            device_buffer = self._copy_train_data_to_device(buffer)

        with timing.add_time('squeeze'):
            # will squeeze actions only in simple categorical case
            tensors_to_squeeze = [
                'actions', 'log_prob_actions', 'policy_version', 'values',
                'rewards', 'dones'
            ]
            for tensor_name in tensors_to_squeeze:
                device_buffer[tensor_name].squeeze_()

        # we no longer need the cached buffer, and can put it back into the pool
        self.tensor_batch_pool.put(buffer)
        return device_buffer

    def _macro_batch_size(self, batch_size):
        return self.cfg.num_batches_per_iteration * batch_size

    def _process_macro_batch(self, rollouts, batch_size, timing):
        macro_batch_size = self._macro_batch_size(batch_size)

        assert macro_batch_size % self.cfg.rollout == 0
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert macro_batch_size % self.cfg.recurrence == 0

        samples = env_steps = 0
        for rollout in rollouts:
            samples += rollout['length']
            env_steps += rollout['env_steps']

        with timing.add_time('prepare'):
            buffer = self._prepare_train_buffer(rollouts, macro_batch_size,
                                                timing)
            self.experience_buffer_queue.put(
                (buffer, batch_size, samples, env_steps))

    def _process_rollouts(self, rollouts, timing):
        # batch_size can potentially change through PBT, so we should keep it the same and pass it around
        # using function arguments, instead of using global self.cfg

        batch_size = self.cfg.batch_size
        rollouts_in_macro_batch = self._macro_batch_size(
            batch_size) // self.cfg.rollout

        if len(rollouts) < rollouts_in_macro_batch:
            return rollouts

        discard_rollouts = 0
        policy_version = self.train_step
        for r in rollouts:
            rollout_min_version = r['t']['policy_version'].min().item()
            if policy_version - rollout_min_version >= self.cfg.max_policy_lag:
                discard_rollouts += 1
                self._mark_rollout_buffer_free(r)
            else:
                break

        if discard_rollouts > 0:
            log.warning(
                'Discarding %d old rollouts, cut by policy lag threshold %d (learner %d)',
                discard_rollouts,
                self.cfg.max_policy_lag,
                self.policy_id,
            )
            rollouts = rollouts[discard_rollouts:]
            self.num_discarded_rollouts += discard_rollouts

        if len(rollouts) >= rollouts_in_macro_batch:
            # process newest rollouts
            rollouts_to_process = rollouts[:rollouts_in_macro_batch]
            rollouts = rollouts[rollouts_in_macro_batch:]

            self._process_macro_batch(rollouts_to_process, batch_size, timing)
            # log.info('Unprocessed rollouts: %d (%d samples)', len(rollouts), len(rollouts) * self.cfg.rollout)

        return rollouts

    def _get_minibatches(self, batch_size, experience_size):
        """Generating minibatches for training."""
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert experience_size % batch_size == 0, f'experience size: {experience_size}, batch size: {batch_size}'

        if self.cfg.num_batches_per_iteration == 1:
            return [
                None
            ]  # single minibatch is actually the entire buffer, we don't need indices

        # indices that will start the mini-trajectories from the same episode (for bptt)
        indices = np.arange(0, experience_size, self.cfg.recurrence)
        indices = np.random.permutation(indices)

        # complete indices of mini trajectories, e.g. with recurrence==4: [4, 16] -> [4, 5, 6, 7, 16, 17, 18, 19]
        indices = [np.arange(i, i + self.cfg.recurrence) for i in indices]
        indices = np.concatenate(indices)

        assert len(indices) == experience_size

        num_minibatches = experience_size // batch_size
        minibatches = np.split(indices, num_minibatches)
        return minibatches

    @staticmethod
    def _get_minibatch(buffer, indices):
        if indices is None:
            # handle the case of a single batch, where the entire buffer is a minibatch
            return buffer

        mb = AttrDict()

        for item, x in buffer.items():
            if isinstance(x, (dict, OrderedDict)):
                mb[item] = AttrDict()
                for key, x_elem in x.items():
                    mb[item][key] = x_elem[indices]
            else:
                mb[item] = x[indices]

        return mb

    def _should_save_summaries(self):
        summaries_every_seconds = self.summary_rate_decay_seconds.at(
            self.train_step)
        if time.time() - self.last_summary_time < summaries_every_seconds:
            return False

        return True

    def _after_optimizer_step(self):
        """A hook to be called after each optimizer step."""
        self.train_step += 1
        self._maybe_save()

    def _maybe_save(self):
        if time.time(
        ) - self.last_saved_time >= self.cfg.save_every_sec or self.should_save_model:
            self._save()
            self.model_saved_event.set()
            self.should_save_model = False
            self.last_saved_time = time.time()

    @staticmethod
    def checkpoint_dir(cfg, policy_id):
        checkpoint_dir = join(experiment_dir(cfg=cfg),
                              f'checkpoint_p{policy_id}')
        return ensure_dir_exists(checkpoint_dir)

    @staticmethod
    def get_checkpoints(checkpoints_dir):
        checkpoints = glob.glob(join(checkpoints_dir, 'checkpoint_*'))
        return sorted(checkpoints)

    def _get_checkpoint_dict(self):
        checkpoint = {
            'train_step': self.train_step,
            'env_steps': self.env_steps,
            'model': self.actor_critic.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        return checkpoint

    def _save(self):
        checkpoint = self._get_checkpoint_dict()
        assert checkpoint is not None

        checkpoint_dir = self.checkpoint_dir(self.cfg, self.policy_id)
        tmp_filepath = join(checkpoint_dir, '.temp_checkpoint')
        checkpoint_name = f'checkpoint_{self.train_step:09d}_{self.env_steps}.pth'
        filepath = join(checkpoint_dir, checkpoint_name)
        log.info('Saving %s...', tmp_filepath)
        torch.save(checkpoint, tmp_filepath)
        log.info('Renaming %s to %s', tmp_filepath, filepath)
        os.rename(tmp_filepath, filepath)

        while len(self.get_checkpoints(
                checkpoint_dir)) > self.cfg.keep_checkpoints:
            oldest_checkpoint = self.get_checkpoints(checkpoint_dir)[0]
            if os.path.isfile(oldest_checkpoint):
                log.debug('Removing %s', oldest_checkpoint)
                os.remove(oldest_checkpoint)

        if self.cfg.save_milestones_sec > 0:
            # milestones enabled
            if time.time(
            ) - self.last_milestone_time >= self.cfg.save_milestones_sec:
                milestones_dir = ensure_dir_exists(
                    join(checkpoint_dir, 'milestones'))
                milestone_path = join(milestones_dir,
                                      f'{checkpoint_name}.milestone')
                log.debug('Saving a milestone %s', milestone_path)
                shutil.copy(filepath, milestone_path)
                self.last_milestone_time = time.time()

    @staticmethod
    def _policy_loss(ratio, adv, clip_ratio_low, clip_ratio_high):
        clipped_ratio = torch.clamp(ratio, clip_ratio_low, clip_ratio_high)
        loss_unclipped = ratio * adv
        loss_clipped = clipped_ratio * adv
        loss = torch.min(loss_unclipped, loss_clipped)
        loss = -loss.mean()

        return loss

    def _value_loss(self, new_values, old_values, target, clip_value):
        value_clipped = old_values + torch.clamp(new_values - old_values,
                                                 -clip_value, clip_value)
        value_original_loss = (new_values - target).pow(2)
        value_clipped_loss = (value_clipped - target).pow(2)
        value_loss = torch.max(value_original_loss, value_clipped_loss)
        value_loss = value_loss.mean()
        value_loss *= self.cfg.value_loss_coeff

        return value_loss

    def _prepare_observations(self, obs_tensors, gpu_buffer_obs):
        for d, gpu_d, k, v, _ in iter_dicts_recursively(
                obs_tensors, gpu_buffer_obs):
            device, dtype = self.actor_critic.device_and_type_for_input_tensor(
                k)
            tensor = v.detach().to(device, copy=True).type(dtype)
            gpu_d[k] = tensor

    def _copy_train_data_to_device(self, buffer):
        device_buffer = copy_dict_structure(buffer)

        for key, item in buffer.items():
            if key == 'obs':
                self._prepare_observations(item, device_buffer['obs'])
            else:
                device_tensor = item.detach().to(self.device,
                                                 copy=True,
                                                 non_blocking=True)
                device_buffer[key] = device_tensor.float()

        return device_buffer

    def _train(self, gpu_buffer, batch_size, experience_size, timing):
        with torch.no_grad():
            early_stopping_tolerance = 1e-6
            early_stop = False
            prev_epoch_actor_loss = 1e9
            epoch_actor_losses = []

            # V-trace parameters
            # noinspection PyArgumentList
            rho_hat = torch.Tensor([self.cfg.vtrace_rho])
            # noinspection PyArgumentList
            c_hat = torch.Tensor([self.cfg.vtrace_c])

            clip_ratio_high = 1.0 + self.cfg.ppo_clip_ratio  # e.g. 1.1
            # this still works with e.g. clip_ratio = 2, while PPO's 1-r would give negative ratio
            clip_ratio_low = 1.0 / clip_ratio_high

            clip_value = self.cfg.ppo_clip_value
            gamma = self.cfg.gamma
            recurrence = self.cfg.recurrence

            if self.cfg.with_vtrace:
                assert recurrence == self.cfg.rollout and recurrence > 1, \
                    'V-trace requires to recurrence and rollout to be equal'

            num_sgd_steps = 0

            stats_and_summaries = None
            if not self.with_training:
                return stats_and_summaries

        for epoch in range(self.cfg.ppo_epochs):
            with timing.add_time('epoch_init'):
                if early_stop or self.terminate:
                    break

                summary_this_epoch = force_summaries = False

                minibatches = self._get_minibatches(batch_size,
                                                    experience_size)

            for batch_num in range(len(minibatches)):
                with timing.add_time('minibatch_init'):
                    indices = minibatches[batch_num]

                    # current minibatch consisting of short trajectory segments with length == recurrence
                    mb = self._get_minibatch(gpu_buffer, indices)

                # calculate policy head outside of recurrent loop
                with timing.add_time('forward_head'):
                    head_outputs = self.actor_critic.forward_head(mb.obs)

                # initial rnn states
                with timing.add_time('bptt_initial'):
                    rnn_states = mb.rnn_states[::recurrence]
                    is_same_episode = 1.0 - mb.dones.unsqueeze(dim=1)

                # calculate RNN outputs for each timestep in a loop
                with timing.add_time('bptt'):
                    core_outputs = []
                    for i in range(recurrence):
                        # indices of head outputs corresponding to the current timestep
                        step_head_outputs = head_outputs[i::recurrence]

                        with timing.add_time('bptt_forward_core'):
                            core_output, rnn_states = self.actor_critic.forward_core(
                                step_head_outputs, rnn_states)
                            core_outputs.append(core_output)

                        if self.cfg.use_rnn:
                            # zero-out RNN states on the episode boundary
                            with timing.add_time('bptt_rnn_states'):
                                is_same_episode_step = is_same_episode[
                                    i::recurrence]
                                rnn_states = rnn_states * is_same_episode_step

                with timing.add_time('tail'):
                    # transform core outputs from [T, Batch, D] to [Batch, T, D] and then to [Batch x T, D]
                    # which is the same shape as the minibatch
                    core_outputs = torch.stack(core_outputs)

                    num_timesteps, num_trajectories = core_outputs.shape[:2]
                    assert num_timesteps == recurrence
                    assert num_timesteps * num_trajectories == batch_size
                    core_outputs = core_outputs.transpose(0, 1).reshape(
                        -1, *core_outputs.shape[2:])
                    assert core_outputs.shape[0] == head_outputs.shape[0]

                    # calculate policy tail outside of recurrent loop
                    result = self.actor_critic.forward_tail(
                        core_outputs, with_action_distribution=True)

                    action_distribution = result.action_distribution
                    log_prob_actions = action_distribution.log_prob(mb.actions)
                    ratio = torch.exp(log_prob_actions -
                                      mb.log_prob_actions)  # pi / pi_old

                    # super large/small values can cause numerical problems and are probably noise anyway
                    ratio = torch.clamp(ratio, 0.05, 20.0)

                    values = result.values.squeeze()

                with torch.no_grad(
                ):  # these computations are not the part of the computation graph
                    if self.cfg.with_vtrace:
                        ratios_cpu = ratio.cpu()
                        values_cpu = values.cpu()
                        rewards_cpu = mb.rewards.cpu(
                        )  # we only need this on CPU, potential minor optimization
                        dones_cpu = mb.dones.cpu()

                        vtrace_rho = torch.min(rho_hat, ratios_cpu)
                        vtrace_c = torch.min(c_hat, ratios_cpu)

                        vs = torch.zeros((num_trajectories * recurrence))
                        adv = torch.zeros((num_trajectories * recurrence))

                        next_values = (
                            values_cpu[recurrence - 1::recurrence] -
                            rewards_cpu[recurrence - 1::recurrence]) / gamma
                        next_vs = next_values

                        with timing.add_time('vtrace'):
                            for i in reversed(range(self.cfg.recurrence)):
                                rewards = rewards_cpu[i::recurrence]
                                dones = dones_cpu[i::recurrence]
                                not_done = 1.0 - dones
                                not_done_times_gamma = not_done * gamma

                                curr_values = values_cpu[i::recurrence]
                                curr_vtrace_rho = vtrace_rho[i::recurrence]
                                curr_vtrace_c = vtrace_c[i::recurrence]

                                delta_s = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma *
                                    next_values - curr_values)
                                adv[i::recurrence] = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma * next_vs -
                                    curr_values)
                                next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (
                                    next_vs - next_values)
                                vs[i::recurrence] = next_vs

                                next_values = curr_values

                        targets = vs
                    else:
                        # using regular GAE
                        adv = mb.advantages
                        targets = mb.returns

                    adv_mean = adv.mean()
                    adv_std = adv.std()
                    adv = (adv - adv_mean) / max(
                        1e-3, adv_std)  # normalize advantage
                    adv = adv.to(self.device)

                with timing.add_time('losses'):
                    policy_loss = self._policy_loss(ratio, adv, clip_ratio_low,
                                                    clip_ratio_high)

                    entropy = action_distribution.entropy()
                    if self.cfg.entropy_loss_coeff > 0.0:
                        entropy_loss = -self.cfg.entropy_loss_coeff * entropy.mean(
                        )
                    else:
                        entropy_loss = 0.0

                    actor_loss = policy_loss + entropy_loss
                    epoch_actor_losses.append(actor_loss.item())

                    targets = targets.to(self.device)
                    old_values = mb.values
                    value_loss = self._value_loss(values, old_values, targets,
                                                  clip_value)
                    critic_loss = value_loss

                    loss = actor_loss + critic_loss

                    high_loss = 30.0
                    if abs(to_scalar(policy_loss)) > high_loss or abs(
                            to_scalar(value_loss)) > high_loss or abs(
                                to_scalar(entropy_loss)) > high_loss:
                        log.warning(
                            'High loss value: %.4f %.4f %.4f %.4f',
                            to_scalar(loss),
                            to_scalar(policy_loss),
                            to_scalar(value_loss),
                            to_scalar(entropy_loss),
                        )
                        force_summaries = True

                with timing.add_time('update'):
                    # update the weights
                    self.optimizer.zero_grad()
                    loss.backward()

                    if self.cfg.max_grad_norm > 0.0:
                        with timing.add_time('clip'):
                            torch.nn.utils.clip_grad_norm_(
                                self.actor_critic.parameters(),
                                self.cfg.max_grad_norm)

                    curr_policy_version = self.train_step  # policy version before the weight update
                    with self.policy_lock:
                        self.optimizer.step()

                    num_sgd_steps += 1

                with torch.no_grad():
                    with timing.add_time('after_optimizer'):
                        self._after_optimizer_step()

                        # collect and report summaries
                        with_summaries = self._should_save_summaries(
                        ) or force_summaries
                        if with_summaries and not summary_this_epoch:
                            stats_and_summaries = self._record_summaries(
                                AttrDict(locals()))
                            summary_this_epoch = True
                            force_summaries = False

            # end of an epoch
            # this will force policy update on the inference worker (policy worker)
            self.policy_versions[self.policy_id] = self.train_step

            new_epoch_actor_loss = np.mean(epoch_actor_losses)
            loss_delta_abs = abs(prev_epoch_actor_loss - new_epoch_actor_loss)
            if loss_delta_abs < early_stopping_tolerance:
                early_stop = True
                log.debug(
                    'Early stopping after %d epochs (%d sgd steps), loss delta %.7f',
                    epoch + 1,
                    num_sgd_steps,
                    loss_delta_abs,
                )
                break

            prev_epoch_actor_loss = new_epoch_actor_loss
            epoch_actor_losses = []

        return stats_and_summaries

    def _record_summaries(self, train_loop_vars):
        var = train_loop_vars

        self.last_summary_time = time.time()
        stats = AttrDict()

        grad_norm = sum(
            p.grad.data.norm(2).item()**2
            for p in self.actor_critic.parameters() if p.grad is not None)**0.5
        stats.grad_norm = grad_norm
        stats.loss = var.loss
        stats.value = var.result.values.mean()
        stats.entropy = var.action_distribution.entropy().mean()
        stats.policy_loss = var.policy_loss
        stats.value_loss = var.value_loss
        stats.entropy_loss = var.entropy_loss
        stats.adv_min = var.adv.min()
        stats.adv_max = var.adv.max()
        stats.adv_std = var.adv_std
        stats.max_abs_logprob = torch.abs(var.mb.action_logits).max()

        if hasattr(var.action_distribution, 'summaries'):
            stats.update(var.action_distribution.summaries())

        if var.epoch == self.cfg.ppo_epochs - 1 and var.batch_num == len(
                var.minibatches) - 1:
            # we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style
            ratio_mean = torch.abs(1.0 - var.ratio).mean().detach()
            ratio_min = var.ratio.min().detach()
            ratio_max = var.ratio.max().detach()
            # log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item())

            value_delta = torch.abs(var.values - var.old_values)
            value_delta_avg, value_delta_max = value_delta.mean(
            ), value_delta.max()

            # calculate KL-divergence with the behaviour policy action distribution
            old_action_distribution = get_action_distribution(
                self.actor_critic.action_space,
                var.mb.action_logits,
            )
            kl_old = var.action_distribution.kl_divergence(
                old_action_distribution)
            kl_old_mean = kl_old.mean()

            stats.kl_divergence = kl_old_mean
            stats.value_delta = value_delta_avg
            stats.value_delta_max = value_delta_max
            stats.fraction_clipped = (
                (var.ratio < var.clip_ratio_low).float() +
                (var.ratio > var.clip_ratio_high).float()).mean()
            stats.ratio_mean = ratio_mean
            stats.ratio_min = ratio_min
            stats.ratio_max = ratio_max
            stats.num_sgd_steps = var.num_sgd_steps

        # this caused numerical issues on some versions of PyTorch with second moment reaching infinity
        adam_max_second_moment = 0.0
        for key, tensor_state in self.optimizer.state.items():
            adam_max_second_moment = max(
                tensor_state['exp_avg_sq'].max().item(),
                adam_max_second_moment)
        stats.adam_max_second_moment = adam_max_second_moment

        version_diff = var.curr_policy_version - var.mb.policy_version
        stats.version_diff_avg = version_diff.mean()
        stats.version_diff_min = version_diff.min()
        stats.version_diff_max = version_diff.max()

        for key, value in stats.items():
            stats[key] = to_scalar(value)

        return stats

    def _update_pbt(self):
        """To be called from the training loop, same thread that updates the model!"""
        with self.pbt_mutex:
            if self.load_policy_id is not None:
                assert self.cfg.with_pbt

                log.debug('Learner %d loads policy from %d', self.policy_id,
                          self.load_policy_id)
                self.load_from_checkpoint(self.load_policy_id)
                self.load_policy_id = None

            if self.new_cfg is not None:
                for key, value in self.new_cfg.items():
                    if self.cfg[key] != value:
                        log.debug(
                            'Learner %d replacing cfg parameter %r with new value %r',
                            self.policy_id, key, value)
                        self.cfg[key] = value

                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.cfg.learning_rate
                    param_group['betas'] = (self.cfg.adam_beta1,
                                            self.cfg.adam_beta2)
                    log.debug('Updated optimizer lr to value %.7f, betas: %r',
                              param_group['lr'], param_group['betas'])

                self.new_cfg = None

    @staticmethod
    def load_checkpoint(checkpoints, device):
        if len(checkpoints) <= 0:
            log.warning('No checkpoints found')
            return None
        else:
            latest_checkpoint = checkpoints[-1]

            # extra safety mechanism to recover from spurious filesystem errors
            num_attempts = 3
            for attempt in range(num_attempts):
                try:
                    log.warning('Loading state from checkpoint %s...',
                                latest_checkpoint)
                    checkpoint_dict = torch.load(latest_checkpoint,
                                                 map_location=device)
                    return checkpoint_dict
                except Exception:
                    log.exception(
                        f'Could not load from checkpoint, attempt {attempt}')

    def _load_state(self, checkpoint_dict, load_progress=True):
        if load_progress:
            self.train_step = checkpoint_dict['train_step']
            self.env_steps = checkpoint_dict['env_steps']
        self.actor_critic.load_state_dict(checkpoint_dict['model'])
        self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
        log.info(
            'Loaded experiment state at training iteration %d, env step %d',
            self.train_step, self.env_steps)

    def init_model(self, timing):
        self.actor_critic = create_actor_critic(self.cfg, self.obs_space,
                                                self.action_space, timing)
        self.actor_critic.model_to_device(self.device)
        self.actor_critic.share_memory()

    def load_from_checkpoint(self, policy_id):
        checkpoints = self.get_checkpoints(
            self.checkpoint_dir(self.cfg, policy_id))
        checkpoint_dict = self.load_checkpoint(checkpoints, self.device)
        if checkpoint_dict is None:
            log.debug('Did not load from checkpoint, starting from scratch!')
        else:
            log.debug('Loading model from checkpoint')

            # if we're replacing our policy with another policy (under PBT), let's not reload the env_steps
            load_progress = policy_id == self.policy_id
            self._load_state(checkpoint_dict, load_progress=load_progress)

    def initialize(self, timing):
        with timing.timeit('init'):
            # initialize the Torch modules
            if self.cfg.seed is None:
                log.info('Starting seed is not provided')
            else:
                log.info('Setting fixed seed %d', self.cfg.seed)
                torch.manual_seed(self.cfg.seed)
                np.random.seed(self.cfg.seed)

            # this does not help with a single experiment
            # but seems to do better when we're running more than one experiment in parallel
            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                torch.backends.cudnn.benchmark = True

                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')
            self.init_model(timing)

            self.optimizer = torch.optim.Adam(
                self.actor_critic.parameters(),
                self.cfg.learning_rate,
                betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
                eps=self.cfg.adam_eps,
            )

            self.load_from_checkpoint(self.policy_id)

            self._broadcast_model_weights(
            )  # sync the very first version of the weights

        self.train_thread_initialized.set()

    def _process_training_data(self, data, timing, wait_stats=None):
        self.is_training = True

        buffer, batch_size, samples, env_steps = data
        assert samples == batch_size * self.cfg.num_batches_per_iteration

        self.env_steps += env_steps
        experience_size = buffer.rewards.shape[0]

        stats = dict(learner_env_steps=self.env_steps,
                     policy_id=self.policy_id)

        with timing.add_time('train'):
            discarding_rate = self._discarding_rate()

            self._update_pbt()

            train_stats = self._train(buffer, batch_size, experience_size,
                                      timing)

            if train_stats is not None:
                stats['train'] = train_stats

                if wait_stats is not None:
                    wait_avg, wait_min, wait_max = wait_stats
                    stats['train']['wait_avg'] = wait_avg
                    stats['train']['wait_min'] = wait_min
                    stats['train']['wait_max'] = wait_max

                stats['train'][
                    'discarded_rollouts'] = self.num_discarded_rollouts
                stats['train']['discarding_rate'] = discarding_rate

                stats['stats'] = memory_stats('learner', self.device)

        self.is_training = False

        try:
            self.report_queue.put(stats)
        except Full:
            log.warning(
                'Could not report training stats, the report queue is full!')

    def _train_loop(self):
        timing = Timing()
        self.initialize(timing)

        wait_times = deque([], maxlen=self.cfg.num_workers)
        last_cache_cleanup = time.time()
        num_batches_processed = 0

        while not self.terminate:
            with timing.timeit('train_wait'):
                data = safe_get(self.experience_buffer_queue)

            if self.terminate:
                break

            wait_stats = None
            wait_times.append(timing.train_wait)

            if len(wait_times) >= wait_times.maxlen:
                wait_times_arr = np.asarray(wait_times)
                wait_avg = np.mean(wait_times_arr)
                wait_min, wait_max = wait_times_arr.min(), wait_times_arr.max()
                # log.debug(
                #     'Training thread had to wait %.5f s for the new experience buffer (avg %.5f)',
                #     timing.train_wait, wait_avg,
                # )
                wait_stats = (wait_avg, wait_min, wait_max)

            self._process_training_data(data, timing, wait_stats)
            num_batches_processed += 1

            if time.time() - last_cache_cleanup > 300.0 or (
                    not self.cfg.benchmark and num_batches_processed < 50):
                if self.cfg.device == 'gpu':
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                last_cache_cleanup = time.time()

        time.sleep(0.3)
        log.info('Train loop timing: %s', timing)
        del self.actor_critic
        del self.device

    def _experience_collection_rate_stats(self):
        now = time.time()
        if now - self.discarded_experience_timer > 1.0:
            self.discarded_experience_timer = now
            self.discarded_experience_over_time.append(
                (now, self.num_discarded_rollouts))

    def _discarding_rate(self):
        if len(self.discarded_experience_over_time) <= 1:
            return 0

        first, last = self.discarded_experience_over_time[
            0], self.discarded_experience_over_time[-1]
        delta_rollouts = last[1] - first[1]
        delta_time = last[0] - first[0]
        discarding_rate = delta_rollouts / (delta_time + EPS)
        return discarding_rate

    def _extract_rollouts(self, data):
        data = AttrDict(data)
        worker_idx, split_idx, traj_buffer_idx = data.worker_idx, data.split_idx, data.traj_buffer_idx

        rollouts = []
        for rollout_data in data.rollouts:
            env_idx, agent_idx = rollout_data['env_idx'], rollout_data[
                'agent_idx']
            tensors = self.rollout_tensors.index(
                (worker_idx, split_idx, env_idx, agent_idx, traj_buffer_idx))

            rollout_data['t'] = tensors
            rollout_data['worker_idx'] = worker_idx
            rollout_data['split_idx'] = split_idx
            rollout_data['traj_buffer_idx'] = traj_buffer_idx
            rollouts.append(AttrDict(rollout_data))

        return rollouts

    def _process_pbt_task(self, pbt_task):
        task_type, data = pbt_task

        with self.pbt_mutex:
            if task_type == PbtTask.SAVE_MODEL:
                policy_id = data
                assert policy_id == self.policy_id
                self.should_save_model = True
            elif task_type == PbtTask.LOAD_MODEL:
                policy_id, new_policy_id = data
                assert policy_id == self.policy_id
                assert new_policy_id is not None
                self.load_policy_id = new_policy_id
            elif task_type == PbtTask.UPDATE_CFG:
                policy_id, new_cfg = data
                assert policy_id == self.policy_id
                self.new_cfg = new_cfg

    def _accumulated_too_much_experience(self, rollouts):
        max_minibatches_to_accumulate = self.cfg.num_minibatches_to_accumulate
        if max_minibatches_to_accumulate == -1:
            # default value
            max_minibatches_to_accumulate = 2 * self.cfg.num_batches_per_iteration

        # allow the max batches to accumulate, plus the minibatches we're currently training on
        max_minibatches_on_learner = max_minibatches_to_accumulate + self.cfg.num_batches_per_iteration

        minibatches_currently_training = int(
            self.is_training) * self.cfg.num_batches_per_iteration

        rollouts_per_minibatch = self.cfg.batch_size / self.cfg.rollout

        # count contribution from unprocessed rollouts
        minibatches_currently_accumulated = len(
            rollouts) / rollouts_per_minibatch

        # count minibatches ready for training
        minibatches_currently_accumulated += self.experience_buffer_queue.qsize(
        ) * self.cfg.num_batches_per_iteration

        total_minibatches_on_learner = minibatches_currently_training + minibatches_currently_accumulated

        return total_minibatches_on_learner >= max_minibatches_on_learner

    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        try:
            psutil.Process().nice(self.cfg.default_niceness)
        except psutil.AccessDenied:
            log.error('Low niceness requires sudo!')

        if self.cfg.device == 'gpu':
            cuda_envvars(self.policy_id)

        torch.multiprocessing.set_sharing_strategy('file_system')
        torch.set_num_threads(self.cfg.learner_main_loop_num_cores)

        timing = Timing()

        rollouts = []

        if self.train_in_background:
            self.training_thread.start()
        else:
            self.initialize(timing)
            log.error(
                'train_in_background set to False on learner %d! This is slow, use only for testing!',
                self.policy_id,
            )

        while not self.terminate:
            while True:
                try:
                    tasks = self.task_queue.get_many(timeout=0.005)

                    for task_type, data in tasks:
                        if task_type == TaskType.TRAIN:
                            with timing.add_time('extract'):
                                rollouts.extend(self._extract_rollouts(data))
                                # log.debug('Learner %d has %d rollouts', self.policy_id, len(rollouts))
                        elif task_type == TaskType.INIT:
                            self._init()
                        elif task_type == TaskType.TERMINATE:
                            time.sleep(0.3)
                            log.info('GPU learner timing: %s', timing)
                            self._terminate()
                            break
                        elif task_type == TaskType.PBT:
                            self._process_pbt_task(data)
                except Empty:
                    break

            if self._accumulated_too_much_experience(rollouts):
                # if we accumulated too much experience, signal the policy workers to stop experience collection
                if not self.stop_experience_collection[self.policy_id]:
                    log.debug(
                        'Learner %d accumulated too much experience, stop experience collection!',
                        self.policy_id)
                self.stop_experience_collection[self.policy_id] = True
            elif self.stop_experience_collection[self.policy_id]:
                # otherwise, resume the experience collection if it was stopped
                self.stop_experience_collection[self.policy_id] = False
                with self.resume_experience_collection_cv:
                    log.debug('Learner %d is resuming experience collection!',
                              self.policy_id)
                    self.resume_experience_collection_cv.notify_all()

            with torch.no_grad():
                rollouts = self._process_rollouts(rollouts, timing)

            if not self.train_in_background:
                while not self.experience_buffer_queue.empty():
                    training_data = self.experience_buffer_queue.get()
                    self._process_training_data(training_data, timing)

            self._experience_collection_rate_stats()

        if self.train_in_background:
            self.experience_buffer_queue.put(None)
            self.training_thread.join()

    def init(self):
        self.task_queue.put((TaskType.INIT, None))
        self.initialized_event.wait()

    def save_model(self, timeout=None):
        self.model_saved_event.clear()
        save_task = (PbtTask.SAVE_MODEL, self.policy_id)
        self.task_queue.put((TaskType.PBT, save_task))
        log.debug('Wait while learner %d saves the model...', self.policy_id)
        if self.model_saved_event.wait(timeout=timeout):
            log.debug('Learner %d saved the model!', self.policy_id)
        else:
            log.warning('Model saving request timed out!')
        self.model_saved_event.clear()

    def close(self):
        self.task_queue.put((TaskType.TERMINATE, None))

    def join(self):
        join_or_kill(self.process)