Пример #1
0
    def log_fn(self, stop_event: Event):
        try:
            self._super_create_loggers()
            self.resposne_queue.put({
                k: self.__dict__[k]
                for k in ["save_dir", "tb_logdir", "is_sweep"]
            })

            while True:
                try:
                    cmd = self.draw_queue.get(True, 0.1)
                except EmptyQueue:
                    if stop_event.is_set():
                        break
                    else:
                        continue

                self._super_log(*cmd)
                self.resposne_queue.put(True)
        except:
            print("Logger process crashed.")
            raise
        finally:
            print("Logger: syncing")
            if self.use_wandb:
                wandb.join()

            stop_event.set()
            print("Logger process terminating...")
def run_in_process_group(world_size, filename, fn, inputs):
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    processes = []
    q = Queue()
    wait_event = Event()

    # run the remaining processes
    # for rank in range(world_size - 1):
    for rank in range(world_size):
        p = Process(
            target=init_and_run_process,
            args=(rank, world_size, filename, fn, inputs[rank], q, wait_event),
        )
        p.start()
        processes.append(p)

    # fetch the results from the queue before joining, the background processes
    # need to be alive if the queue contains tensors. See
    # https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847/3  # noqa: B950
    results = []
    for _ in range(len(processes)):
        results.append(q.get())

    wait_event.set()

    for p in processes:
        p.join()
    return results
Пример #3
0
class TensorEvent:
    """Basically a tuple of several torch.Tensors and a multiprocessing.Event.

  The Tensors can be used as "shared tensors" for passing intermediate tensors
  across processes.

  The Event should be used to signal that the consumer process has finished
  reading from the Tensors. When writing values to Tensors, the producer
  process should first check if Tensors are free, by calling event.wait(). If
  the Tensors are indeed free, then event.wait() will return at once. If not,
  then event.wait() will block until the consumer process calls event.set().
  Thus, the consumer should make sure that it calls event.set() AFTER the
  Tensors' contents have been copied to a safe area, such as the consumer's own
  local tensors.

  This class also includes an Array object living on shared memory, consisting
  of integers for indicating the valid region in each tensor. For example, if
  a process uses only 3 rows of a 4-row tensor, then the corresponding entry
  in the Array would be set to 3. Later, when values are read from the tensor
  by another process, that process would first check the Array value and know
  that it can ignore the final row.
  """
    def __init__(self, shapes, device, dtype=torch.float32):
        self.tensors = tuple(
            torch.empty(*shape, dtype=dtype, device=device)
            for shape in shapes)
        self.event = Event()
        self.event.set()
        self.valid_batch_sizes = Array('i', len(shapes))
Пример #4
0
def test_routine_as_process():
    r = DummyRoutine()
    e = Event()
    r.stop_event = e
    r.as_process()
    r.start()
    e.set()
    r.runner.join()
Пример #5
0
class DataLoader:
    def __init__(self, data_store, epochs=1):
        """
          Start the batchCreator and sampleCreator.
          Read the memory config file, and create the right number
          of processes.
        """
        self.ds = data_store
        # Event to stop batch creator and sample creator
        self.stop_sc = Event()
        self.stop_bc = Event()

        # Start separate processes for sample_creator(s)
        self.epochs = epochs
        self.sc = SampleCreator(self.ds,
                                event=self.stop_sc,
                                epochs=self.epochs,
                                sampled=self.ds.points_sampled)
        self.sc.start()

        # Start batch_creator(s)
        self.bc = BatchCreator(self.ds, self.stop_bc)
        self.bc.start()

    def get_next_batch(self):
        """
         Get the next batch from the queue
        """
        if self.ds.batch_creator_done.full():
            return None

        # Access batches from data_store.batches and return the batch
        try:
            batch = self.ds.batches.get()
            return batch
        except Exception as e:
            print(e)
            return None

    def stop_batch_creation(self):
        # Attempt to gracefully terminate the processes
        self.stop_bc.set()
        self.stop_sc.set()
        time.sleep(1)

        # Terminate the processes forcefully
        self.bc.terminate()
        self.sc.terminate()

        # Wait for child processes to end
        self.bc.join()
        self.sc.join()
Пример #6
0
def test_routine_no_runner():
    r = DummyRoutine(name="dummy")
    with pytest.raises(NoRunnerException):
        r.start()
    e = Event()
    r.stop_event = e
    r.as_thread()
    try:
        r.start()
    except NoRunnerException:
        pytest.fail("NoRunnerException was thrown...")
    e.set()
    r.runner.join()
class Worker:
    def __init__(self, queue: Queue, collect_event: Event, actor_net, args,
                 seed):
        self._queue = queue
        self._collect_event = collect_event
        self._actor = actor_net
        self._args = args
        self.event = Event()
        self.seed = seed

    def run(self, episode):
        env = ManyUavEnv(self._args.agents, self.seed, self._args.reward_type)
        state = env.reset()
        while True:
            self.event.set()
            self._collect_event.wait()
            actions = []
            for i in range(self._args.agents):
                action = self._choose_action_with_exploration(state[i])
                actions.append(action)
            next_state, reward, done, info = env.step(
                np.array(actions) * self._args.action_bound)

            transition = []
            for i in range(self._args.agents):
                transition.append(
                    (state[i], actions[i], reward[i], next_state[i], done))

            self._queue.put(transition)

            state = next_state
            if done:
                state = env.reset()
                with episode.get_lock():
                    episode.value += 1
            if self._queue.qsize() >= self._args.update_interval:
                self._collect_event.clear()

    def _choose_action_with_exploration(self, state):
        action = self._choose_action(state)
        noise = np.random.normal(0, self._args.scale, (2, ))
        action = np.clip(action + noise, -1, 1)  # clip action between [-1, 1]
        return action

    def _choose_action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).float().to(CHIP)
            action = self._actor(state)
        action = action.detach().cpu().numpy()
        return action
Пример #8
0
def test_routine_crash_message(caplog):
    r = DummyCrashingRoutine()
    e = Event()
    r.stop_event = e
    r.as_thread()
    r.start()
    time.sleep(0.01)
    e.set()
    r.runner.join()
    logs_text_list = [
        record[2]
        for record in (record_tuple for record_tuple in caplog.record_tuples)
    ]
    assert any("The routine has crashed" in log for log in logs_text_list)
Пример #9
0
class Worker:
    def __init__(self, queue: Queue, collect_event: Event, actor_net, args):
        self._env = ManyUavEnv(1, True)
        self._queue = queue
        self._collect_event = collect_event
        self._actor = actor_net
        self._args = args
        self.event = Event()

    def run(self, episode):
        state = self._env.reset()

        while True:
            self.event.set()
            self._collect_event.wait()

            action = self._choose_action_with_exploration(state)
            next_state, reward, done, info = self._env.step(
                action * self._args.action_bound)

            self._queue.put((state, action, reward, next_state, done))

            state = next_state
            if done:
                state = self._env.reset()
                with episode.get_lock():
                    episode.value += 1

            if self._queue.qsize() >= self._args.update_interval:
                self._collect_event.clear()

    def _choose_action_with_exploration(self, state):
        action = self._choose_action(state)
        noise = np.random.normal(0, self._args.scale, (2, ))
        action = np.clip(action + noise, -1, 1)  # clip action between [-1, 1]
        return action

    def _choose_action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).float()
            action = self._actor(state)
        action = action.detach().numpy()
        return action
Пример #10
0
class WorkerManager:
    def __init__(self, n_workers, actor, args):
        self._now_episode = Value('i', 0)

        self.queue = Queue()
        self.collect_event = Event()

        self.worker = []
        for i in range(n_workers):
            self.worker.append(
                Worker(self.queue, self.collect_event, actor, args))
            time.sleep(1)

        self.process = [
            Process(target=self.worker[i].run, args=(self._now_episode, ))
            for i in range(n_workers)
        ]

        for p in self.process:
            p.start()
        print(f'Start {n_workers} workers.')

    def collect(self):
        result = []
        self.collect_event.set()
        while self.collect_event.is_set():
            # WAIT FOR DATA COLLECT END
            pass

        for w in self.worker:
            w.event.wait()

        while not self.queue.empty():
            result.append(self.queue.get())

        for w in self.worker:
            w.event.clear()
        return result

    def now_episode(self):
        value = self._now_episode.value
        return value
Пример #11
0
class MultimodalPatchesCache(object):
    def __init__(self,
                 cache_dir,
                 dataset_dir,
                 dataset_list,
                 cuda,
                 batch_size=500,
                 num_workers=3,
                 renew_frequency=5,
                 rejection_radius_position=0,
                 numpatches=900,
                 numneg=3,
                 pos_thr=50.0,
                 reject=True,
                 mode='train',
                 rejection_radius=3000,
                 dist_type='3D',
                 patch_radius=None,
                 use_depth=False,
                 use_normals=False,
                 use_silhouettes=False,
                 color_jitter=False,
                 greyscale=False,
                 maxres=4096,
                 scale_jitter=False,
                 photo_jitter=False,
                 uniform_negatives=False,
                 needles=0,
                 render_only=False,
                 maxitems=200,
                 cache_once=False):
        super(MultimodalPatchesCache, self).__init__()
        self.cache_dir = cache_dir
        self.dataset_dir = dataset_dir
        #self.images_path = images_path
        self.dataset_list = dataset_list
        self.cuda = cuda
        self.batch_size = batch_size

        self.num_workers = num_workers
        self.renew_frequency = renew_frequency
        self.rejection_radius_position = rejection_radius_position
        self.numpatches = numpatches
        self.numneg = numneg
        self.pos_thr = pos_thr
        self.reject = reject
        self.mode = mode
        self.rejection_radius = rejection_radius
        self.dist_type = dist_type
        self.patch_radius = patch_radius
        self.use_depth = use_depth
        self.use_normals = use_normals
        self.use_silhouettes = use_silhouettes
        self.color_jitter = color_jitter
        self.greyscale = greyscale
        self.maxres = maxres
        self.scale_jitter = scale_jitter
        self.photo_jitter = photo_jitter
        self.uniform_negatives = uniform_negatives
        self.needles = needles
        self.render_only = render_only

        self.cache_done_lock = Lock()
        self.all_done = Value('B', 0)  # 0 is False
        self.cache_done = Value('B', 0)  # 0 is False

        self.wait_for_cache_builder = Event()
        # prepare for wait until initial cache is built
        self.wait_for_cache_builder.clear()
        self.cache_builder_resume = Event()

        self.maxitems = maxitems
        self.cache_once = cache_once

        if self.mode == 'eval':
            self.maxitems = -1
        self.cache_builder = Process(target=self.buildCache,
                                     args=[self.maxitems])
        self.current_cache_build = Value('B', 0)  # 0th cache
        self.current_cache_use = Value('B', 1)  # 1th cache

        self.cache_names = ["cache1", "cache2"]  # constant

        rebuild_cache = True
        if self.mode == 'eval':
            validation_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(validation_dir):
                # we don't need to rebuild validation cache
                # TODO: check if cache is VALID
                rebuild_cache = False
        elif cache_once:
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                # we don't need to rebuild training cache if we are training
                # on limited subset of the training set
                rebuild_cache = False

        if rebuild_cache:
            # clear the caches if they already exist
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            use_dataset_dir = os.path.join(
                self.cache_dir, self.cache_names[self.current_cache_use.value])
            if os.path.isdir(use_dataset_dir):
                shutil.rmtree(use_dataset_dir)

            os.makedirs(build_dataset_dir)

            self.cache_builder_resume.set()
            self.cache_builder.start()

            # wait until initial cache is built
            # print("before wait to build")
            # print("wait for cache builder state",
            #       self.wait_for_cache_builder.is_set())
            self.wait_for_cache_builder.wait()
            # print("after wait to build")

        # we have been resumed
        if self.mode != 'eval' and (not self.cache_once):
            # for training, we can set up the cache builder to build
            # the second cache
            self.restart()
        else:
            # else for validation we don't need second cache
            # we just need to switch the built cache to the use cache in order
            # to use it
            tmp = self.current_cache_build.value
            self.current_cache_build.value = self.current_cache_use.value
            self.current_cache_use.value = tmp

        # initialization finished, now this dataset can be used

    def getCurrentCache(self):
        # Lock should not be needed - cache_done is not touched
        # and cache_len is read only for cache in use, which should not
        # been touched by other threads
        # self.cache_done_lock.acquire()
        h5_dataset_filename = os.path.join(
            self.cache_dir, self.cache_names[self.current_cache_use.value])
        # self.cache_done_lock.release()
        return h5_dataset_filename

    def restart(self):
        # print("Restarting - waiting for lock...")
        self.cache_done_lock.acquire()
        # print("Restarting cached dataset...")
        if self.cache_done.value and (not self.cache_once):
            cache_changed = True
            tmp_cache_name = self.current_cache_use.value
            self.current_cache_use.value = self.current_cache_build.value
            self.current_cache_build.value = tmp_cache_name
            # clear the old cache if exists
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            os.makedirs(build_dataset_dir)
            self.cache_done.value = 0  # 0 is False
            self.cache_builder_resume.set()
            # print("Switched cache to: ",
            #       self.cache_names[self.current_cache_use.value]
            # )
        else:
            cache_changed = False
            # print(
            #     "New cache not ready, continuing with old cache:",
            #     self.cache_names[self.current_cache_use.value]
            # )
        all_done_value = self.all_done.value
        self.cache_done_lock.release()
        # returns true if no more items are available to be loaded
        # this object should be destroyed and new dataset should be created
        # in order to start over.
        return cache_changed, all_done_value

    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None

    @staticmethod
    def loadBatch(sample_batched, mode, device, keep_all=False):
        if mode == 'eval':
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[10]
            item_id = sample_batched[11]
        else:
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[9]
            item_id = sample_batched[10]
        if keep_all:
            # requested to return fill batch
            batchsize = sample_batched[0].shape[0]
            keep = torch.ones(batchsize).byte()
        keep = keep.reshape(-1)
        keep = keep.bool()
        anchor = sample_batched[0]
        pos = sample_batched[1]
        neg = sample_batched[2]

        # swapped photo to render
        anchor_r = sample_batched[3]
        pos_p = sample_batched[4]
        neg_p = sample_batched[5]

        anchor = anchor[keep].to(device)
        pos = pos[keep].to(device)
        neg = neg[keep].to(device)

        anchor_r = anchor_r[keep]
        pos_p = pos_p[keep]
        neg_p = neg_p[keep]

        coords1 = coords1[keep]
        coords2 = coords2[keep]
        coords_neg = coords_neg[keep]
        item_id = item_id[keep]
        return anchor, pos, neg, anchor_r, pos_p, neg_p, coords1, coords2, \
            coords_neg, item_id

    @staticmethod
    def generateTrainingData(queue,
                             dataset,
                             dataloader,
                             batch_size,
                             qmaxsize,
                             resume,
                             shuffle=True,
                             disable_tqdm=False):
        local_buffer_a = []
        local_buffer_p = []
        local_buffer_n = []

        local_buffer_ar = []
        local_buffer_pp = []
        local_buffer_np = []

        local_buffer_c1 = []
        local_buffer_c2 = []
        local_buffer_cneg = []
        local_buffer_id = []
        nbatches = 10
        # cannot load to cuda in batckground process!
        device = torch.device('cpu')

        buffer_size = min(qmaxsize * batch_size, nbatches * batch_size)
        bidx = 0
        for i_batch, sample_batched in enumerate(dataloader):
            # tqdm(dataloader, disable=disable_tqdm)
            resume.wait()
            anchor, pos, neg, anchor_r, \
                pos_p, neg_p, c1, c2, cneg, id = \
                MultimodalPatchesCache.loadBatch(
                    sample_batched, dataset.mode, device
                )
            if anchor.shape[0] == 0:
                continue
            local_buffer_a.extend(list(anchor))  # [:current_batches]
            local_buffer_p.extend(list(pos))
            local_buffer_n.extend(list(neg))

            local_buffer_ar.extend(list(anchor_r))
            local_buffer_pp.extend(list(pos_p))
            local_buffer_np.extend(list(neg_p))

            local_buffer_c1.extend(list(c1))
            local_buffer_c2.extend(list(c2))
            local_buffer_cneg.extend(list(cneg))
            local_buffer_id.extend(list(id))
            if len(local_buffer_a) >= buffer_size:
                if shuffle:
                    local_buffer_a, local_buffer_p, local_buffer_n, \
                        local_buffer_ar, local_buffer_pp, local_buffer_np, \
                        local_buffer_c1, local_buffer_c2, local_buffer_cneg, \
                        local_buffer_id = sklearn.utils.shuffle(
                            local_buffer_a,
                            local_buffer_p,
                            local_buffer_n,
                            local_buffer_ar,
                            local_buffer_pp,
                            local_buffer_np,
                            local_buffer_c1,
                            local_buffer_c2,
                            local_buffer_cneg,
                            local_buffer_id
                        )
                curr_nbatches = int(np.floor(len(local_buffer_a) / batch_size))
                for i in range(0, curr_nbatches):
                    queue.put([
                        bidx,
                        torch.stack(local_buffer_a[:batch_size]),
                        torch.stack(local_buffer_p[:batch_size]),
                        torch.stack(local_buffer_n[:batch_size]),
                        torch.stack(local_buffer_ar[:batch_size]),
                        torch.stack(local_buffer_pp[:batch_size]),
                        torch.stack(local_buffer_np[:batch_size]),
                        torch.stack(local_buffer_c1[:batch_size]),
                        torch.stack(local_buffer_c2[:batch_size]),
                        torch.stack(local_buffer_cneg[:batch_size]),
                        torch.stack(local_buffer_id[:batch_size])
                    ])
                    del local_buffer_a[:batch_size]
                    del local_buffer_p[:batch_size]
                    del local_buffer_n[:batch_size]
                    del local_buffer_ar[:batch_size]
                    del local_buffer_pp[:batch_size]
                    del local_buffer_np[:batch_size]
                    del local_buffer_c1[:batch_size]
                    del local_buffer_c2[:batch_size]
                    del local_buffer_cneg[:batch_size]
                    del local_buffer_id[:batch_size]
                    bidx += 1
        remaining_batches = len(local_buffer_a) // batch_size
        for i in range(0, remaining_batches):
            queue.put([
                bidx,
                torch.stack(local_buffer_a[:batch_size]),
                torch.stack(local_buffer_p[:batch_size]),
                torch.stack(local_buffer_n[:batch_size]),
                torch.stack(local_buffer_ar[:batch_size]),
                torch.stack(local_buffer_pp[:batch_size]),
                torch.stack(local_buffer_np[:batch_size]),
                torch.stack(local_buffer_c1[:batch_size]),
                torch.stack(local_buffer_c2[:batch_size]),
                torch.stack(local_buffer_cneg[:batch_size]),
                torch.stack(local_buffer_id[:batch_size])
            ])
            del local_buffer_a[:batch_size]
            del local_buffer_p[:batch_size]
            del local_buffer_n[:batch_size]
            del local_buffer_ar[:batch_size]
            del local_buffer_pp[:batch_size]
            del local_buffer_np[:batch_size]
            del local_buffer_c1[:batch_size]
            del local_buffer_c2[:batch_size]
            del local_buffer_cneg[:batch_size]
            del local_buffer_id[:batch_size]
        ra = torch.randn(batch_size, 3, 64, 64)
        queue.put([-1, ra, ra, ra])
        queue.join()

    @staticmethod
    def my_collate(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)
Пример #12
0
class LearnerWorker:
    def __init__(
        self,
        worker_idx,
        policy_id,
        cfg,
        obs_space,
        action_space,
        report_queue,
        policy_worker_queues,
        shared_buffers,
        policy_lock,
        resume_experience_collection_cv,
    ):
        log.info('Initializing the learner %d for policy %d', worker_idx,
                 policy_id)

        self.worker_idx = worker_idx
        self.policy_id = policy_id

        self.cfg = cfg

        # PBT-related stuff
        self.should_save_model = True  # set to true if we need to save the model to disk on the next training iteration
        self.load_policy_id = None  # non-None when we need to replace our parameters with another policy's parameters
        self.pbt_mutex = threading.Lock()
        self.new_cfg = None  # non-None when we need to update the learning hyperparameters

        self.terminate = False

        self.obs_space = obs_space
        self.action_space = action_space

        self.rollout_tensors = shared_buffers.tensor_trajectories
        self.traj_tensors_available = shared_buffers.is_traj_tensor_available
        self.policy_versions = shared_buffers.policy_versions
        self.stop_experience_collection = shared_buffers.stop_experience_collection

        self.device = None
        self.actor_critic = None
        self.optimizer = None
        self.policy_lock = policy_lock
        self.resume_experience_collection_cv = resume_experience_collection_cv

        self.task_queue = faster_fifo.Queue()
        self.report_queue = report_queue

        self.initialized_event = MultiprocessingEvent()
        self.initialized_event.clear()

        self.model_saved_event = MultiprocessingEvent()
        self.model_saved_event.clear()

        # queues corresponding to policy workers using the same policy
        # we send weight updates via these queues
        self.policy_worker_queues = policy_worker_queues

        self.experience_buffer_queue = Queue()

        self.tensor_batch_pool = ObjectPool()
        self.tensor_batcher = TensorBatcher(self.tensor_batch_pool)

        self.with_training = True  # set to False for debugging no-training regime
        self.train_in_background = self.cfg.train_in_background_thread  # set to False for debugging

        self.training_thread = Thread(
            target=self._train_loop) if self.train_in_background else None
        self.train_thread_initialized = threading.Event()

        self.is_training = False

        self.train_step = self.env_steps = 0

        # decay rate at which summaries are collected
        # save summaries every 20 seconds in the beginning, but decay to every 4 minutes in the limit, because we
        # do not need frequent summaries for longer experiments
        self.summary_rate_decay_seconds = LinearDecay([(0, 20), (100000, 120),
                                                       (1000000, 240)])
        self.last_summary_time = 0

        self.last_saved_time = self.last_milestone_time = 0

        self.discarded_experience_over_time = deque([], maxlen=30)
        self.discarded_experience_timer = time.time()
        self.num_discarded_rollouts = 0

        self.process = Process(target=self._run, daemon=True)

    def start_process(self):
        self.process.start()

    def _init(self):
        log.info('Waiting for the learner to initialize...')
        self.train_thread_initialized.wait()
        log.info('Learner %d initialized', self.worker_idx)
        self.initialized_event.set()

    def _terminate(self):
        self.terminate = True

    def _broadcast_model_weights(self):
        state_dict = self.actor_critic.state_dict()
        policy_version = self.train_step
        log.debug('Broadcast model weights for model version %d',
                  policy_version)
        model_state = (policy_version, state_dict)
        for q in self.policy_worker_queues:
            q.put((TaskType.INIT_MODEL, model_state))

    def _calculate_gae(self, buffer):
        """
        Calculate advantages using Generalized Advantage Estimation.
        This is leftover the from previous version of the algorithm.
        Perhaps should be re-implemented in PyTorch tensors, similar to V-trace for uniformity.
        """

        rewards = torch.stack(buffer.rewards).numpy().squeeze()  # [E, T]
        dones = torch.stack(buffer.dones).numpy().squeeze()  # [E, T]
        values_arr = torch.stack(buffer.values).numpy().squeeze()  # [E, T]

        # calculating fake values for the last step in the rollout
        # this will make sure that advantage of the very last action is always zero
        values = []
        for i in range(len(values_arr)):
            last_value, last_reward = values_arr[i][-1], rewards[i, -1]
            next_value = (last_value - last_reward) / self.cfg.gamma
            values.append(list(values_arr[i]))
            values[i].append(float(next_value))  # [T] -> [T+1]

        # calculating returns and GAE
        rewards = rewards.transpose((1, 0))  # [E, T] -> [T, E]
        dones = dones.transpose((1, 0))  # [E, T] -> [T, E]
        values = np.asarray(values).transpose((1, 0))  # [E, T+1] -> [T+1, E]

        advantages, returns = calculate_gae(rewards, dones, values,
                                            self.cfg.gamma,
                                            self.cfg.gae_lambda)

        # transpose tensors back to [E, T] before creating a single experience buffer
        buffer.advantages = advantages.transpose((1, 0))  # [T, E] -> [E, T]
        buffer.returns = returns.transpose((1, 0))  # [T, E] -> [E, T]
        buffer.returns = buffer.returns[:, :,
                                        np.newaxis]  # [E, T] -> [E, T, 1]

        buffer.advantages = [torch.tensor(buffer.advantages).reshape(-1)]
        buffer.returns = [torch.tensor(buffer.returns).reshape(-1)]

        return buffer

    def _mark_rollout_buffer_free(self, rollout):
        r = rollout
        self.traj_tensors_available[r.worker_idx,
                                    r.split_idx][r.env_idx, r.agent_idx,
                                                 r.traj_buffer_idx] = 1

    def _prepare_train_buffer(self, rollouts, macro_batch_size, timing):
        trajectories = [AttrDict(r['t']) for r in rollouts]

        with timing.add_time('buffers'):
            buffer = AttrDict()

            # by the end of this loop the buffer is a dictionary containing lists of numpy arrays
            for i, t in enumerate(trajectories):
                for key, x in t.items():
                    if key not in buffer:
                        buffer[key] = []
                    buffer[key].append(x)

            # convert lists of dict observations to a single dictionary of lists
            for key, x in buffer.items():
                if isinstance(x[0], (dict, OrderedDict)):
                    buffer[key] = list_of_dicts_to_dict_of_lists(x)

        if not self.cfg.with_vtrace:
            with timing.add_time('calc_gae'):
                buffer = self._calculate_gae(buffer)

        with timing.add_time('batching'):
            # concatenate rollouts from different workers into a single batch efficiently
            # that is, if we already have memory for the buffers allocated, we can just copy the data into
            # existing cached tensors instead of creating new ones. This is a performance optimization.
            use_pinned_memory = self.cfg.device == 'gpu'
            buffer = self.tensor_batcher.cat(buffer, macro_batch_size,
                                             use_pinned_memory, timing)

        with timing.add_time('buff_ready'):
            for r in rollouts:
                self._mark_rollout_buffer_free(r)

        with timing.add_time('tensors_gpu_float'):
            device_buffer = self._copy_train_data_to_device(buffer)

        with timing.add_time('squeeze'):
            # will squeeze actions only in simple categorical case
            tensors_to_squeeze = [
                'actions', 'log_prob_actions', 'policy_version', 'values',
                'rewards', 'dones'
            ]
            for tensor_name in tensors_to_squeeze:
                device_buffer[tensor_name].squeeze_()

        # we no longer need the cached buffer, and can put it back into the pool
        self.tensor_batch_pool.put(buffer)
        return device_buffer

    def _macro_batch_size(self, batch_size):
        return self.cfg.num_batches_per_iteration * batch_size

    def _process_macro_batch(self, rollouts, batch_size, timing):
        macro_batch_size = self._macro_batch_size(batch_size)

        assert macro_batch_size % self.cfg.rollout == 0
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert macro_batch_size % self.cfg.recurrence == 0

        samples = env_steps = 0
        for rollout in rollouts:
            samples += rollout['length']
            env_steps += rollout['env_steps']

        with timing.add_time('prepare'):
            buffer = self._prepare_train_buffer(rollouts, macro_batch_size,
                                                timing)
            self.experience_buffer_queue.put(
                (buffer, batch_size, samples, env_steps))

    def _process_rollouts(self, rollouts, timing):
        # batch_size can potentially change through PBT, so we should keep it the same and pass it around
        # using function arguments, instead of using global self.cfg

        batch_size = self.cfg.batch_size
        rollouts_in_macro_batch = self._macro_batch_size(
            batch_size) // self.cfg.rollout

        if len(rollouts) < rollouts_in_macro_batch:
            return rollouts

        discard_rollouts = 0
        policy_version = self.train_step
        for r in rollouts:
            rollout_min_version = r['t']['policy_version'].min().item()
            if policy_version - rollout_min_version >= self.cfg.max_policy_lag:
                discard_rollouts += 1
                self._mark_rollout_buffer_free(r)
            else:
                break

        if discard_rollouts > 0:
            log.warning(
                'Discarding %d old rollouts, cut by policy lag threshold %d (learner %d)',
                discard_rollouts,
                self.cfg.max_policy_lag,
                self.policy_id,
            )
            rollouts = rollouts[discard_rollouts:]
            self.num_discarded_rollouts += discard_rollouts

        if len(rollouts) >= rollouts_in_macro_batch:
            # process newest rollouts
            rollouts_to_process = rollouts[:rollouts_in_macro_batch]
            rollouts = rollouts[rollouts_in_macro_batch:]

            self._process_macro_batch(rollouts_to_process, batch_size, timing)
            # log.info('Unprocessed rollouts: %d (%d samples)', len(rollouts), len(rollouts) * self.cfg.rollout)

        return rollouts

    def _get_minibatches(self, batch_size, experience_size):
        """Generating minibatches for training."""
        assert self.cfg.rollout % self.cfg.recurrence == 0
        assert experience_size % batch_size == 0, f'experience size: {experience_size}, batch size: {batch_size}'

        if self.cfg.num_batches_per_iteration == 1:
            return [
                None
            ]  # single minibatch is actually the entire buffer, we don't need indices

        # indices that will start the mini-trajectories from the same episode (for bptt)
        indices = np.arange(0, experience_size, self.cfg.recurrence)
        indices = np.random.permutation(indices)

        # complete indices of mini trajectories, e.g. with recurrence==4: [4, 16] -> [4, 5, 6, 7, 16, 17, 18, 19]
        indices = [np.arange(i, i + self.cfg.recurrence) for i in indices]
        indices = np.concatenate(indices)

        assert len(indices) == experience_size

        num_minibatches = experience_size // batch_size
        minibatches = np.split(indices, num_minibatches)
        return minibatches

    @staticmethod
    def _get_minibatch(buffer, indices):
        if indices is None:
            # handle the case of a single batch, where the entire buffer is a minibatch
            return buffer

        mb = AttrDict()

        for item, x in buffer.items():
            if isinstance(x, (dict, OrderedDict)):
                mb[item] = AttrDict()
                for key, x_elem in x.items():
                    mb[item][key] = x_elem[indices]
            else:
                mb[item] = x[indices]

        return mb

    def _should_save_summaries(self):
        summaries_every_seconds = self.summary_rate_decay_seconds.at(
            self.train_step)
        if time.time() - self.last_summary_time < summaries_every_seconds:
            return False

        return True

    def _after_optimizer_step(self):
        """A hook to be called after each optimizer step."""
        self.train_step += 1
        self._maybe_save()

    def _maybe_save(self):
        if time.time(
        ) - self.last_saved_time >= self.cfg.save_every_sec or self.should_save_model:
            self._save()
            self.model_saved_event.set()
            self.should_save_model = False
            self.last_saved_time = time.time()

    @staticmethod
    def checkpoint_dir(cfg, policy_id):
        checkpoint_dir = join(experiment_dir(cfg=cfg),
                              f'checkpoint_p{policy_id}')
        return ensure_dir_exists(checkpoint_dir)

    @staticmethod
    def get_checkpoints(checkpoints_dir):
        checkpoints = glob.glob(join(checkpoints_dir, 'checkpoint_*'))
        return sorted(checkpoints)

    def _get_checkpoint_dict(self):
        checkpoint = {
            'train_step': self.train_step,
            'env_steps': self.env_steps,
            'model': self.actor_critic.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        return checkpoint

    def _save(self):
        checkpoint = self._get_checkpoint_dict()
        assert checkpoint is not None

        checkpoint_dir = self.checkpoint_dir(self.cfg, self.policy_id)
        tmp_filepath = join(checkpoint_dir, '.temp_checkpoint')
        checkpoint_name = f'checkpoint_{self.train_step:09d}_{self.env_steps}.pth'
        filepath = join(checkpoint_dir, checkpoint_name)
        log.info('Saving %s...', tmp_filepath)
        torch.save(checkpoint, tmp_filepath)
        log.info('Renaming %s to %s', tmp_filepath, filepath)
        os.rename(tmp_filepath, filepath)

        while len(self.get_checkpoints(
                checkpoint_dir)) > self.cfg.keep_checkpoints:
            oldest_checkpoint = self.get_checkpoints(checkpoint_dir)[0]
            if os.path.isfile(oldest_checkpoint):
                log.debug('Removing %s', oldest_checkpoint)
                os.remove(oldest_checkpoint)

        if self.cfg.save_milestones_sec > 0:
            # milestones enabled
            if time.time(
            ) - self.last_milestone_time >= self.cfg.save_milestones_sec:
                milestones_dir = ensure_dir_exists(
                    join(checkpoint_dir, 'milestones'))
                milestone_path = join(milestones_dir,
                                      f'{checkpoint_name}.milestone')
                log.debug('Saving a milestone %s', milestone_path)
                shutil.copy(filepath, milestone_path)
                self.last_milestone_time = time.time()

    @staticmethod
    def _policy_loss(ratio, adv, clip_ratio_low, clip_ratio_high):
        clipped_ratio = torch.clamp(ratio, clip_ratio_low, clip_ratio_high)
        loss_unclipped = ratio * adv
        loss_clipped = clipped_ratio * adv
        loss = torch.min(loss_unclipped, loss_clipped)
        loss = -loss.mean()

        return loss

    def _value_loss(self, new_values, old_values, target, clip_value):
        value_clipped = old_values + torch.clamp(new_values - old_values,
                                                 -clip_value, clip_value)
        value_original_loss = (new_values - target).pow(2)
        value_clipped_loss = (value_clipped - target).pow(2)
        value_loss = torch.max(value_original_loss, value_clipped_loss)
        value_loss = value_loss.mean()
        value_loss *= self.cfg.value_loss_coeff

        return value_loss

    def _prepare_observations(self, obs_tensors, gpu_buffer_obs):
        for d, gpu_d, k, v, _ in iter_dicts_recursively(
                obs_tensors, gpu_buffer_obs):
            device, dtype = self.actor_critic.device_and_type_for_input_tensor(
                k)
            tensor = v.detach().to(device, copy=True).type(dtype)
            gpu_d[k] = tensor

    def _copy_train_data_to_device(self, buffer):
        device_buffer = copy_dict_structure(buffer)

        for key, item in buffer.items():
            if key == 'obs':
                self._prepare_observations(item, device_buffer['obs'])
            else:
                device_tensor = item.detach().to(self.device,
                                                 copy=True,
                                                 non_blocking=True)
                device_buffer[key] = device_tensor.float()

        return device_buffer

    def _train(self, gpu_buffer, batch_size, experience_size, timing):
        with torch.no_grad():
            early_stopping_tolerance = 1e-6
            early_stop = False
            prev_epoch_actor_loss = 1e9
            epoch_actor_losses = []

            # V-trace parameters
            # noinspection PyArgumentList
            rho_hat = torch.Tensor([self.cfg.vtrace_rho])
            # noinspection PyArgumentList
            c_hat = torch.Tensor([self.cfg.vtrace_c])

            clip_ratio_high = 1.0 + self.cfg.ppo_clip_ratio  # e.g. 1.1
            # this still works with e.g. clip_ratio = 2, while PPO's 1-r would give negative ratio
            clip_ratio_low = 1.0 / clip_ratio_high

            clip_value = self.cfg.ppo_clip_value
            gamma = self.cfg.gamma
            recurrence = self.cfg.recurrence

            if self.cfg.with_vtrace:
                assert recurrence == self.cfg.rollout and recurrence > 1, \
                    'V-trace requires to recurrence and rollout to be equal'

            num_sgd_steps = 0

            stats_and_summaries = None
            if not self.with_training:
                return stats_and_summaries

        for epoch in range(self.cfg.ppo_epochs):
            with timing.add_time('epoch_init'):
                if early_stop or self.terminate:
                    break

                summary_this_epoch = force_summaries = False

                minibatches = self._get_minibatches(batch_size,
                                                    experience_size)

            for batch_num in range(len(minibatches)):
                with timing.add_time('minibatch_init'):
                    indices = minibatches[batch_num]

                    # current minibatch consisting of short trajectory segments with length == recurrence
                    mb = self._get_minibatch(gpu_buffer, indices)

                # calculate policy head outside of recurrent loop
                with timing.add_time('forward_head'):
                    head_outputs = self.actor_critic.forward_head(mb.obs)

                # initial rnn states
                with timing.add_time('bptt_initial'):
                    rnn_states = mb.rnn_states[::recurrence]
                    is_same_episode = 1.0 - mb.dones.unsqueeze(dim=1)

                # calculate RNN outputs for each timestep in a loop
                with timing.add_time('bptt'):
                    core_outputs = []
                    for i in range(recurrence):
                        # indices of head outputs corresponding to the current timestep
                        step_head_outputs = head_outputs[i::recurrence]

                        with timing.add_time('bptt_forward_core'):
                            core_output, rnn_states = self.actor_critic.forward_core(
                                step_head_outputs, rnn_states)
                            core_outputs.append(core_output)

                        if self.cfg.use_rnn:
                            # zero-out RNN states on the episode boundary
                            with timing.add_time('bptt_rnn_states'):
                                is_same_episode_step = is_same_episode[
                                    i::recurrence]
                                rnn_states = rnn_states * is_same_episode_step

                with timing.add_time('tail'):
                    # transform core outputs from [T, Batch, D] to [Batch, T, D] and then to [Batch x T, D]
                    # which is the same shape as the minibatch
                    core_outputs = torch.stack(core_outputs)

                    num_timesteps, num_trajectories = core_outputs.shape[:2]
                    assert num_timesteps == recurrence
                    assert num_timesteps * num_trajectories == batch_size
                    core_outputs = core_outputs.transpose(0, 1).reshape(
                        -1, *core_outputs.shape[2:])
                    assert core_outputs.shape[0] == head_outputs.shape[0]

                    # calculate policy tail outside of recurrent loop
                    result = self.actor_critic.forward_tail(
                        core_outputs, with_action_distribution=True)

                    action_distribution = result.action_distribution
                    log_prob_actions = action_distribution.log_prob(mb.actions)
                    ratio = torch.exp(log_prob_actions -
                                      mb.log_prob_actions)  # pi / pi_old

                    # super large/small values can cause numerical problems and are probably noise anyway
                    ratio = torch.clamp(ratio, 0.05, 20.0)

                    values = result.values.squeeze()

                with torch.no_grad(
                ):  # these computations are not the part of the computation graph
                    if self.cfg.with_vtrace:
                        ratios_cpu = ratio.cpu()
                        values_cpu = values.cpu()
                        rewards_cpu = mb.rewards.cpu(
                        )  # we only need this on CPU, potential minor optimization
                        dones_cpu = mb.dones.cpu()

                        vtrace_rho = torch.min(rho_hat, ratios_cpu)
                        vtrace_c = torch.min(c_hat, ratios_cpu)

                        vs = torch.zeros((num_trajectories * recurrence))
                        adv = torch.zeros((num_trajectories * recurrence))

                        next_values = (
                            values_cpu[recurrence - 1::recurrence] -
                            rewards_cpu[recurrence - 1::recurrence]) / gamma
                        next_vs = next_values

                        with timing.add_time('vtrace'):
                            for i in reversed(range(self.cfg.recurrence)):
                                rewards = rewards_cpu[i::recurrence]
                                dones = dones_cpu[i::recurrence]
                                not_done = 1.0 - dones
                                not_done_times_gamma = not_done * gamma

                                curr_values = values_cpu[i::recurrence]
                                curr_vtrace_rho = vtrace_rho[i::recurrence]
                                curr_vtrace_c = vtrace_c[i::recurrence]

                                delta_s = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma *
                                    next_values - curr_values)
                                adv[i::recurrence] = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma * next_vs -
                                    curr_values)
                                next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (
                                    next_vs - next_values)
                                vs[i::recurrence] = next_vs

                                next_values = curr_values

                        targets = vs
                    else:
                        # using regular GAE
                        adv = mb.advantages
                        targets = mb.returns

                    adv_mean = adv.mean()
                    adv_std = adv.std()
                    adv = (adv - adv_mean) / max(
                        1e-3, adv_std)  # normalize advantage
                    adv = adv.to(self.device)

                with timing.add_time('losses'):
                    policy_loss = self._policy_loss(ratio, adv, clip_ratio_low,
                                                    clip_ratio_high)

                    entropy = action_distribution.entropy()
                    if self.cfg.entropy_loss_coeff > 0.0:
                        entropy_loss = -self.cfg.entropy_loss_coeff * entropy.mean(
                        )
                    else:
                        entropy_loss = 0.0

                    actor_loss = policy_loss + entropy_loss
                    epoch_actor_losses.append(actor_loss.item())

                    targets = targets.to(self.device)
                    old_values = mb.values
                    value_loss = self._value_loss(values, old_values, targets,
                                                  clip_value)
                    critic_loss = value_loss

                    loss = actor_loss + critic_loss

                    high_loss = 30.0
                    if abs(to_scalar(policy_loss)) > high_loss or abs(
                            to_scalar(value_loss)) > high_loss or abs(
                                to_scalar(entropy_loss)) > high_loss:
                        log.warning(
                            'High loss value: %.4f %.4f %.4f %.4f',
                            to_scalar(loss),
                            to_scalar(policy_loss),
                            to_scalar(value_loss),
                            to_scalar(entropy_loss),
                        )
                        force_summaries = True

                with timing.add_time('update'):
                    # update the weights
                    self.optimizer.zero_grad()
                    loss.backward()

                    if self.cfg.max_grad_norm > 0.0:
                        with timing.add_time('clip'):
                            torch.nn.utils.clip_grad_norm_(
                                self.actor_critic.parameters(),
                                self.cfg.max_grad_norm)

                    curr_policy_version = self.train_step  # policy version before the weight update
                    with self.policy_lock:
                        self.optimizer.step()

                    num_sgd_steps += 1

                with torch.no_grad():
                    with timing.add_time('after_optimizer'):
                        self._after_optimizer_step()

                        # collect and report summaries
                        with_summaries = self._should_save_summaries(
                        ) or force_summaries
                        if with_summaries and not summary_this_epoch:
                            stats_and_summaries = self._record_summaries(
                                AttrDict(locals()))
                            summary_this_epoch = True
                            force_summaries = False

            # end of an epoch
            # this will force policy update on the inference worker (policy worker)
            self.policy_versions[self.policy_id] = self.train_step

            new_epoch_actor_loss = np.mean(epoch_actor_losses)
            loss_delta_abs = abs(prev_epoch_actor_loss - new_epoch_actor_loss)
            if loss_delta_abs < early_stopping_tolerance:
                early_stop = True
                log.debug(
                    'Early stopping after %d epochs (%d sgd steps), loss delta %.7f',
                    epoch + 1,
                    num_sgd_steps,
                    loss_delta_abs,
                )
                break

            prev_epoch_actor_loss = new_epoch_actor_loss
            epoch_actor_losses = []

        return stats_and_summaries

    def _record_summaries(self, train_loop_vars):
        var = train_loop_vars

        self.last_summary_time = time.time()
        stats = AttrDict()

        grad_norm = sum(
            p.grad.data.norm(2).item()**2
            for p in self.actor_critic.parameters() if p.grad is not None)**0.5
        stats.grad_norm = grad_norm
        stats.loss = var.loss
        stats.value = var.result.values.mean()
        stats.entropy = var.action_distribution.entropy().mean()
        stats.policy_loss = var.policy_loss
        stats.value_loss = var.value_loss
        stats.entropy_loss = var.entropy_loss
        stats.adv_min = var.adv.min()
        stats.adv_max = var.adv.max()
        stats.adv_std = var.adv_std
        stats.max_abs_logprob = torch.abs(var.mb.action_logits).max()

        if hasattr(var.action_distribution, 'summaries'):
            stats.update(var.action_distribution.summaries())

        if var.epoch == self.cfg.ppo_epochs - 1 and var.batch_num == len(
                var.minibatches) - 1:
            # we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style
            ratio_mean = torch.abs(1.0 - var.ratio).mean().detach()
            ratio_min = var.ratio.min().detach()
            ratio_max = var.ratio.max().detach()
            # log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item())

            value_delta = torch.abs(var.values - var.old_values)
            value_delta_avg, value_delta_max = value_delta.mean(
            ), value_delta.max()

            # calculate KL-divergence with the behaviour policy action distribution
            old_action_distribution = get_action_distribution(
                self.actor_critic.action_space,
                var.mb.action_logits,
            )
            kl_old = var.action_distribution.kl_divergence(
                old_action_distribution)
            kl_old_mean = kl_old.mean()

            stats.kl_divergence = kl_old_mean
            stats.value_delta = value_delta_avg
            stats.value_delta_max = value_delta_max
            stats.fraction_clipped = (
                (var.ratio < var.clip_ratio_low).float() +
                (var.ratio > var.clip_ratio_high).float()).mean()
            stats.ratio_mean = ratio_mean
            stats.ratio_min = ratio_min
            stats.ratio_max = ratio_max
            stats.num_sgd_steps = var.num_sgd_steps

        # this caused numerical issues on some versions of PyTorch with second moment reaching infinity
        adam_max_second_moment = 0.0
        for key, tensor_state in self.optimizer.state.items():
            adam_max_second_moment = max(
                tensor_state['exp_avg_sq'].max().item(),
                adam_max_second_moment)
        stats.adam_max_second_moment = adam_max_second_moment

        version_diff = var.curr_policy_version - var.mb.policy_version
        stats.version_diff_avg = version_diff.mean()
        stats.version_diff_min = version_diff.min()
        stats.version_diff_max = version_diff.max()

        for key, value in stats.items():
            stats[key] = to_scalar(value)

        return stats

    def _update_pbt(self):
        """To be called from the training loop, same thread that updates the model!"""
        with self.pbt_mutex:
            if self.load_policy_id is not None:
                assert self.cfg.with_pbt

                log.debug('Learner %d loads policy from %d', self.policy_id,
                          self.load_policy_id)
                self.load_from_checkpoint(self.load_policy_id)
                self.load_policy_id = None

            if self.new_cfg is not None:
                for key, value in self.new_cfg.items():
                    if self.cfg[key] != value:
                        log.debug(
                            'Learner %d replacing cfg parameter %r with new value %r',
                            self.policy_id, key, value)
                        self.cfg[key] = value

                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.cfg.learning_rate
                    param_group['betas'] = (self.cfg.adam_beta1,
                                            self.cfg.adam_beta2)
                    log.debug('Updated optimizer lr to value %.7f, betas: %r',
                              param_group['lr'], param_group['betas'])

                self.new_cfg = None

    @staticmethod
    def load_checkpoint(checkpoints, device):
        if len(checkpoints) <= 0:
            log.warning('No checkpoints found')
            return None
        else:
            latest_checkpoint = checkpoints[-1]

            # extra safety mechanism to recover from spurious filesystem errors
            num_attempts = 3
            for attempt in range(num_attempts):
                try:
                    log.warning('Loading state from checkpoint %s...',
                                latest_checkpoint)
                    checkpoint_dict = torch.load(latest_checkpoint,
                                                 map_location=device)
                    return checkpoint_dict
                except Exception:
                    log.exception(
                        f'Could not load from checkpoint, attempt {attempt}')

    def _load_state(self, checkpoint_dict, load_progress=True):
        if load_progress:
            self.train_step = checkpoint_dict['train_step']
            self.env_steps = checkpoint_dict['env_steps']
        self.actor_critic.load_state_dict(checkpoint_dict['model'])
        self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
        log.info(
            'Loaded experiment state at training iteration %d, env step %d',
            self.train_step, self.env_steps)

    def init_model(self, timing):
        self.actor_critic = create_actor_critic(self.cfg, self.obs_space,
                                                self.action_space, timing)
        self.actor_critic.model_to_device(self.device)
        self.actor_critic.share_memory()

    def load_from_checkpoint(self, policy_id):
        checkpoints = self.get_checkpoints(
            self.checkpoint_dir(self.cfg, policy_id))
        checkpoint_dict = self.load_checkpoint(checkpoints, self.device)
        if checkpoint_dict is None:
            log.debug('Did not load from checkpoint, starting from scratch!')
        else:
            log.debug('Loading model from checkpoint')

            # if we're replacing our policy with another policy (under PBT), let's not reload the env_steps
            load_progress = policy_id == self.policy_id
            self._load_state(checkpoint_dict, load_progress=load_progress)

    def initialize(self, timing):
        with timing.timeit('init'):
            # initialize the Torch modules
            if self.cfg.seed is None:
                log.info('Starting seed is not provided')
            else:
                log.info('Setting fixed seed %d', self.cfg.seed)
                torch.manual_seed(self.cfg.seed)
                np.random.seed(self.cfg.seed)

            # this does not help with a single experiment
            # but seems to do better when we're running more than one experiment in parallel
            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                torch.backends.cudnn.benchmark = True

                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')
            self.init_model(timing)

            self.optimizer = torch.optim.Adam(
                self.actor_critic.parameters(),
                self.cfg.learning_rate,
                betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
                eps=self.cfg.adam_eps,
            )

            self.load_from_checkpoint(self.policy_id)

            self._broadcast_model_weights(
            )  # sync the very first version of the weights

        self.train_thread_initialized.set()

    def _process_training_data(self, data, timing, wait_stats=None):
        self.is_training = True

        buffer, batch_size, samples, env_steps = data
        assert samples == batch_size * self.cfg.num_batches_per_iteration

        self.env_steps += env_steps
        experience_size = buffer.rewards.shape[0]

        stats = dict(learner_env_steps=self.env_steps,
                     policy_id=self.policy_id)

        with timing.add_time('train'):
            discarding_rate = self._discarding_rate()

            self._update_pbt()

            train_stats = self._train(buffer, batch_size, experience_size,
                                      timing)

            if train_stats is not None:
                stats['train'] = train_stats

                if wait_stats is not None:
                    wait_avg, wait_min, wait_max = wait_stats
                    stats['train']['wait_avg'] = wait_avg
                    stats['train']['wait_min'] = wait_min
                    stats['train']['wait_max'] = wait_max

                stats['train'][
                    'discarded_rollouts'] = self.num_discarded_rollouts
                stats['train']['discarding_rate'] = discarding_rate

                stats['stats'] = memory_stats('learner', self.device)

        self.is_training = False

        try:
            self.report_queue.put(stats)
        except Full:
            log.warning(
                'Could not report training stats, the report queue is full!')

    def _train_loop(self):
        timing = Timing()
        self.initialize(timing)

        wait_times = deque([], maxlen=self.cfg.num_workers)
        last_cache_cleanup = time.time()
        num_batches_processed = 0

        while not self.terminate:
            with timing.timeit('train_wait'):
                data = safe_get(self.experience_buffer_queue)

            if self.terminate:
                break

            wait_stats = None
            wait_times.append(timing.train_wait)

            if len(wait_times) >= wait_times.maxlen:
                wait_times_arr = np.asarray(wait_times)
                wait_avg = np.mean(wait_times_arr)
                wait_min, wait_max = wait_times_arr.min(), wait_times_arr.max()
                # log.debug(
                #     'Training thread had to wait %.5f s for the new experience buffer (avg %.5f)',
                #     timing.train_wait, wait_avg,
                # )
                wait_stats = (wait_avg, wait_min, wait_max)

            self._process_training_data(data, timing, wait_stats)
            num_batches_processed += 1

            if time.time() - last_cache_cleanup > 300.0 or (
                    not self.cfg.benchmark and num_batches_processed < 50):
                if self.cfg.device == 'gpu':
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                last_cache_cleanup = time.time()

        time.sleep(0.3)
        log.info('Train loop timing: %s', timing)
        del self.actor_critic
        del self.device

    def _experience_collection_rate_stats(self):
        now = time.time()
        if now - self.discarded_experience_timer > 1.0:
            self.discarded_experience_timer = now
            self.discarded_experience_over_time.append(
                (now, self.num_discarded_rollouts))

    def _discarding_rate(self):
        if len(self.discarded_experience_over_time) <= 1:
            return 0

        first, last = self.discarded_experience_over_time[
            0], self.discarded_experience_over_time[-1]
        delta_rollouts = last[1] - first[1]
        delta_time = last[0] - first[0]
        discarding_rate = delta_rollouts / (delta_time + EPS)
        return discarding_rate

    def _extract_rollouts(self, data):
        data = AttrDict(data)
        worker_idx, split_idx, traj_buffer_idx = data.worker_idx, data.split_idx, data.traj_buffer_idx

        rollouts = []
        for rollout_data in data.rollouts:
            env_idx, agent_idx = rollout_data['env_idx'], rollout_data[
                'agent_idx']
            tensors = self.rollout_tensors.index(
                (worker_idx, split_idx, env_idx, agent_idx, traj_buffer_idx))

            rollout_data['t'] = tensors
            rollout_data['worker_idx'] = worker_idx
            rollout_data['split_idx'] = split_idx
            rollout_data['traj_buffer_idx'] = traj_buffer_idx
            rollouts.append(AttrDict(rollout_data))

        return rollouts

    def _process_pbt_task(self, pbt_task):
        task_type, data = pbt_task

        with self.pbt_mutex:
            if task_type == PbtTask.SAVE_MODEL:
                policy_id = data
                assert policy_id == self.policy_id
                self.should_save_model = True
            elif task_type == PbtTask.LOAD_MODEL:
                policy_id, new_policy_id = data
                assert policy_id == self.policy_id
                assert new_policy_id is not None
                self.load_policy_id = new_policy_id
            elif task_type == PbtTask.UPDATE_CFG:
                policy_id, new_cfg = data
                assert policy_id == self.policy_id
                self.new_cfg = new_cfg

    def _accumulated_too_much_experience(self, rollouts):
        max_minibatches_to_accumulate = self.cfg.num_minibatches_to_accumulate
        if max_minibatches_to_accumulate == -1:
            # default value
            max_minibatches_to_accumulate = 2 * self.cfg.num_batches_per_iteration

        # allow the max batches to accumulate, plus the minibatches we're currently training on
        max_minibatches_on_learner = max_minibatches_to_accumulate + self.cfg.num_batches_per_iteration

        minibatches_currently_training = int(
            self.is_training) * self.cfg.num_batches_per_iteration

        rollouts_per_minibatch = self.cfg.batch_size / self.cfg.rollout

        # count contribution from unprocessed rollouts
        minibatches_currently_accumulated = len(
            rollouts) / rollouts_per_minibatch

        # count minibatches ready for training
        minibatches_currently_accumulated += self.experience_buffer_queue.qsize(
        ) * self.cfg.num_batches_per_iteration

        total_minibatches_on_learner = minibatches_currently_training + minibatches_currently_accumulated

        return total_minibatches_on_learner >= max_minibatches_on_learner

    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        try:
            psutil.Process().nice(self.cfg.default_niceness)
        except psutil.AccessDenied:
            log.error('Low niceness requires sudo!')

        if self.cfg.device == 'gpu':
            cuda_envvars(self.policy_id)

        torch.multiprocessing.set_sharing_strategy('file_system')
        torch.set_num_threads(self.cfg.learner_main_loop_num_cores)

        timing = Timing()

        rollouts = []

        if self.train_in_background:
            self.training_thread.start()
        else:
            self.initialize(timing)
            log.error(
                'train_in_background set to False on learner %d! This is slow, use only for testing!',
                self.policy_id,
            )

        while not self.terminate:
            while True:
                try:
                    tasks = self.task_queue.get_many(timeout=0.005)

                    for task_type, data in tasks:
                        if task_type == TaskType.TRAIN:
                            with timing.add_time('extract'):
                                rollouts.extend(self._extract_rollouts(data))
                                # log.debug('Learner %d has %d rollouts', self.policy_id, len(rollouts))
                        elif task_type == TaskType.INIT:
                            self._init()
                        elif task_type == TaskType.TERMINATE:
                            time.sleep(0.3)
                            log.info('GPU learner timing: %s', timing)
                            self._terminate()
                            break
                        elif task_type == TaskType.PBT:
                            self._process_pbt_task(data)
                except Empty:
                    break

            if self._accumulated_too_much_experience(rollouts):
                # if we accumulated too much experience, signal the policy workers to stop experience collection
                if not self.stop_experience_collection[self.policy_id]:
                    log.debug(
                        'Learner %d accumulated too much experience, stop experience collection!',
                        self.policy_id)
                self.stop_experience_collection[self.policy_id] = True
            elif self.stop_experience_collection[self.policy_id]:
                # otherwise, resume the experience collection if it was stopped
                self.stop_experience_collection[self.policy_id] = False
                with self.resume_experience_collection_cv:
                    log.debug('Learner %d is resuming experience collection!',
                              self.policy_id)
                    self.resume_experience_collection_cv.notify_all()

            with torch.no_grad():
                rollouts = self._process_rollouts(rollouts, timing)

            if not self.train_in_background:
                while not self.experience_buffer_queue.empty():
                    training_data = self.experience_buffer_queue.get()
                    self._process_training_data(training_data, timing)

            self._experience_collection_rate_stats()

        if self.train_in_background:
            self.experience_buffer_queue.put(None)
            self.training_thread.join()

    def init(self):
        self.task_queue.put((TaskType.INIT, None))
        self.initialized_event.wait()

    def save_model(self, timeout=None):
        self.model_saved_event.clear()
        save_task = (PbtTask.SAVE_MODEL, self.policy_id)
        self.task_queue.put((TaskType.PBT, save_task))
        log.debug('Wait while learner %d saves the model...', self.policy_id)
        if self.model_saved_event.wait(timeout=timeout):
            log.debug('Learner %d saved the model!', self.policy_id)
        else:
            log.warning('Model saving request timed out!')
        self.model_saved_event.clear()

    def close(self):
        self.task_queue.put((TaskType.TERMINATE, None))

    def join(self):
        join_or_kill(self.process)
Пример #13
0
class BaseComponent:
    def __init__(self,
                 endpoint="tcp://0.0.0.0:4242",
                 name="",
                 metrics_collector=NullCollector(),
                 *args,
                 **kwargs):
        """
        Args:
            endpoint: the endpoint the component's zerorpc server will listen
            in.
            *args: TBD
            **kwargs: TBD
        """
        super().__init__()
        self.name = name
        self.metrics_collector = metrics_collector
        self.stop_event = Event()
        self.endpoint = endpoint
        self._routines = []
        self.zrpc = zerorpc.Server(self)
        self.zrpc.bind(endpoint)

    def _start(self):
        """
        Goes over the component's routines registered in self.routines and
        starts running them.
        """
        for routine in self._routines:
            routine.start()

    def run(self):
        """
        Starts running all the component's routines and the zerorpc server.
        """
        self._start()
        gevent.signal(signal.SIGTERM, self.stop_run)
        self.metrics_collector.setup()
        self.zrpc.run()
        self.zrpc.close()

    def register_routine(self, routine: Union[Routine, Process, Thread]):
        """
        Registers routine to the list of component's routines
        Args:
            routine: the routine to register
        """
        # TODO - write this function in a cleaner way?
        if isinstance(routine, Routine):
            if routine.stop_event is None:
                routine.stop_event = self.stop_event
            else:
                raise RegisteredException("routine is already registered")
        self._routines.append(routine)

    def _teardown_callback(self, *args, **kwargs):
        """
        Implemented by subclasses of BaseComponent. Used for stopping or
        tearing down things that are not stopped by setting the stop_event.
        Returns: None
        """
        pass

    def stop_run(self):
        """
        Signals all the component's routines to stop and then stops the zerorpc
        server.
        """
        try:
            self.zrpc.stop()
            self.stop_event.set()
            self._teardown_callback()
            for routine in self._routines:
                if isinstance(routine, Routine):
                    routine.runner.join()
                elif isinstance(routine, (Process, Thread)):
                    routine.join()
            return 0
        except RuntimeError:
            return 1
Пример #14
0
class Sink(Process):
    def __init__(self, port_out, front_sink_addr, verbose=False):
        super().__init__()
        self.port = port_out
        self.exit_flag = Event()
        self.logger = set_logger(colored('SINK', 'green'), verbose)
        self.front_sink_addr = front_sink_addr
        self.is_ready = Event()
        self.verbose = verbose

    def close(self):
        self.logger.info('shutting down...')
        self.is_ready.clear()
        self.exit_flag.set()
        self.terminate()
        self.join()
        self.logger.info('terminated!')

    def run(self):
        self._run()

    @zmqd.socket(zmq.PULL)
    @zmqd.socket(zmq.PAIR)
    @zmqd.socket(zmq.PUB)
    def _run(self, receiver, frontend, sender):
        receiver_addr = auto_bind(receiver)
        frontend.connect(self.front_sink_addr)
        sender.bind('tcp://*:%d' % self.port)

        pending_jobs: Dict[str, SinkJob] = defaultdict(lambda: SinkJob())

        poller = zmq.Poller()
        poller.register(frontend, zmq.POLLIN)
        poller.register(receiver, zmq.POLLIN)

        # send worker receiver address back to frontend
        frontend.send(receiver_addr.encode('ascii'))

        # Windows does not support logger in MP environment, thus get a new logger
        # inside the process for better compability
        logger = set_logger(colored('SINK', 'green'), self.verbose)
        logger.info('ready')
        self.is_ready.set()

        while not self.exit_flag.is_set():
            socks = dict(poller.poll())
            if socks.get(receiver) == zmq.POLLIN:
                msg = receiver.recv_multipart()
                job_id = msg[0]
                # parsing job_id and partial_id
                job_info = job_id.split(b'@')
                job_id = job_info[0]
                partial_id = int(job_info[1]) if len(job_info) == 2 else 0

                if msg[2] == ServerCmd.data_embed:
                    x = jsonapi.loads(msg[1])
                    pending_jobs[job_id].add_output(x, partial_id)
                else:
                    logger.error(
                        'received a wrongly-formatted request (expected 4 frames, got %d)' % len(msg))
                    logger.error('\n'.join('field %d: %s' % (idx, k)
                                           for idx, k in enumerate(msg)), exc_info=True)

                logger.info('collect %s %s (E:%d/A:%d)' % (msg[2], job_id,
                                                           pending_jobs[job_id].progress_outputs,
                                                           pending_jobs[job_id].checksum))

                # check if there are finished jobs, then send it back to workers

                finished = [(k, v)
                            for k, v in pending_jobs.items() if v.is_done]
                for job_info, tmp in finished:
                    client_addr, req_id = job_info.split(b'#')
                    x = tmp.result
                    sender.send_multipart([client_addr, x, req_id])
                    logger.info('send back\tsize: %d\tjob id: %s' %
                                (tmp.checksum, job_info))
                    # release the job
                    tmp.clear()
                    pending_jobs.pop(job_info)

            if socks.get(frontend) == zmq.POLLIN:
                client_addr, msg_type, msg_info, req_id = frontend.recv_multipart()
                if msg_type == ServerCmd.new_job:
                    job_info = client_addr + b'#' + req_id
                    # register a new job
                    pending_jobs[job_info].checksum = int(msg_info)
                    logger.info('job register\tsize: %d\tjob id: %s' %
                                (int(msg_info), job_info))
                elif msg_type == ServerCmd.show_config:
                    # dirty fix of slow-joiner: sleep so that client receiver can connect.
                    time.sleep(0.1)
                    logger.info('send config\tclient %s' % client_addr)
                    sender.send_multipart([client_addr, msg_info, req_id])
Пример #15
0
class BaseComponent:

    def __init__(self, component_config, start_component=False):
        self.name = ""
        self.ROUTINES_FOLDER_PATH = "pipert/contrib/routines"
        self.MONITORING_SYSTEMS_FOLDER_PATH = "pipert/contrib/metrics_collectors"
        self.use_memory = False
        self.stop_event = Event()
        self.stop_event.set()
        self.queues = {}
        self._routines = {}
        self.metrics_collector = NullCollector()
        self.parent_logger = None
        self.logger = None
        self.setup_component(component_config)
        self.metrics_collector.setup()
        if start_component:
            self.run_comp()

    def setup_component(self, component_config):
        if (component_config is None) or (type(component_config) is not dict) or\
                (component_config == {}):
            return
        component_name, component_parameters = list(component_config.items())[0]
        self.name = component_name

        self.parent_logger = create_parent_logger(self.name)
        self.logger = self.parent_logger.getChild(self.name)

        if ("shared_memory" in component_parameters) and \
                (component_parameters["shared_memory"]):
            self.use_memory = True
            self.generator = smGen(self.name)

        if "monitoring_system" in component_parameters:
            self.set_monitoring_system(component_parameters["monitoring_system"])

        for queue in component_parameters["queues"]:
            self.create_queue(queue_name=queue, queue_size=1)

        routine_factory = ClassFactory(self.ROUTINES_FOLDER_PATH)
        for routine_name, routine_parameters_real in component_parameters["routines"].items():
            routine_parameters = routine_parameters_real.copy()
            routine_parameters["name"] = routine_name
            routine_parameters['metrics_collector'] = self.metrics_collector
            routine_parameters["logger"] = self.parent_logger.getChild(routine_name)
            routine_class = routine_factory.get_class(routine_parameters.pop("routine_type_name", ""))
            if routine_class is None:
                continue
            try:
                self._replace_queue_names_with_queue_objects(routine_parameters)
            except QueueDoesNotExist as e:
                continue

            routine_parameters["component_name"] = self.name

            self.register_routine(routine_class(**routine_parameters).as_thread())

    def _replace_queue_names_with_queue_objects(self, routine_parameters_kwargs):
        for key, value in routine_parameters_kwargs.items():
            if 'queue' in key.lower():
                routine_parameters_kwargs[key] = self.get_queue(queue_name=value)

    def _start(self):
        """
        Goes over the component's routines registered in self.routines and
        starts running them.
        """
        self.logger.info("Running all routines")
        for routine in self._routines.values():
            routine.start()
            self.logger.info("{0} Started".format(routine.name))

    def run_comp(self):
        """
        Starts running all the component's routines.
        """
        self.logger.info("Running component")
        self.stop_event.clear()
        if self.use_memory and sys.version_info.minor < 8:
            self.generator.create_memories()
        self._start()
        gevent.signal_handler(signal.SIGTERM, self.stop_run)

    def register_routine(self, routine: Union[Routine, Process, Thread]):
        """
        Registers routine to the list of component's routines
        Args:
            routine: the routine to register
        """
        self.logger.info("Registering routine")
        self.logger.info(routine)
        # TODO - write this function in a cleaner way?
        if isinstance(routine, Routine):
            if routine.name in self._routines:
                self.logger.error("Routine name already exist")
                raise RegisteredException("routine name already exist")
            if routine.stop_event is None:
                routine.stop_event = self.stop_event
                if self.use_memory:
                    routine.use_memory = self.use_memory
                    routine.generator = self.generator
            else:
                self.logger.error("Routine is already registered")
                raise RegisteredException("routine is already registered")
            self.logger.info("Routine registered")
            self._routines[routine.name] = routine
        else:
            self.logger.info("Routine registered")
            self._routines[routine.__str__()] = routine

    def _teardown_callback(self, *args, **kwargs):
        """
        Implemented by subclasses of BaseComponent. Used for stopping or
        tearing down things that are not stopped by setting the stop_event.
        Returns: None
        """
        pass

    def stop_run(self):
        """
        Signals all the component's routines to stop.
        """
        self.logger.info("Stopping component")
        if self.stop_event.is_set():
            return 0
        self.stop_event.set()

        try:
            self._teardown_callback()
            if self.use_memory:
                self.logger.info("Cleaning shared memory")
                self.generator.cleanup()
            for routine in self._routines.values():
                self.logger.info("Stopping routine {0}".format(routine.name))
                if isinstance(routine, Routine):
                    routine.runner.join()
                elif isinstance(routine, (Process, Thread)):
                    routine.join()
                self.logger.info("Routine {0} stopped".format(routine.name))
            return 0
        except RuntimeError:
            return 1

    def create_queue(self, queue_name, queue_size=1):
        """
           Create a new queue for the component.
           Returns True if created or False otherwise
           Args:
               queue_name: the name of the queue, must be unique
               queue_size: the size of the queue
        """
        if queue_name in self.queues:
            return False
        self.queues[queue_name] = Queue(maxsize=queue_size)
        return True

    def get_queue(self, queue_name):
        """
           Returns the queue object by its name
           Args:
               queue_name: the name of the queue
           Raises:
               KeyError - if no queue has the name
        """
        try:
            return self.queues[queue_name]
        except KeyError:
            raise QueueDoesNotExist(queue_name)

    def get_all_queue_names(self):
        """
           Returns the list of names of queues that
           the component expose.
        """
        return list(self.queues.keys())

    def does_queue_exist(self, queue_name):
        """
           Returns True the component has a queue named
           queue_name or False otherwise
           Args:
               queue_name: the name of the queue to check
        """
        return queue_name in self.queues

    def delete_queue(self, queue_name):
        """
           Deletes a queue with the name queue_name.
           Returns True if succeeded.
           Args:
               queue_name: the name of the queue to delete
           Raises:
               KeyError - if no queue has the name queue_name
        """
        if queue_name not in self.queues:
            raise QueueDoesNotExist(queue_name)
        if self.does_routines_use_queue(queue_name=queue_name):
            return False
        try:
            del self.queues[queue_name]
            return True
        except KeyError:
            raise QueueDoesNotExist(queue_name)

    def does_routine_name_exist(self, routine_name):
        return routine_name in self._routines

    def remove_routine(self, routine_name):
        if self.does_routine_name_exist(routine_name):
            del self._routines[routine_name]
            return True
        else:
            return False

    def does_routines_use_queue(self, queue_name):
        for routine in self._routines.values():
            if routine.does_routine_use_queue(self.queues[queue_name]):
                return True
        return False

    def is_component_running(self):
        return not self.stop_event.is_set()

    def get_routines(self):
        return self._routines

    def get_component_configuration(self):
        component_dict = {
            "shared_memory": self.use_memory,
            "queues":
                list(self.get_all_queue_names()),
            "routines": {}
        }

        if type(self).__name__ != BaseComponent.__name__:
            component_dict["component_type_name"] = type(self).__name__
        for current_routine_object in self._routines.values():
            routine_creation_dict = \
                self._get_routine_creation(current_routine_object)
            routine_name = routine_creation_dict.pop("name")
            component_dict["routines"][routine_name] = \
                routine_creation_dict
        return {self.name: component_dict}

    def _get_routine_creation(self, routine):
        routine_dict = routine.get_creation_dictionary()
        routine_dict["routine_type_name"] = routine.__class__.__name__
        for routine_param_name in routine_dict.keys():
            if "queue" in routine_param_name:
                for queue_name in self.queues.keys():
                    if getattr(routine, routine_param_name) is \
                            self.queues[queue_name]:
                        routine_dict[routine_param_name] = queue_name

        return routine_dict

    def set_monitoring_system(self, monitoring_system_parameters):
        monitoring_system_factory = ClassFactory(self.MONITORING_SYSTEMS_FOLDER_PATH)
        if "name" not in monitoring_system_parameters:
            print("No name parameter found inside the monitoring system")
            return
        monitoring_system_name = monitoring_system_parameters.pop("name") + "Collector"
        monitoring_system_class = monitoring_system_factory.get_class(monitoring_system_name)
        if monitoring_system_class is None:
            return
        try:
            self.metrics_collector = monitoring_system_class(**monitoring_system_parameters)
        except TypeError:
            print("Bad parameters given for the monitoring system " + monitoring_system_name)

    def set_routine_attribute(self, routine_name, attribute_name, attribute_value):
        routine = self._routines.get(routine_name, None)
        if routine is not None:
            setattr(routine, attribute_name, attribute_value)
Пример #16
0
    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None
Пример #17
0
class DataQueue(object):
    '''Queue for data prefetching
       DataQueue launch a subprocess to avoid python's GIL
       # Arguments
            generator: instance of generator which feeds data infinitely
            max_queue_size: maximum queue size
            nb_worker: control concurrency,
                       only take effect when do preprocessing
    '''
    def __init__(self, generator, max_queue_size=5, nb_worker=1):
        self.generator = generator
        self.nb_worker = nb_worker
        self.max_queue_size = max_queue_size

        self._queue = Queue()
        self._signal = Event()
        self._available_cv = Condition()
        self._full_cv = Condition()

        args = (generator, self._queue, self._signal, self._available_cv,
                self._full_cv, self.nb_worker, self.max_queue_size)
        self.working_process = Process(target=self.generator_process,
                                       args=args)
        self.working_process.daemon = True
        self.working_process.start()

    def get(self, timeout=None):
        with self._available_cv:
            if not self._signal.is_set() and self._queue.qsize() == 0:
                self._available_cv.wait()

        if self._signal.is_set():
            raise Exception("prefetch process terminated!")

        try:
            data = self._queue.get()
            with self._full_cv:
                self._full_cv.notify()
        except Exception as e:
            with self._full_cv:
                self._signal.set()
                self._full_cv.notify_all()
                raise e

        return data

    def qsize(self):
        return self._queue.qsize()

    def __del__(self):
        with self._full_cv:
            self._signal.set()
            self._full_cv.notify_all()
        #self.working_process.terminate()
        self.working_process.join()

    @staticmethod
    def generator_process(generator, queue, signal, available_cv, full_cv,
                          nb_worker, max_qsize):
        preprocess = generator.preprocess
        generator = BackgroundGenerator(generator())  # invoke call()

        # put data in the queue
        def enqueue_fn(generator, preprocess, queue, signal, available_cv,
                       full_cv, lock, max_qsize):
            while True:
                try:
                    with lock:
                        data = next(generator)
                    data = preprocess(data)

                    if not isinstance(data, types.GeneratorType):
                        data = [data]

                    for ele in data:
                        ele = np2tensor(ele)  # numpy array to pytorch's tensor
                        with full_cv:
                            while not signal.is_set(
                            ) and queue.qsize() >= max_qsize:
                                full_cv.wait()

                        if signal.is_set(): return

                        queue.put(ele)

                        with available_cv:
                            available_cv.notify()
                except Exception as e:
                    print("Error Message", e, file=sys.stderr)
                    with full_cv:
                        signal.set()
                        full_cv.notify_all()
                    with available_cv:
                        signal.set()
                        available_cv.notify_all()
                    raise Exception("generator thread went wrong!")

        # start threads
        lock = threading.Lock()
        args = (generator, preprocess, queue, signal, available_cv, full_cv,
                lock, max_qsize)
        generator_threads = [
            threading.Thread(target=enqueue_fn, args=args)
            for _ in range(nb_worker)
        ]

        for thread in generator_threads:
            thread.daemon = True
            thread.start()

        for thread in generator_threads:
            thread.join()
Пример #18
0
class DaliIteratorCPU(DaliIterator):
    """
    Wrapper class to decode the DALI iterator output & provide iterator that functions the same as torchvision
    Note that permutation to channels first, converting from 8 bit to float & normalization are all performed on GPU

    pipelines (Pipeline): DALI pipelines
    size (int): Number of examples in set
    fp16 (bool): Use fp16 as output format, f32 otherwise
    mean (tuple): Image mean value for each channel
    std (tuple): Image standard deviation value for each channel
    pin_memory (bool): Transfer input tensor to pinned memory, before moving to GPU
    """
    def __init__(self,
                 fp16=False,
                 mean=(0., 0., 0.),
                 std=(1., 1., 1.),
                 pin_memory=True,
                 pca_jitter=False,
                 **kwargs):
        super().__init__(**kwargs)
        print('Using DALI CPU iterator')
        self.stream = torch.cuda.Stream()

        self.fp16 = fp16
        self.mean = torch.tensor(mean).cuda().view(1, 3, 1, 1)
        self.std = torch.tensor(std).cuda().view(1, 3, 1, 1)
        self.pin_memory = pin_memory
        self.pca_jitter = pca_jitter

        if self.fp16:
            self.mean = self.mean.half()
            self.std = self.std.half()

        self.proc_next_input = Event()
        self.done_event = Event()
        self.output_queue = queue.Queue(maxsize=5)
        self.preproc_thread = threading.Thread(
            target=_preproc_worker,
            kwargs={
                'dali_iterator': self._dali_iterator,
                'cuda_stream': self.stream,
                'fp16': self.fp16,
                'mean': self.mean,
                'std': self.std,
                'proc_next_input': self.proc_next_input,
                'done_event': self.done_event,
                'output_queue': self.output_queue,
                'pin_memory': self.pin_memory,
                'pca_jitter': self.pca_jitter
            })
        self.preproc_thread.daemon = True
        self.preproc_thread.start()

        self.proc_next_input.set()

    def __next__(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.output_queue.get()
        self.proc_next_input.set()
        if data is None:
            raise StopIteration
        return data

    def __del__(self):
        self.done_event.set()
        self.proc_next_input.set()
        torch.cuda.current_stream().wait_stream(self.stream)
        self.preproc_thread.join()
Пример #19
0
class AsyncLogger(Logger):
    @staticmethod
    def log_fn(self, stop_event: Event):
        try:
            self._super_create_loggers()
            self.resposne_queue.put({
                k: self.__dict__[k]
                for k in ["save_dir", "tb_logdir", "is_sweep"]
            })

            while True:
                try:
                    cmd = self.draw_queue.get(True, 0.1)
                except EmptyQueue:
                    if stop_event.is_set():
                        break
                    else:
                        continue

                self._super_log(*cmd)
                self.resposne_queue.put(True)
        except:
            print("Logger process crashed.")
            raise
        finally:
            print("Logger: syncing")
            if self.use_wandb:
                wandb.join()

            stop_event.set()
            print("Logger process terminating...")

    def create_loggers(self):
        self._super_create_loggers = super().create_loggers
        self.stop_event = Event()
        self.proc = Process(target=self.log_fn, args=(self, self.stop_event))
        self.proc.start()

        atexit.register(self.finish)

    def __init__(self, *args, **kwargs):
        self.queue = []

        self.draw_queue = Queue()
        self.resposne_queue = Queue()
        self._super_log = super().log
        self.waiting = 0

        super().__init__(*args, **kwargs)

        self.__dict__.update(self.resposne_queue.get(True))

    def log(self, plotlist, step=None):
        if self.stop_event.is_set():
            return

        if not isinstance(plotlist, list):
            plotlist = [plotlist]

        plotlist = [p for p in plotlist if p]
        if not plotlist:
            return

        plotlist = U.apply_to_tensors(plotlist, lambda x: x.detach().cpu())
        self.queue.append((plotlist, step))
        self.flush(wait=False)

    def enqueue(self, data, step: Optional[int]):
        self.draw_queue.put((data, step))
        self.waiting += 1

    def wait_logger(self, wait=False):
        cond = (lambda: not self.resposne_queue.empty()) if not wait else (
            lambda: self.waiting > 0)
        already_printed = False
        while cond() and not self.stop_event.is_set():
            will_wait = self.resposne_queue.empty()
            if will_wait and not already_printed:
                already_printed = True
                sys.stdout.write("Warning: waiting for logger... ")
                sys.stdout.flush()
            try:
                self.resposne_queue.get(True, 0.2)
            except EmptyQueue:
                continue
            self.waiting -= 1

        if already_printed:
            print("done.")

    def flush(self, wait: bool = True):
        while self.queue:
            plotlist, step = self.queue[0]

            for i, p in enumerate(plotlist):
                if isinstance(p, PlotAsync):
                    res = p.get(wait)
                    if res is not None:
                        plotlist[i] = res
                    else:
                        if wait:
                            assert p.failed
                            # Exception in the worker thread
                            print(
                                "Exception detected in a PlotAsync object. Syncing logger and ignoring further plots."
                            )
                            self.wait_logger(True)
                            self.stop_event.set()
                            self.proc.join()

                        return

            self.queue.pop(0)
            self.enqueue(plotlist, step)

        self.wait_logger(wait)

    def finish(self):
        if self.stop_event.is_set():
            return

        self.flush(True)
        self.stop_event.set()
        self.proc.join()