Example #1
0
 def __init__(self, manager: SharedMemoryManager, cache_size, unit,
              shape) -> None:
     self.unit = unit
     self.shape = shape
     self.cache_size = cache_size
     self.total_bytes = self.unit * self.cache_size
     self.cache_block = manager.SharedMemory(size=self.total_bytes)
     self.lock = Manager().Lock()
     self.st_id = Manager().Value('i', cache_size)
     self.et_id = Manager().Value('i', -1)
Example #2
0
    def multiprocess(self, n, nthread=8):
        """

        多进程 并发

        :return:
        """

        print('Parent process %s.' % os.getpid())

        p = Pool(nthread)  # 进程池, 和系统申请 nthread 个进程

        smm = SharedMemoryManager()  #TODO: pyrhon3.8+ 才有
        smm.start()  # Start the process that manages the shared memory blocks

        cache_list = smm.ShareableList([0] * n)
        # 限制了可被存储在其中的值只能是 int, float, bool, str (每条数据小于10M), bytes (每条数据小于10M)以及 None 这些内置类型。
        # 它另一个显著区别于内置 list 类型的地方在于它的长度无法修改(比如,没有 append, insert 等操作)
        # 且不支持通过切片操作动态创建新的 ShareableList  实例。

        shm_a = smm.SharedMemory(size=n)
        # shm_a.buf[:] = bytearray([0]*n)
        # shm_a.buf[:] = [0] * n

        print('shm_a id in main process: {} '.format(id(shm_a)))

        # 主进程的内存空间 和 子进程的内存空间 的考察
        self.global_array = [0] * n
        print('array id in main process: {} '.format(id(self.global_array)))

        self.global_string = 'abc'
        print('string id in main process: {} '.format(id(self.global_string)))

        self.global_int = 10
        print('int id in main process: {} '.format(id(self.global_int)))

        for i in range(n):

            # p.apply_async(task, args=(cache_name,i)) # apply_async 异步取回结果
            p.apply_async(self.task, args=(cache_list, shm_a, i))

        print('Waiting for all subprocesses done...')
        p.close()

        p.join()

        print('All subprocesses done.')

        smm.shutdown()

        return cache_list, shm_a
Example #3
0
class SMBase:
    def __init__(self, addr=None, manager=None, mutex=None, format_list=None,
                 size=0, ratio=2):
        if mutex is None:
            self.mutex = RLock()
        else:
            self.mutex = mutex
        if manager is None:
            self._manager = DummyManager()
        elif isinstance(manager, SharedMemoryManager) or isinstance(manager, DummyManager):
            self._manager = manager
        else:
            self._manager = SharedMemoryManager(manager)
        capacity = int(size*ratio)
        if capacity == 0:
            capacity = ratio
        with self.mutex:
            if addr is None:
                if format_list is None:
                    raise ValueError("Either addr or format_list must be provided")
                self._shl = self._manager.ShareableList(format_list)
                self._shl_addr = self._shl.shm.name
                self._shm = self._manager.SharedMemory(capacity)
                self._shm_addr = self._shm.name
                self._shl[0] = self._shm_addr
                self._shl[1] = int(size)
                self._shl[2] = int(capacity)
            else:
                self._shl_addr = addr
                self._shl = shared_memory.ShareableList(name=addr)
                self._shm_addr = self._shl[0]
                self._shm = shared_memory.SharedMemory(name=self._shm_addr)

    @locked
    def size(self):
        return self._shl[1]

    @locked
    def capacity(self):
        return self._shl[2]

    @locked
    def name(self):
        return self._shl.shm.name

    def check_memory(self) -> bool:
        updated = False
        with self.mutex:
            if self._shm_addr != self._shl[0]:
                updated = True
                try:
                    self._shm.close()
                except:
                    traceback.print_exc()
                self._shm = shared_memory.SharedMemory(name=self._shl[0])
                self._shm_addr = self._shl[0]
        return updated

    def _recap(self, cap: int):
        if cap > self._shm.size:
            new_shm = self._manager.SharedMemory(size=int(cap))
            new_shm.buf[:self._shm.size] = self._shm.buf[:]
            try:
                self._shm.close()
                self._shm.unlink()
            except:
                traceback.print_exc()
            self._shm = new_shm
            self._shm_addr = new_shm.name
            self._shl[0] = new_shm.name
            self._shl[2] = int(cap)

    def __del__(self):
        """
        Note: this does not unlink data. That is expected to be handled
        by the manager.
        """
        self._shm.close()
        self._shl.shm.close()
Example #4
0
class ParallelSimulation(Simulation):
    """ Parallel simulation of Barnes-Hut algorithm realised using shared memory. """

    def __init__(self, positions: np.ndarray, velocities: np.ndarray, masses: np.ndarray, params: Namespace):
        super().__init__(positions, velocities, masses, params)
        self._theta = params.theta

        self._init_memory(positions, velocities, masses)
        self._init_workers()

        atexit.register(self._cleanup)

    def _init_memory(self, positions: np.ndarray, velocities: np.ndarray, masses: np.ndarray):
        """ Prepares shared memory arrays. """
        # setup process that sets up shared memory
        self._memory_manager = SharedMemoryManager()
        self._memory_manager.start()

        max_nodes = self.bodies + 64

        # create shared memory buffers
        self._positions_shm = self._memory_manager.SharedMemory(positions.nbytes)
        self._velocities_shm = self._memory_manager.SharedMemory(velocities.nbytes)
        self._accelerations_shm = self._memory_manager.SharedMemory(velocities.nbytes)
        self._masses_shm = self._memory_manager.SharedMemory(masses.nbytes)

        self._nodes_positions_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 3), np.float).nbytes)
        self._nodes_masses_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, ), np.float).nbytes)
        self._nodes_sizes_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, ), np.float).nbytes)
        self._nodes_children_types_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 8), np.int).nbytes)
        self._nodes_children_ids_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 8), np.int).nbytes)

        # setup NumPy arrays
        self._data = SharedData(
            time_step=self.time_step,
            theta=self._theta,
            gravitational_constant=self.gravitational_constant,
            softening=self.softening,

            nodes_count=Value('i', 0),

            positions=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._positions_shm.buf),
            velocities=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._velocities_shm.buf),
            accelerations=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._accelerations_shm.buf),
            masses=np.ndarray((self.bodies, ), dtype=np.float, buffer=self._masses_shm.buf),

            nodes_positions=np.ndarray((max_nodes, 3), dtype=np.float, buffer=self._nodes_positions_shm.buf),
            nodes_masses=np.ndarray((max_nodes, ), dtype=np.float, buffer=self._nodes_masses_shm.buf),
            nodes_sizes=np.ndarray((max_nodes, ), dtype=np.float, buffer=self._nodes_sizes_shm.buf),
            nodes_children_types=np.ndarray((max_nodes, 8), dtype=np.int, buffer=self._nodes_children_types_shm.buf),
            nodes_children_ids=np.ndarray((max_nodes, 8), dtype=np.int, buffer=self._nodes_children_ids_shm.buf)
        )

        # copy data into shared arrays
        self._data.positions[:] = positions[:]
        self._data.velocities[:] = velocities[:]
        self._data.masses[:] = masses[:]

    def _init_workers(self):
        """ Prepares pool of workers. """
        self._pool = Pool(
            processes=self._params.processes,
            initializer=worker.initialize,
            initargs=(self._data, )
        )

    def _cleanup(self):
        """ Cleans up shared memory and pool of workers. """
        self._pool.terminate()
        self._memory_manager.shutdown()
        print('Memory manager was shut down.')

    def simulate(self) -> Iterable[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
        """ Runs parallel implementation of Barnes-Hut simulation. """
        while True:
            self._build_octree()
            self._update_accelerations()
            self._update_positions()

            yield self._data.positions, self._data.velocities, self._data.accelerations

    def _build_octree(self):
        """ Builds octree used in Barnes-Hut. """
        global_coords_min = np.repeat(np.min(self._data.positions), 3)
        global_coords_max = np.repeat(np.max(self._data.positions), 3)
        global_coords_mid = (global_coords_min + global_coords_max) / 2

        # manually build first node
        self._data.nodes_count.value = 1
        self._data.nodes_positions[0] = np.average(self._data.positions, axis=0, weights=self._data.masses)
        self._data.nodes_masses[0] = np.sum(self._data.masses)
        self._data.nodes_sizes[0] = global_coords_max[0] - global_coords_min[0]

        # calculate base octant for each body
        bodies_base_octant = np.sum((self._data.positions > global_coords_mid) * [1, 2, 4], axis=1)

        tasks_targets = []
        tasks_args = []

        # build second layer of nodes and collect tasks
        for octant in range(8):
            coords_min, coords_max = octant_coords(global_coords_min, global_coords_max, octant)
            coords_mid = (coords_min + coords_max) / 2

            # get indices of bodies in this octant
            octant_bodies = np.argwhere(bodies_base_octant == octant).flatten()

            # if node is empty or has one body handle it separately
            if octant_bodies.size == 0:
                self._data.nodes_children_types[0, octant] = OCTANT_EMPTY
                continue
            if octant_bodies.size == 1:
                self._data.nodes_children_types[0, octant] = OCTANT_BODY
                self._data.nodes_children_ids[0, octant] = octant_bodies[0]
                continue

            # create node
            node_id = self._data.nodes_count.value
            self._data.nodes_count.value = node_id + 1
            self._data.nodes_children_types[0, octant] = OCTANT_NODE
            self._data.nodes_children_ids[0, octant] = node_id

            self._data.nodes_positions[node_id] = np.average(self._data.positions[octant_bodies], axis=0, weights=self._data.masses[octant_bodies])
            self._data.nodes_masses[node_id] = np.sum(self._data.masses[octant_bodies])
            self._data.nodes_sizes[node_id] = coords_max[0] - coords_min[0]

            # split bodies into sub octants
            bodies_sub_octant = np.sum((self._data.positions[octant_bodies] > coords_mid) * [1, 2, 4], axis=1)

            # create tasks
            for i in range(8):
                tasks_targets.append((node_id, i))
                tasks_args.append((
                    octant_bodies[bodies_sub_octant == i],
                    *octant_coords(coords_min, coords_max, i)
                ))

        # run tasks
        results = self._pool.starmap(worker.build_octree_branch, tasks_args)

        # update references in nodes
        for (node_id, i), (sub_node_type, sub_node_id) in zip(tasks_targets, results):
            self._data.nodes_children_types[node_id, i] = sub_node_type
            self._data.nodes_children_ids[node_id, i] = sub_node_id

    def _update_accelerations(self):
        """ Calculates accelerations of the bodies. """
        if self.bodies < 2:
            return

        self._pool.map(worker.update_acceleration, range(self.bodies))

    def _update_positions(self):
        """ Calculates positions of the bodies. """
        self._pool.map(worker.update_position, range(self.bodies))

    @property
    def positions(self) -> np.ndarray:
        return self._data.positions

    @property
    def velocities(self) -> np.ndarray:
        return self._data.velocities

    @property
    def masses(self) -> np.ndarray:
        return self._data.masses

    @property
    def accelerations(self) -> np.ndarray:
        return self._data.accelerations
Example #5
0
class DataManager:
    def __init__(self, obs=1000, config=None):
        self._datasets = dict()
        self.smm = SharedMemoryManager()
        self.smm.start()
        self.result = Manager().list()
        self.conns = dict()
        self.obs = obs
        self.config = config

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    def shutdown(self):
        self.smm.shutdown()
        for conn in self.conns.values():
            conn.close()

    def _add_to_shared_memory(self, nparray: recarray) -> SharedMemory:
        """Internal function to copy an array into shared memory.

        Parameters
        ----------
        nparray : recarray
            The array to be copied into shared memory.

        Returns
        -------
        SharedMemoryName
            The shared memory object.
        """
        shm = self.smm.SharedMemory(nparray.nbytes)
        array = recarray(shape=nparray.shape, dtype=nparray.dtype, buf=shm.buf)
        copyto(array, nparray)
        return shm

    def _download_dataset(self, dataset: Dataset) -> recarray:
        """Internal function to download the dataset.

        Parameters
        ----------
        dataset : Dataset
            Dataset information including source, library, table, vars, etc.

        Returns
        -------
        recarray
            `numpy.recarry` of the downloaded dataset.
        """
        # TODO: generic login data for different data sources
        if dataset.source == 'wrds':
            usr = self.config.get('wrds_username')
            pwd = self.config.get('wrds_password')
        # If there exists a connection for the data source, use it!
        if (conn := self.conns.get(dataset.source, None)) is None:
            module = import_module(f'frds.data.{dataset.source}')
            conn = module.Connection(usr=usr, pwd=pwd)
            self.conns.update({dataset.source: conn})
        df = conn.get_table(library=dataset.library,
                            table=dataset.table,
                            columns=dataset.vars,
                            date_cols=dataset.date_vars,
                            obs=self.obs)
        assert isinstance(df, DataFrame)
        return df.to_records(index=False)
Example #6
0
class ParallelTransformStep(ProcessingStep[TPayload]):
    def __init__(
        self,
        function: Callable[[Message[TPayload]], TTransformed],
        next_step: ProcessingStep[TTransformed],
        processes: int,
        max_batch_size: int,
        max_batch_time: float,
        input_block_size: int,
        output_block_size: int,
        metrics: MetricsBackend,
    ) -> None:
        self.__transform_function = function
        self.__next_step = next_step
        self.__max_batch_size = max_batch_size
        self.__max_batch_time = max_batch_time

        self.__shared_memory_manager = SharedMemoryManager()
        self.__shared_memory_manager.start()

        self.__pool = Pool(
            processes,
            initializer=parallel_transform_worker_initializer,
            context=multiprocessing.get_context("spawn"),
        )

        self.__input_blocks = [
            self.__shared_memory_manager.SharedMemory(input_block_size)
            for _ in range(processes)
        ]

        self.__output_blocks = [
            self.__shared_memory_manager.SharedMemory(output_block_size)
            for _ in range(processes)
        ]

        self.__batch_builder: Optional[BatchBuilder[TPayload]] = None

        self.__results: Deque[Tuple[MessageBatch[TPayload], AsyncResult[Tuple[
            int, MessageBatch[TTransformed]]], ]] = deque()

        self.__metrics = metrics
        self.__batches_in_progress = Gauge(metrics, "batches_in_progress")
        self.__pool_waiting_time: Optional[float] = None

        self.__closed = False

        def handle_sigchld(signum: int, frame: Any) -> None:
            # Terminates the consumer if any child process of the
            # consumer is terminated.
            # This is meant to detect the unexpected termination of
            # multiprocessor pool workers.
            if not self.__closed:
                self.__metrics.increment("sigchld.detected")
                raise ChildProcessTerminated()

        signal.signal(signal.SIGCHLD, handle_sigchld)

    def __submit_batch(self) -> None:
        assert self.__batch_builder is not None
        batch = self.__batch_builder.build()
        logger.debug("Submitting %r to %r...", batch, self.__pool)
        self.__results.append((
            batch,
            self.__pool.apply_async(
                parallel_transform_worker_apply,
                (self.__transform_function, batch, self.__output_blocks.pop()),
            ),
        ))
        self.__batches_in_progress.increment()
        self.__metrics.timing("batch.size.msg", len(batch))
        self.__metrics.timing("batch.size.bytes", batch.get_content_size())
        self.__batch_builder = None

    def __check_for_results(self, timeout: Optional[float] = None) -> None:
        input_batch, result = self.__results[0]

        # If this call is being made in a context where it is intended to be
        # nonblocking, checking if the result is ready (rather than trying to
        # retrieve the result itself) avoids costly synchronization.
        if timeout == 0 and not result.ready():
            # ``multiprocessing.TimeoutError`` (rather than builtin
            # ``TimeoutError``) maintains consistency with ``AsyncResult.get``.
            raise multiprocessing.TimeoutError()

        i, output_batch = result.get(timeout=timeout)

        # TODO: This does not handle rejections from the next step!
        for message in output_batch:
            self.__next_step.poll()
            self.__next_step.submit(message)

        if i != len(input_batch):
            logger.warning(
                "Received incomplete batch (%0.2f%% complete), resubmitting...",
                i / len(input_batch) * 100,
            )
            # TODO: This reserializes all the ``SerializedMessage`` data prior
            # to the processed index even though the values at those indices
            # will never be unpacked. It probably makes sense to remove that
            # data from the batch to avoid unnecessary serialization overhead.
            self.__results[0] = (
                input_batch,
                self.__pool.apply_async(
                    parallel_transform_worker_apply,
                    (
                        self.__transform_function,
                        input_batch,
                        output_batch.block,
                        i,
                    ),
                ),
            )
            return

        logger.debug("Completed %r, reclaiming blocks...", input_batch)
        self.__input_blocks.append(input_batch.block)
        self.__output_blocks.append(output_batch.block)
        self.__batches_in_progress.decrement()

        del self.__results[0]

    def poll(self) -> None:
        self.__next_step.poll()

        while self.__results:
            try:
                self.__check_for_results(timeout=0)
            except multiprocessing.TimeoutError:
                if self.__pool_waiting_time is None:
                    self.__pool_waiting_time = time.time()
                else:
                    current_time = time.time()
                    if current_time - self.__pool_waiting_time > LOG_THRESHOLD_TIME:
                        logger.warning(
                            "Waited on the process pool longer than %d seconds. Waiting for %d results. Pool: %r",
                            LOG_THRESHOLD_TIME,
                            len(self.__results),
                            self.__pool,
                        )
                        self.__pool_waiting_time = current_time
                break
            else:
                self.__pool_waiting_time = None

        if self.__batch_builder is not None and self.__batch_builder.ready():
            self.__submit_batch()

    def __reset_batch_builder(self) -> None:
        try:
            input_block = self.__input_blocks.pop()
        except IndexError as e:
            raise MessageRejected("no available input blocks") from e

        self.__batch_builder = BatchBuilder(
            MessageBatch(input_block),
            self.__max_batch_size,
            self.__max_batch_time,
        )

    def submit(self, message: Message[TPayload]) -> None:
        assert not self.__closed

        if self.__batch_builder is None:
            self.__reset_batch_builder()
            assert self.__batch_builder is not None

        try:
            self.__batch_builder.append(message)
        except ValueTooLarge as e:
            logger.debug("Caught %r, closing batch and retrying...", e)
            self.__submit_batch()

            # This may raise ``MessageRejected`` (if all of the shared memory
            # is in use) and create backpressure.
            self.__reset_batch_builder()
            assert self.__batch_builder is not None

            # If this raises ``ValueTooLarge``, that means that the input block
            # size is too small (smaller than the Kafka payload limit without
            # compression.)
            self.__batch_builder.append(message)

    def close(self) -> None:
        self.__closed = True

        if self.__batch_builder is not None and len(self.__batch_builder) > 0:
            self.__submit_batch()

    def terminate(self) -> None:
        self.__closed = True

        logger.debug("Terminating %r...", self.__pool)
        self.__pool.terminate()

        logger.debug("Shutting down %r...", self.__shared_memory_manager)
        self.__shared_memory_manager.shutdown()

        logger.debug("Terminating %r...", self.__next_step)
        self.__next_step.terminate()

    def join(self, timeout: Optional[float] = None) -> None:
        deadline = time.time() + timeout if timeout is not None else None

        logger.debug("Waiting for %s batches...", len(self.__results))
        while self.__results:
            self.__check_for_results(
                timeout=max(deadline -
                            time.time(), 0) if deadline is not None else None)

        self.__pool.close()

        logger.debug("Waiting for %s...", self.__pool)
        # ``Pool.join`` doesn't accept a timeout (?!) but this really shouldn't
        # block for any significant amount of time unless something really went
        # wrong (i.e. we lost track of a task)
        self.__pool.join()

        self.__shared_memory_manager.shutdown()

        self.__next_step.close()
        self.__next_step.join(
            timeout=max(deadline -
                        time.time(), 0) if deadline is not None else None)
Example #7
0
class DataManager(Singleton):
    """DataManager loads data from sources and manages the shared memory"""
    def __init__(self, obs=-1, config=None):
        self._datasets = dict()
        self.smm = SharedMemoryManager()
        self.smm.start()
        self.conns = dict()
        self.obs = obs
        self.config = config

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    def shutdown(self):
        """Shutdown the shared memory manager and all data connections"""
        self.smm.shutdown()
        for conn in self.conns.values():
            conn.close()

    def _add_to_shared_memory(self, nparray: np.recarray) -> SharedMemory:
        """Internal function to copy an array into shared memory.

        Parameters
        ----------
        nparray : np.recarray
            The array to be copied into shared memory.

        Returns
        -------
        SharedMemory
            The shared memory object.
        """
        shm = self.smm.SharedMemory(nparray.nbytes)
        array = np.recarray(nparray.shape, dtype=nparray.dtype, buf=shm.buf)
        np.copyto(array, nparray)
        return shm

    def _download_dataset(self, dataset: Dataset) -> pd.DataFrame:
        """Internal function to download the dataset.

        Parameters
        ----------
        dataset : Dataset
            Dataset information including source, library, table, vars, etc.

        Returns
        -------
        pd.DataFrame
            DataFrame of the downloaded dataset.
        """
        # TODO: generic login data for different data sources
        if dataset.source == "wrds":
            usr = self.config.get("wrds_username")
            pwd = self.config.get("wrds_password")
        # If there exists a connection for the data source, use it!
        if (conn := self.conns.get(dataset.source, None)) is None:
            module = import_module(f"frds.data.{dataset.source}")
            conn = module.Connection(usr=usr, pwd=pwd)
            self.conns.update({dataset.source: conn})
        return conn.get_table(
            library=dataset.library,
            table=dataset.table,
            columns=dataset.vars,
            date_cols=dataset.date_vars,
            obs=self.obs,
        )
Example #8
0
class PathFinder:
    def __init__(self, maze: Maze = None, memory_manager: SharedMemoryManager = None):
        self._maze = maze
        self.display = None
        self.memory_manager = memory_manager
        self.pos = [0, 0]
        self._data = {}

        if self.memory_manager is None:
            self.memory_manager = SharedMemoryManager()
            self.memory_manager.start()

    def assign_maze(self, maze: Maze):
        """Assign a new maze AND clear data"""
        self._maze = maze
        # self._data = {}

    def start_pygame_display(self):
        self.display = Process(
            target=pygame_display,
            args=[
                self._data["maze"].dtype,
                self._data["maze_mem"],
                self._data["maze"].shape,
                self._data["path"].dtype,
                self._data["path_mem"],
                self._data["path"].shape,
                self._data["display_state_mem"],
            ],
            daemon=True,
        )
        self.display.start()

    def solve(self, *args, **kwargs):
        start_time = time.time()
        time.sleep(0.000001)  # prevent division by zero error

        def _setup_solve_data(**kwargs):

            self._data["node_depth"] = 0
            self._data["total_nodes"] = sum([len(x) for x in self._data["nodes"]])
            self._data["node_factorial"] = math.factorial(len(self._data["nodes"]))
            self._data["step_count"] = 0

            if not self._data.get("path_mem", False):
                print("[info] setting up shared memory")
                path_len = len(self._maze.tiles.flat)

                sys.setrecursionlimit(len(self._maze.tiles.flat) * 2)

                self._data["path_mem"] = self.memory_manager.SharedMemory(
                    self._maze.tiles.nbytes * 2
                )
                self._data["path"] = np.ndarray(
                    (path_len, 2), dtype=int, buffer=self._data["path_mem"].buf
                )

                self._data["maze_mem"] = self.memory_manager.SharedMemory(
                    self._maze.tiles.nbytes
                )

                self._data["maze"] = np.ndarray(
                    self._maze.tiles.shape,
                    dtype=self._maze.tiles.dtype,
                    buffer=self._data["maze_mem"].buf,
                )

                # state = running, step_count, step_per_sec
                self._data["display_state_mem"] = self.memory_manager.SharedMemory(
                    (3 * 64)
                )
                self._data["display_state"] = np.ndarray(
                    (3,), dtype="uint64", buffer=self._data["display_state_mem"].buf
                )
                self._data["display_state"][0] = True
                self._data["display_state"][1:] = 0

            self._data["maze"][:] = self._maze.tiles[:]
            self._data["path"].fill(0)

            self._config = {
                "progress_style": kwargs.get("progress_style", "bar"),
                "interval": kwargs.get("interval", 1),
                "interval_type": kwargs.get("interval_type", "time"),
                "patience": kwargs.get("patience", 0),
            }

            if not self.display and self._config["progress_style"] == "pygame":
                self.start_pygame_display()

            elif not self.display.is_alive():
                del self.display
                self.start_pygame_display()
                time.sleep(2)

        def _validate_maze(self) -> bool:
            if not isinstance(self._maze, Maze):
                print(f"[error] {self._maze} is not a {Maze}")
                return False

        def _check_unwalked(x, y, path_number) -> bool:
            if (x, y) in self._data["path"][path_number]:
                return False
            else:
                return True

        def _update_progress(self):
            def _show_progress(self):
                def _node_count():
                    sys.stdout.write(
                        f'\r{self._data["node_depth"]} of {self._data["total_nodes"]} nodes'
                    )
                    sys.stdout.flush()

                def _bar():
                    printProgressBar(
                        len(path),
                        self._data["total_nodes"],
                        prefix="Finding path...",
                        suffix="node depth",
                    )

                def _path():
                    current_time = time.time()
                    self.show_path(path)
                    print(
                        f"Steps: {self._data['step_count']:,} ",
                        f"Time: {(time.time()-start_time):.0f} sec",
                        f"\nSteps/second: {int(self._data['step_count']/(current_time-start_time)):,}",
                    )

                def _pygame():
                    if not self.display:
                        self.start_pygame_display()

                    else:
                        self._data["display_state"][1] = self._data["step_count"]
                        self._data["display_state"][2] = int(
                            self._data["step_count"] / (time.time() - start_time)
                        )
                        formatted_time = time.strftime("%H:%M:%S", time.gmtime())
                        print(f"solving... {formatted_time}", end="\r")

                # call function based on progress style
                {
                    "node_count": _node_count,
                    "bar": _bar,
                    "path": _path,
                    "pygame": _pygame,
                }.get(self._config["progress_style"])()

            if self._config["interval_type"] == "time":
                if (
                    time.time() - self._data.get("progress_timer", start_time)
                    > self._config["interval"]
                ):
                    _show_progress(self)
                    self._data["progress_timer"] = time.time()

            elif self._config["interval_type"] == "step_count":
                if self._data["step_count"] % self._config["interval"] == 0:
                    _show_progress(self)

        def _traverse(self, coords: tuple, path: set) -> bool:
            """Recursively explores the maze, returns True if the end is found,
            returns False when the end cannot be found"""
            # self._data["node_depth"] += 1
            self._data["step_count"] += 1
            path.add(coords)

            _update_progress(self)
            if (
                self._config["patience"]
                and time.time() - start_time > self._config["patience"]
            ):
                raise TimeoutError

            if coords == self._maze.end:
                return True

            connections = [c for c in self._data["nodes"][coords] if not c in path]

            # recursively traverse each connection
            for new_coords in connections:
                # set next point in path
                self._data["node_depth"] += 1
                self._data["path"][self._data["node_depth"]] = new_coords

                if _traverse(self, new_coords, path):
                    return True

                else:
                    # remove point from path
                    self._data["path"][self._data["node_depth"]] = (0, 0)
                    self._data["node_depth"] -= 1
                    path.remove(new_coords)

            # Failed to find path at this point
            return False

        # start pathing the maze
        _setup_solve_data(**kwargs)
        path = set()

        self._data["path"][0] = self._maze.start  # first point in path
        try:
            if _traverse(self, self._maze.start, path):
                print("[success] Path found!")

            else:
                print("[fail] No valid path could be found.")

        except TimeoutError:
            print("[timeout] took too long to solve")

        end_time = time.time()

        self._data["display_state"][1] = self._data["step_count"]
        self._data["display_state"][2] = int(
            self._data["step_count"] / (time.time() - start_time)
        )
        self._data["solve_time"] = end_time - start_time

    def create_nodes(self):
        """ Map the maze into a list of nodes"""

        self._data["nodes"] = {}
        i = 0
        map_size = self._maze.x * self._maze.y
        for x in range(self._maze.x):
            for y in range(self._maze.y):
                i += 1
                if self._is_node((x, y)):
                    self._data["nodes"][(x, y)] = []
                printProgressBar(
                    i, map_size, "Creating nodes...  ", f"{i}/{map_size}", length=50
                )

        numer_of_nodes = len(self._data["nodes"])
        for i, node in enumerate(self._data["nodes"].keys()):
            self._connect_nodes(node)
            printProgressBar(
                i,
                numer_of_nodes - 1,
                "Connecting nodes...",
                f"{i+1}/{numer_of_nodes}",
                length=50,
            )

    def _connect_nodes(self, node: tuple):
        """connect nodes together"""
        x, y = node

        def _down():
            dy = y
            while dy > 0 and self._maze.tiles[x, dy] == 0:
                dy -= 1
                # if self._is_node((x, dy)):
                if (x, dy) in self._data["nodes"]:
                    self._data["nodes"][(x, y)].append((x, dy))
                    break

        def _left():
            dx = x
            while dx > 0 and self._maze.tiles[dx, y] == 0:
                dx -= 1
                # if self._is_node((dx, y)):
                if (dx, y) in self._data["nodes"]:
                    self._data["nodes"][(x, y)].append((dx, y))
                    break

        def _right():
            dx = x
            while dx < self._maze.x - 1 and self._maze.tiles[dx, y] == 0:
                dx += 1
                # if self._is_node((dx, y)):
                if (dx, y) in self._data["nodes"]:
                    self._data["nodes"][(x, y)].append((dx, y))
                    break

        def _up():
            dy = y
            while dy < self._maze.y - 1 and self._maze.tiles[x, dy] == 0:
                dy += 1
                # if self._is_node((x, dy)):
                if (x, dy) in self._data["nodes"]:
                    self._data["nodes"][(x, y)].append((x, dy))
                    break

        # order determines order of pathing
        _down()
        _left()
        _right()
        _up()

    def _is_node(self, coords: tuple) -> bool:
        # The start and end of a maze are nodes automatically
        if coords == self._maze.start or coords == self._maze.end:
            return True

        x, y = coords

        # If all 8 surrounding tiles are clear then ignore
        if not self._maze.tiles[x - 1 : x + 2, y - 1 : y + 2].any() > 0:
            return False

        # Check if the cardinal directions are open
        left, right, up, down = False, False, False, False
        if x > 0 and self._maze.tiles[x - 1, y] == 0:
            left = True
        if x < self._maze.x - 1 and self._maze.tiles[x + 1, y] == 0:
            right = True
        if y > 0 and self._maze.tiles[x, y - 1] == 0:
            down = True
        if y < self._maze.y - 1 and self._maze.tiles[x, y + 1] == 0:
            up = True

        # straight paths and deadends are not nodes, but corners are!
        for xbool in [left, right]:
            for ybool in [up, down]:
                if xbool and ybool:
                    return True

        return False

    def display_nodes(self):
        if not "nodes" in self._data:
            print("[error] cannont display nodes, data missing")
            return

        for x in range(self._maze.x):
            line = [" "]
            for y in range(self._maze.y):
                if self._maze.tiles[x, y] == 1:
                    line.append("#")
                elif (x, y) in self._data["nodes"].keys():
                    line.append("+")
                elif self._maze.tiles[x, y] == 0:
                    line.append(" ")
                else:
                    line.append("e")
            print(" ".join(line))

    def show_path(self, path: tuple = None):
        """Prints to terminal the map with the path :n: drawn on it"""
        if not path is list:
            path = [tuple(p) for p in self._data["path"][:] if tuple(p) != (0, 0)]

        # tile type to character lookup dict
        characters = {0: " ", 1: "#", 2: "2", 3: "."}

        output = []

        def _draw_path_between_points(a, b, maze):
            # if a is None then b should be the start
            if a is None:
                maze[b] = 3

            else:
                dx = 0
                if b[0] > a[0]:
                    dx = 1
                elif b[0] < a[0]:
                    dx = -1

                dy = 0
                if b[1] > a[1]:
                    dy = 1
                elif b[1] < a[1]:
                    dy = -1

                # paint path onto maze
                x, y = a
                while (x, y) != b:
                    x += dx
                    y += dy
                    maze[x, y] = 3

        if path:
            maze = copy.deepcopy(self._maze.tiles)
            point_a = None
            for point_b in path:
                _draw_path_between_points(point_a, point_b, maze)
                point_a = point_b

            for x in range(maze.shape[0]):
                line = [" "]
                for y in range(maze.shape[1]):
                    if (x, y) == path[-1]:
                        line.append("X")
                    else:
                        line.append(characters[maze[x, y]])
                output.append(" ".join(line))

        if output:
            print("\n".join(output))

        else:
            print("[error] No path to display")