def _shutdown_workers(self): if not self.shutdown: self.shutdown = True # removes pids from the C side data structure first so worker # termination afterwards won't trigger false positive error report. if self.worker_pids_set: _remove_worker_pids(id(self)) self.worker_pids_set = False self.done_event.set() if self.pin_memory: # Sending `None` to `pin_memory_thread` must be before # stopping worker processes because the workers may leave # corrupted data in `worker_result_queue`, causing # `pin_memory_thread` unable to read and terminate properly. self.worker_result_queue.put(None) # Workers can't be waiting to put be cause their output queue # is a multiprocessing.Queue and its .put is non-blocking. # They can only be waiting to get, so we put `None` here. for _w in self.workers: # Putting as many None as workers to ensure worker will get one self.batch_queue.put(None) for w in self.workers: w.join() if self.pin_memory: self.pin_memory_thread.join()
def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. python_exit_status = torch_data_utils.python_exit_status if python_exit_status is True or python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. # See (1) and the second half of the note. if not self._shutdown: self._shutdown = True try: # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. if hasattr(self, '_pin_memory_thread'): # Use hasattr in case error happens before we set the attribute. self._pin_memory_thread_done_event.set() # Send something to pin_memory_thread in case it is waiting # so that it can wake up and check `pin_memory_thread_done_event` for worker_id in range(len(self._workers)): self._worker_result_queues[worker_id].put((None, None)) self._pin_memory_thread.join() for worker_id in range(len(self._workers)): self._worker_result_queues[ worker_id].cancel_join_thread() self._worker_result_queues[worker_id].close() # Exit workers now. self._workers_done_event.set() for worker_id in range(len(self._workers)): # Get number of workers from `len(self._workers)` instead of # `self.num_workers` in case we error before starting all # workers. if self._worker_is_active[worker_id]: self._shutdown_worker(worker_id) for w in self._workers: w.join() for q in self._task_queues: q.cancel_join_thread() q.close() finally: # Even though all this function does is putting into queues that # we have called `cancel_join_thread` on, weird things can # happen when a worker is killed by a signal, e.g., hanging in # `Event.set()`. So we need to guard this with SIGCHLD handler, # and remove pids from the C side data structure only at the # end. # # FIXME: Unfortunately, for Windows, we are missing a worker # error detection mechanism here in this function, as it # doesn't provide a SIGCHLD handler. if self._worker_pids_set: signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False