Ejemplo n.º 1
0
class Pipeline():
    def __init__(self,
                 config,
                 share_batches=True,
                 manager=None,
                 new_process=True):
        if new_process == True and manager is None:
            manager = Manager()
        self.knows = Semaphore(0)  # > 0 if we know if any are coming
        # == 0 if DatasetReader is processing a command
        self.working = Semaphore(1 if new_process else 100)
        self.finished_reading = Lock(
        )  # locked if we're still reading from file
        # number of molecules that have been sent to the pipe:
        self.in_pipe = Value('i', 0)

        # Tracking what's already been sent through the pipe:
        self._example_number = Value('i', 0)

        # The final kill switch:
        self._close = Value('i', 0)

        self.command_queue = manager.Queue(10)
        self.molecule_pipeline = None
        self.batch_queue = Queue(config.data.batch_queue_cap
                                 )  #manager.Queue(config.data.batch_queue_cap)
        self.share_batches = share_batches

        self.dataset_reader = DatasetReader("dataset_reader",
                                            self,
                                            config,
                                            new_process=new_process)
        if new_process:
            self.dataset_reader.start()

    def __getstate__(self):
        self_dict = self.__dict__.copy()
        self_dict['dataset_reader'] = None
        return self_dict

    # methods for pipeline user/consumer:
    def start_reading(self,
                      examples_to_read,
                      make_molecules=True,
                      batch_size=None,
                      wait=False):
        #print("Start reading...")
        assert check_semaphore(
            self.finished_reading
        ), "Tried to start reading file, but already reading!"
        with self.in_pipe.get_lock():
            assert self.in_pipe.value == 0, "Tried to start reading, but examples already in pipe!"
        set_semaphore(self.finished_reading, False)
        set_semaphore(self.knows, False)
        self.working.acquire()
        self.command_queue.put(
            StartReading(examples_to_read, make_molecules, batch_size))
        if wait:
            self.wait_till_done()

    def wait_till_done(self):
        # wait_semaphore(self.knows)
        # wait_semaphore(self.finished_reading)
        self.working.acquire()
        self.working.release()
        if self.any_coming():
            with self.in_pipe.get_lock():
                ip = self.in_pipe.value
            raise Exception(f"Waiting with {ip} examples in pipe!")

    def scan_to(self, index):
        assert check_semaphore(
            self.knows), "Tried to scan to index, but don't know if finished!"
        assert check_semaphore(
            self.finished_reading
        ), "Tried to scan to index, but not finished reading!"
        assert not self.any_coming(
        ), "Tried to scan to index, but pipeline not empty!"
        self.working.acquire()
        self.command_queue.put(ScanTo(index))
        with self._example_number.get_lock():
            self._example_number.value = index
        # What to do if things are still in the pipe???

    def set_indices(self, test_set_indices):
        self.working.acquire()
        self.command_queue.put(SetIndices(torch.tensor(test_set_indices)))
        self.working.acquire()
        self.command_queue.put(ScanTo(0))

    def set_shuffle(self, shuffle):
        self.command_queue.put(SetShuffle(shuffle))

    def any_coming(self):  # returns True if at least one example is coming
        wait_semaphore(self.knows)
        with self.in_pipe.get_lock():
            return self.in_pipe.value > 0

    def get_batch(self, timeout=None):
        #assert self.any_coming(verbose=verbose), "Tried to get data from an empty pipeline!"
        x = self.batch_queue.get(True, timeout)
        #print(f"{type(x)} : {x}")
        #for b in x:
        #    print(f" --{type(b)} : {b}")

        with self.in_pipe.get_lock():
            self.in_pipe.value -= x.n_examples
            if self.in_pipe.value == 0 and not check_semaphore(
                    self.finished_reading):
                set_semaphore(self.knows, False)
        with self._example_number.get_lock():
            self._example_number.value += x.n_examples
        return x

    @property
    def example_number(self):
        with self._example_number.get_lock():
            return self._example_number.value

    def close(self):
        self.command_queue.put(CloseReader())
        with self._close.get_lock():
            self._close.value = True
        self.dataset_reader.join(4)
        self.dataset_reader.kill()

    # methods for DatasetReader:
    def get_command(self):
        return self.command_queue.get()

    def put_molecule_to_ext(self, m, block=True):
        r = self.molecule_pipeline.put_molecule(m, block)
        if not r:
            return False
        with self.in_pipe.get_lock():
            if self.in_pipe.value == 0:
                set_semaphore(self.knows, True)
            self.in_pipe.value += 1
        return True

    def put_molecule_data(self, data, atomic_numbers, weights, ID, block=True):
        r = self.molecule_pipeline.put_molecule_data(data, atomic_numbers,
                                                     weights, ID, block)
        if not r:
            return False
        with self.in_pipe.get_lock():
            if self.in_pipe.value == 0:
                set_semaphore(self.knows, True)
            if data.ndim == 3:
                self.in_pipe.value += data.shape[0]
            else:
                self.in_pipe.value += 1
        return True

    def get_batch_from_ext(self, block=True):
        return self.molecule_pipeline.get_next_batch(block)

    def ext_batch_ready(self):
        return self.molecule_pipeline.batch_ready()

    # !!! Call only after you've put the molecules !!!
    def set_finished_reading(self):
        set_semaphore(self.finished_reading, True)
        set_semaphore(self.knows, True)
        self.molecule_pipeline.notify_finished()

    def put_batch(self, x):
        if False:  #self.share_batches:
            print("[P] Sharing memory... ")
            try:
                x.share_memory_()
            except Exception as e:
                print("[P] Failed when moving tensor to shared memory")
                print(e)
            print("[P] Done sharing memory")
        self.batch_queue.put(x)

    def time_to_close(self):
        with self._close.get_lock():
            return self._close.value
Ejemplo n.º 2
0
class HogwildWorld(World):
    """Creates a separate world for each thread (process).

    Maintains a few shared objects to keep track of state:

    - A Semaphore which represents queued examples to be processed. Every call
      of parley increments this counter; every time a Process claims an
      example, it decrements this counter.

    - A Condition variable which notifies when there are no more queued
      examples.

    - A boolean Value which represents whether the inner worlds should shutdown.

    - An integer Value which contains the number of unprocessed examples queued
      (acquiring the semaphore only claims them--this counter is decremented
      once the processing is complete).
    """
    def __init__(self, world_class, opt, agents):
        super().__init__(opt)
        self.inner_world = world_class(opt, agents)

        self.queued_items = Semaphore(0)  # counts num exs to be processed
        self.epochDone = Condition()  # notifies when exs are finished
        self.terminate = Value('b', False)  # tells threads when to shut down
        self.cnt = Value('i', 0)  # number of exs that remain to be processed

        self.threads = []
        for i in range(opt['numthreads']):
            self.threads.append(
                HogwildProcess(i, world_class, opt, agents, self.queued_items,
                               self.epochDone, self.terminate, self.cnt))
        for t in self.threads:
            t.start()

    def display(self):
        self.shutdown()
        raise NotImplementedError('Hogwild does not support displaying in-run'
                                  ' task data. Use `--numthreads 1`.')

    def episode_done(self):
        return False

    def parley(self):
        """Queue one item to be processed."""
        with self.cnt.get_lock():
            self.cnt.value += 1
        self.queued_items.release()
        self.total_parleys += 1

    def getID(self):
        return self.inner_world.getID()

    def report(self, compute_time=False):
        return self.inner_world.report(compute_time)

    def save_agents(self):
        self.inner_world.save_agents()

    def synchronize(self):
        """Sync barrier: will wait until all queued examples are processed."""
        with self.epochDone:
            self.epochDone.wait_for(lambda: self.cnt.value == 0)

    def shutdown(self):
        """Set shutdown flag and wake threads up to close themselves"""
        # set shutdown flag
        with self.terminate.get_lock():
            self.terminate.value = True
        # wake up each thread by queueing fake examples
        for _ in self.threads:
            self.queued_items.release()
        # wait for threads to close
        for t in self.threads:
            t.join()