Esempio n. 1
0
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock,
                      rx: Connection) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                if self._safe_queue_put(worker_id, (batch, None), queue, rx):
                    continue
                else:
                    # Couldn't put item on queue because parent process has exited.
                    return
        except Exception as e:
            if not self._safe_queue_put(
                    worker_id,
                (None, (repr(e), traceback.format_exc())), queue, rx):
                return

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Esempio n. 2
0
    def _call_with_qiterable(self, qiterable: QIterable, num_epochs: int,
                             shuffle: bool) -> Iterator[TensorDict]:
        # JoinableQueue needed here as sharing tensors across processes
        # requires that the creating tensor not exit prematurely.
        output_queue = JoinableQueue(self.output_queue_size)

        for _ in range(num_epochs):
            qiterable.start()

            # Start the tensor-dict workers.
            for i in range(self.num_workers):
                args = (qiterable, output_queue, self.iterator, shuffle, i)
                process = Process(target=_create_tensor_dicts_from_qiterable,
                                  args=args)
                process.start()
                self.processes.append(process)

            num_finished = 0
            while num_finished < self.num_workers:
                item = output_queue.get()
                output_queue.task_done()
                if isinstance(item, int):
                    num_finished += 1
                    logger.info(
                        f"worker {item} finished ({num_finished} / {self.num_workers})"
                    )
                else:
                    yield item

            for process in self.processes:
                process.join()
            self.processes.clear()

            qiterable.join()
def read_img(path_queue: multiprocessing.JoinableQueue,
             data_queue: multiprocessing.SimpleQueue):
    torch.set_num_threads(1)
    while True:
        img_path = path_queue.get()
        img = Image.open(img_path)
        data_queue.put(T(img))
        path_queue.task_done()
    def _gather_instances(self, queue: mp.JoinableQueue) -> Iterable[Instance]:
        done_count: int = 0
        while done_count < self.num_workers:
            for instance, worker_error in iter(queue.get, (None, None)):
                if worker_error is not None:
                    e, tb = worker_error
                    raise WorkerError(e, tb)

                self.reader.apply_token_indexers(instance)
                if self._vocab is not None:
                    instance.index_fields(self._vocab)
                yield instance
                queue.task_done()
            done_count += 1
    def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            checked_for_token_indexers: bool = False
            for instance in instances:
                # Check the first instance to make sure it doesn't contain any TextFields with
                # token_indexers because we don't want to be duplicating those by sending
                # them across processes.
                if not checked_for_token_indexers:
                    for field_name, field in instance.fields.items():
                        if isinstance(field, TextField) and field._token_indexers is not None:
                            raise ValueError(
                                f"Found a TextField ({field_name}) with token_indexers already "
                                "applied, but you're using num_workers > 0 in your data loader. "
                                "Make sure your dataset reader's text_to_instance() method doesn't "
                                "add any token_indexers to the TextFields it creates. Instead, the token_indexers "
                                "should be added to the instances in the apply_token_indexers() method of your "
                                "dataset reader (which you'll have to implement if you haven't done "
                                "so already)."
                            )
                    checked_for_token_indexers = True
                queue.put((instance, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Esempio n. 6
0
    def _gather_instances(self, queue: mp.JoinableQueue) -> Iterable[Instance]:
        done_count: int = 0
        while done_count < self.num_workers:
            for instances_chunk, worker_error in iter(queue.get, (None, None)):
                if worker_error is not None:
                    e, tb = worker_error
                    sys.stderr.write("".join(tb))
                    raise e

                for instance in instances_chunk:
                    self.reader.apply_token_indexers(instance)
                    if self._vocab is not None:
                        instance.index_fields(self._vocab)
                    yield instance
                queue.task_done()
            queue.task_done()
            done_count += 1
Esempio n. 7
0
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                queue.put((batch, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Esempio n. 8
0
    def _batch_worker(self, instance_queue: mp.JoinableQueue,
                      batch_queue: mp.JoinableQueue) -> None:
        try:
            for batch_chunk in lazy_groups_of(
                    self._instances_to_batches(
                        self._gather_instances(instance_queue)),
                    self._batch_chunk_size,
            ):
                batch_queue.put((batch_chunk, None))
        except Exception as e:
            batch_queue.put((None, (e, traceback.format_exc())))

        # Indicate to the consumer (main thread) that this worker is finished.
        batch_queue.put((None, None))

        # Wait for the consumer (in the main process) to finish receiving all batch groups
        # to avoid prematurely closing the queue.
        batch_queue.join()
Esempio n. 9
0
def run_experiment(experiment_id, experiment_directory, run_id,
                   experiment_config, agents_config, seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

    results_path = f'{experiment_directory}/run-{run_id}'
    if not os.path.exists(results_path):
        os.mkdir(results_path)
    base_path = results_path

    createNewEnvironment = EnvironmentCreationFunction(
        experiment_config['environment'])

    checkpoint_at_iterations = [
        int(i) for i in experiment_config['checkpoint_at_iterations']
    ]
    benchmarking_episodes = int(experiment_config['benchmarking_episodes'])

    training_schemes = util.experiment_parsing.initialize_training_schemes(
        experiment_config['self_play_training_schemes'])
    algorithms = util.experiment_parsing.initialize_algorithms(
        createNewEnvironment(), agents_config)
    fixed_agents = util.experiment_parsing.initialize_fixed_agents(
        experiment_config['fixed_agents'])

    training_jobs = enumerate_training_jobs(training_schemes, algorithms)

    (initial_fixed_agents_to_benchmark,
     fixed_agents_for_confusion) = preprocess_fixed_agents(
         fixed_agents, checkpoint_at_iterations)
    agent_queue, benchmark_queue, matrix_queue = JoinableQueue(
    ), JoinableQueue(), JoinableQueue()

    for fixed_agent in initial_fixed_agents_to_benchmark:
        agent_queue.put(fixed_agent)

    (training_processes, mm_process,
     benchmark_process, cfm_process) = create_all_initial_processes(
         training_jobs, createNewEnvironment, checkpoint_at_iterations,
         agent_queue, benchmark_queue, matrix_queue, benchmarking_episodes,
         fixed_agents_for_confusion, results_path, seed)

    run_processes(training_processes, mm_process, benchmark_process,
                  cfm_process)
Esempio n. 10
0
 def _safe_queue_put(self, worker_id: int, item: Any,
                     queue: mp.JoinableQueue, rx: Connection) -> bool:
     while True:
         # First we have to check to make sure the parent process is still alive
         # and consuming from the queue because there are circumstances where the
         # parent process can or exit stop consuming without automatically cleaning up
         # its children (the workers).
         # For example, when the parent process is killed with `kill -9`.
         # So the first thing we do is check to see if the parent has notified
         # us (the worker) to stop through the rx (receiver) connection.
         # Of course this only works if the parent was able to send out a notification,
         # which may not always be the case. So we have a backup check below.
         if rx.poll():
             logger.warning(
                 "worker %d received stop message from parent, exiting now",
                 worker_id)
             queue.cancel_join_thread()
             return False
         # The is the backup check.
         # The file descriptor associated with the rx (receiver) connection will
         # be readable if and only if the parent process has exited.
         # NOTE (epwalsh): this doesn't work on Mac OS X with `start_method == "fork"`
         # for some reason, i.e. the file descriptor doesn't show as readable
         # after the parent process has died.
         fds, _, _ = select.select([rx.fileno()], [], [], 0)
         if fds:
             logger.warning(
                 "worker %d parent process has died, exiting now",
                 worker_id)
             queue.cancel_join_thread()
             return False
         # If we're down here the parent process is still alive to the best of our
         # knowledge, so we can continue putting things on the queue.
         try:
             queue.put(item, True, 0.1)
             return True
         except Full:
             continue
Esempio n. 11
0
    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None
Esempio n. 12
0
import sys
import time

import torch
from torch.multiprocessing import Process 
from torch.multiprocessing import Queue, SimpleQueue
from torch.multiprocessing import JoinableQueue

#q = SimpleQueue() 
#q = Queue() 
q = JoinableQueue() 

def torch_shared_mem_process(shared_memory):
    counter = 0
    start = time.time()
    while True:
        data = q.get()
        counter += 1
        if data is None:
            print(f'[torch_shared_mem_process_q1] Received with shared memory {shared_memory}: {time.time() - start}')
            return
        # assert data.is_shared()
        del data

def test_mem_share(share_memory):
    p = Process(target=torch_shared_mem_process, args=(share_memory, ))
    p.start()

    start = time.time()
    n = 100
    for i in range(n):