コード例 #1
0
    def _call_with_qiterable(self, qiterable: QIterable, num_epochs: int,
                             shuffle: bool) -> Iterator[TensorDict]:
        # JoinableQueue needed here as sharing tensors across processes
        # requires that the creating tensor not exit prematurely.
        output_queue = JoinableQueue(self.output_queue_size)

        for _ in range(num_epochs):
            qiterable.start()

            # Start the tensor-dict workers.
            for i in range(self.num_workers):
                args = (qiterable, output_queue, self.iterator, shuffle, i)
                process = Process(target=_create_tensor_dicts_from_qiterable,
                                  args=args)
                process.start()
                self.processes.append(process)

            num_finished = 0
            while num_finished < self.num_workers:
                item = output_queue.get()
                output_queue.task_done()
                if isinstance(item, int):
                    num_finished += 1
                    logger.info(
                        f"worker {item} finished ({num_finished} / {self.num_workers})"
                    )
                else:
                    yield item

            for process in self.processes:
                process.join()
            self.processes.clear()

            qiterable.join()
コード例 #2
0
def read_img(path_queue: multiprocessing.JoinableQueue,
             data_queue: multiprocessing.SimpleQueue):
    torch.set_num_threads(1)
    while True:
        img_path = path_queue.get()
        img = Image.open(img_path)
        data_queue.put(T(img))
        path_queue.task_done()
コード例 #3
0
    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None