예제 #1
0
 def _shutdown_workers(self):
     if not self.shutdown:
         self.shutdown = True
         # removes pids from the C side data structure first so worker
         # termination afterwards won't trigger false positive error report.
         if self.worker_pids_set:
             _remove_worker_pids(id(self))
             self.worker_pids_set = False
         self.done_event.set()
         if self.pin_memory:
             # Sending `None` to `pin_memory_thread` must be before
             # stopping worker processes because the workers may leave
             # corrupted data in `worker_result_queue`, causing
             # `pin_memory_thread` unable to read and terminate properly.
             self.worker_result_queue.put(None)
         # Workers can't be waiting to put be cause their output queue
         # is a multiprocessing.Queue and its .put is non-blocking.
         # They can only be waiting to get, so we put `None` here.
         for _w in self.workers:
             # Putting as many None as workers to ensure worker will get one
             self.batch_queue.put(None)
         for w in self.workers:
             w.join()
         if self.pin_memory:
             self.pin_memory_thread.join()
예제 #2
0
 def _shutdown_workers(self):
     if not self.shutdown:
         self.shutdown = True
         # removes pids from the C side data structure first so worker
         # termination afterwards won't trigger false positive error report.
         if self.worker_pids_set:
             _remove_worker_pids(id(self))
             self.worker_pids_set = False
         self.done_event.set()
         if self.pin_memory:
             # Sending `None` to `pin_memory_thread` must be before
             # stopping worker processes because the workers may leave
             # corrupted data in `worker_result_queue`, causing
             # `pin_memory_thread` unable to read and terminate properly.
             self.worker_result_queue.put(None)
         # Workers can't be waiting to put be cause their output queue
         # is a multiprocessing.Queue and its .put is non-blocking.
         # They can only be waiting to get, so we put `None` here.
         for _w in self.workers:
             # Putting as many None as workers to ensure worker will get one
             self.batch_queue.put(None)
         for w in self.workers:
             w.join()
         if self.pin_memory:
             self.pin_memory_thread.join()
예제 #3
0
    def _shutdown_workers(self):
        # Called when shutting down this `_MultiProcessingDataLoaderIter`.
        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
        # the logic of this function.
        python_exit_status = torch_data_utils.python_exit_status
        if python_exit_status is True or python_exit_status is None:
            # See (2) of the note. If Python is shutting down, do no-op.
            return
        # Normal exit when last reference is gone / iterator is depleted.
        # See (1) and the second half of the note.
        if not self._shutdown:
            self._shutdown = True
            try:
                # Exit `pin_memory_thread` first because exiting workers may leave
                # corrupted data in `worker_result_queue` which `pin_memory_thread`
                # reads from.
                if hasattr(self, '_pin_memory_thread'):
                    # Use hasattr in case error happens before we set the attribute.
                    self._pin_memory_thread_done_event.set()
                    # Send something to pin_memory_thread in case it is waiting
                    # so that it can wake up and check `pin_memory_thread_done_event`
                    for worker_id in range(len(self._workers)):
                        self._worker_result_queues[worker_id].put((None, None))

                    self._pin_memory_thread.join()
                    for worker_id in range(len(self._workers)):
                        self._worker_result_queues[
                            worker_id].cancel_join_thread()
                        self._worker_result_queues[worker_id].close()

                # Exit workers now.
                self._workers_done_event.set()
                for worker_id in range(len(self._workers)):
                    # Get number of workers from `len(self._workers)` instead of
                    # `self.num_workers` in case we error before starting all
                    # workers.
                    if self._worker_is_active[worker_id]:
                        self._shutdown_worker(worker_id)
                for w in self._workers:
                    w.join()
                for q in self._task_queues:
                    q.cancel_join_thread()
                    q.close()
            finally:
                # Even though all this function does is putting into queues that
                # we have called `cancel_join_thread` on, weird things can
                # happen when a worker is killed by a signal, e.g., hanging in
                # `Event.set()`. So we need to guard this with SIGCHLD handler,
                # and remove pids from the C side data structure only at the
                # end.
                #
                # FIXME: Unfortunately, for Windows, we are missing a worker
                #        error detection mechanism here in this function, as it
                #        doesn't provide a SIGCHLD handler.
                if self._worker_pids_set:
                    signal_handling._remove_worker_pids(id(self))
                    self._worker_pids_set = False