def __init__(self, manager: SharedMemoryManager, cache_size, unit, shape) -> None: self.unit = unit self.shape = shape self.cache_size = cache_size self.total_bytes = self.unit * self.cache_size self.cache_block = manager.SharedMemory(size=self.total_bytes) self.lock = Manager().Lock() self.st_id = Manager().Value('i', cache_size) self.et_id = Manager().Value('i', -1)
def multiprocess(self, n, nthread=8): """ 多进程 并发 :return: """ print('Parent process %s.' % os.getpid()) p = Pool(nthread) # 进程池, 和系统申请 nthread 个进程 smm = SharedMemoryManager() #TODO: pyrhon3.8+ 才有 smm.start() # Start the process that manages the shared memory blocks cache_list = smm.ShareableList([0] * n) # 限制了可被存储在其中的值只能是 int, float, bool, str (每条数据小于10M), bytes (每条数据小于10M)以及 None 这些内置类型。 # 它另一个显著区别于内置 list 类型的地方在于它的长度无法修改(比如,没有 append, insert 等操作) # 且不支持通过切片操作动态创建新的 ShareableList 实例。 shm_a = smm.SharedMemory(size=n) # shm_a.buf[:] = bytearray([0]*n) # shm_a.buf[:] = [0] * n print('shm_a id in main process: {} '.format(id(shm_a))) # 主进程的内存空间 和 子进程的内存空间 的考察 self.global_array = [0] * n print('array id in main process: {} '.format(id(self.global_array))) self.global_string = 'abc' print('string id in main process: {} '.format(id(self.global_string))) self.global_int = 10 print('int id in main process: {} '.format(id(self.global_int))) for i in range(n): # p.apply_async(task, args=(cache_name,i)) # apply_async 异步取回结果 p.apply_async(self.task, args=(cache_list, shm_a, i)) print('Waiting for all subprocesses done...') p.close() p.join() print('All subprocesses done.') smm.shutdown() return cache_list, shm_a
class SMBase: def __init__(self, addr=None, manager=None, mutex=None, format_list=None, size=0, ratio=2): if mutex is None: self.mutex = RLock() else: self.mutex = mutex if manager is None: self._manager = DummyManager() elif isinstance(manager, SharedMemoryManager) or isinstance(manager, DummyManager): self._manager = manager else: self._manager = SharedMemoryManager(manager) capacity = int(size*ratio) if capacity == 0: capacity = ratio with self.mutex: if addr is None: if format_list is None: raise ValueError("Either addr or format_list must be provided") self._shl = self._manager.ShareableList(format_list) self._shl_addr = self._shl.shm.name self._shm = self._manager.SharedMemory(capacity) self._shm_addr = self._shm.name self._shl[0] = self._shm_addr self._shl[1] = int(size) self._shl[2] = int(capacity) else: self._shl_addr = addr self._shl = shared_memory.ShareableList(name=addr) self._shm_addr = self._shl[0] self._shm = shared_memory.SharedMemory(name=self._shm_addr) @locked def size(self): return self._shl[1] @locked def capacity(self): return self._shl[2] @locked def name(self): return self._shl.shm.name def check_memory(self) -> bool: updated = False with self.mutex: if self._shm_addr != self._shl[0]: updated = True try: self._shm.close() except: traceback.print_exc() self._shm = shared_memory.SharedMemory(name=self._shl[0]) self._shm_addr = self._shl[0] return updated def _recap(self, cap: int): if cap > self._shm.size: new_shm = self._manager.SharedMemory(size=int(cap)) new_shm.buf[:self._shm.size] = self._shm.buf[:] try: self._shm.close() self._shm.unlink() except: traceback.print_exc() self._shm = new_shm self._shm_addr = new_shm.name self._shl[0] = new_shm.name self._shl[2] = int(cap) def __del__(self): """ Note: this does not unlink data. That is expected to be handled by the manager. """ self._shm.close() self._shl.shm.close()
class ParallelSimulation(Simulation): """ Parallel simulation of Barnes-Hut algorithm realised using shared memory. """ def __init__(self, positions: np.ndarray, velocities: np.ndarray, masses: np.ndarray, params: Namespace): super().__init__(positions, velocities, masses, params) self._theta = params.theta self._init_memory(positions, velocities, masses) self._init_workers() atexit.register(self._cleanup) def _init_memory(self, positions: np.ndarray, velocities: np.ndarray, masses: np.ndarray): """ Prepares shared memory arrays. """ # setup process that sets up shared memory self._memory_manager = SharedMemoryManager() self._memory_manager.start() max_nodes = self.bodies + 64 # create shared memory buffers self._positions_shm = self._memory_manager.SharedMemory(positions.nbytes) self._velocities_shm = self._memory_manager.SharedMemory(velocities.nbytes) self._accelerations_shm = self._memory_manager.SharedMemory(velocities.nbytes) self._masses_shm = self._memory_manager.SharedMemory(masses.nbytes) self._nodes_positions_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 3), np.float).nbytes) self._nodes_masses_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, ), np.float).nbytes) self._nodes_sizes_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, ), np.float).nbytes) self._nodes_children_types_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 8), np.int).nbytes) self._nodes_children_ids_shm = self._memory_manager.SharedMemory(np.empty((max_nodes, 8), np.int).nbytes) # setup NumPy arrays self._data = SharedData( time_step=self.time_step, theta=self._theta, gravitational_constant=self.gravitational_constant, softening=self.softening, nodes_count=Value('i', 0), positions=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._positions_shm.buf), velocities=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._velocities_shm.buf), accelerations=np.ndarray((self.bodies, 3), dtype=np.float, buffer=self._accelerations_shm.buf), masses=np.ndarray((self.bodies, ), dtype=np.float, buffer=self._masses_shm.buf), nodes_positions=np.ndarray((max_nodes, 3), dtype=np.float, buffer=self._nodes_positions_shm.buf), nodes_masses=np.ndarray((max_nodes, ), dtype=np.float, buffer=self._nodes_masses_shm.buf), nodes_sizes=np.ndarray((max_nodes, ), dtype=np.float, buffer=self._nodes_sizes_shm.buf), nodes_children_types=np.ndarray((max_nodes, 8), dtype=np.int, buffer=self._nodes_children_types_shm.buf), nodes_children_ids=np.ndarray((max_nodes, 8), dtype=np.int, buffer=self._nodes_children_ids_shm.buf) ) # copy data into shared arrays self._data.positions[:] = positions[:] self._data.velocities[:] = velocities[:] self._data.masses[:] = masses[:] def _init_workers(self): """ Prepares pool of workers. """ self._pool = Pool( processes=self._params.processes, initializer=worker.initialize, initargs=(self._data, ) ) def _cleanup(self): """ Cleans up shared memory and pool of workers. """ self._pool.terminate() self._memory_manager.shutdown() print('Memory manager was shut down.') def simulate(self) -> Iterable[Tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Runs parallel implementation of Barnes-Hut simulation. """ while True: self._build_octree() self._update_accelerations() self._update_positions() yield self._data.positions, self._data.velocities, self._data.accelerations def _build_octree(self): """ Builds octree used in Barnes-Hut. """ global_coords_min = np.repeat(np.min(self._data.positions), 3) global_coords_max = np.repeat(np.max(self._data.positions), 3) global_coords_mid = (global_coords_min + global_coords_max) / 2 # manually build first node self._data.nodes_count.value = 1 self._data.nodes_positions[0] = np.average(self._data.positions, axis=0, weights=self._data.masses) self._data.nodes_masses[0] = np.sum(self._data.masses) self._data.nodes_sizes[0] = global_coords_max[0] - global_coords_min[0] # calculate base octant for each body bodies_base_octant = np.sum((self._data.positions > global_coords_mid) * [1, 2, 4], axis=1) tasks_targets = [] tasks_args = [] # build second layer of nodes and collect tasks for octant in range(8): coords_min, coords_max = octant_coords(global_coords_min, global_coords_max, octant) coords_mid = (coords_min + coords_max) / 2 # get indices of bodies in this octant octant_bodies = np.argwhere(bodies_base_octant == octant).flatten() # if node is empty or has one body handle it separately if octant_bodies.size == 0: self._data.nodes_children_types[0, octant] = OCTANT_EMPTY continue if octant_bodies.size == 1: self._data.nodes_children_types[0, octant] = OCTANT_BODY self._data.nodes_children_ids[0, octant] = octant_bodies[0] continue # create node node_id = self._data.nodes_count.value self._data.nodes_count.value = node_id + 1 self._data.nodes_children_types[0, octant] = OCTANT_NODE self._data.nodes_children_ids[0, octant] = node_id self._data.nodes_positions[node_id] = np.average(self._data.positions[octant_bodies], axis=0, weights=self._data.masses[octant_bodies]) self._data.nodes_masses[node_id] = np.sum(self._data.masses[octant_bodies]) self._data.nodes_sizes[node_id] = coords_max[0] - coords_min[0] # split bodies into sub octants bodies_sub_octant = np.sum((self._data.positions[octant_bodies] > coords_mid) * [1, 2, 4], axis=1) # create tasks for i in range(8): tasks_targets.append((node_id, i)) tasks_args.append(( octant_bodies[bodies_sub_octant == i], *octant_coords(coords_min, coords_max, i) )) # run tasks results = self._pool.starmap(worker.build_octree_branch, tasks_args) # update references in nodes for (node_id, i), (sub_node_type, sub_node_id) in zip(tasks_targets, results): self._data.nodes_children_types[node_id, i] = sub_node_type self._data.nodes_children_ids[node_id, i] = sub_node_id def _update_accelerations(self): """ Calculates accelerations of the bodies. """ if self.bodies < 2: return self._pool.map(worker.update_acceleration, range(self.bodies)) def _update_positions(self): """ Calculates positions of the bodies. """ self._pool.map(worker.update_position, range(self.bodies)) @property def positions(self) -> np.ndarray: return self._data.positions @property def velocities(self) -> np.ndarray: return self._data.velocities @property def masses(self) -> np.ndarray: return self._data.masses @property def accelerations(self) -> np.ndarray: return self._data.accelerations
class DataManager: def __init__(self, obs=1000, config=None): self._datasets = dict() self.smm = SharedMemoryManager() self.smm.start() self.result = Manager().list() self.conns = dict() self.obs = obs self.config = config def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() def shutdown(self): self.smm.shutdown() for conn in self.conns.values(): conn.close() def _add_to_shared_memory(self, nparray: recarray) -> SharedMemory: """Internal function to copy an array into shared memory. Parameters ---------- nparray : recarray The array to be copied into shared memory. Returns ------- SharedMemoryName The shared memory object. """ shm = self.smm.SharedMemory(nparray.nbytes) array = recarray(shape=nparray.shape, dtype=nparray.dtype, buf=shm.buf) copyto(array, nparray) return shm def _download_dataset(self, dataset: Dataset) -> recarray: """Internal function to download the dataset. Parameters ---------- dataset : Dataset Dataset information including source, library, table, vars, etc. Returns ------- recarray `numpy.recarry` of the downloaded dataset. """ # TODO: generic login data for different data sources if dataset.source == 'wrds': usr = self.config.get('wrds_username') pwd = self.config.get('wrds_password') # If there exists a connection for the data source, use it! if (conn := self.conns.get(dataset.source, None)) is None: module = import_module(f'frds.data.{dataset.source}') conn = module.Connection(usr=usr, pwd=pwd) self.conns.update({dataset.source: conn}) df = conn.get_table(library=dataset.library, table=dataset.table, columns=dataset.vars, date_cols=dataset.date_vars, obs=self.obs) assert isinstance(df, DataFrame) return df.to_records(index=False)
class ParallelTransformStep(ProcessingStep[TPayload]): def __init__( self, function: Callable[[Message[TPayload]], TTransformed], next_step: ProcessingStep[TTransformed], processes: int, max_batch_size: int, max_batch_time: float, input_block_size: int, output_block_size: int, metrics: MetricsBackend, ) -> None: self.__transform_function = function self.__next_step = next_step self.__max_batch_size = max_batch_size self.__max_batch_time = max_batch_time self.__shared_memory_manager = SharedMemoryManager() self.__shared_memory_manager.start() self.__pool = Pool( processes, initializer=parallel_transform_worker_initializer, context=multiprocessing.get_context("spawn"), ) self.__input_blocks = [ self.__shared_memory_manager.SharedMemory(input_block_size) for _ in range(processes) ] self.__output_blocks = [ self.__shared_memory_manager.SharedMemory(output_block_size) for _ in range(processes) ] self.__batch_builder: Optional[BatchBuilder[TPayload]] = None self.__results: Deque[Tuple[MessageBatch[TPayload], AsyncResult[Tuple[ int, MessageBatch[TTransformed]]], ]] = deque() self.__metrics = metrics self.__batches_in_progress = Gauge(metrics, "batches_in_progress") self.__pool_waiting_time: Optional[float] = None self.__closed = False def handle_sigchld(signum: int, frame: Any) -> None: # Terminates the consumer if any child process of the # consumer is terminated. # This is meant to detect the unexpected termination of # multiprocessor pool workers. if not self.__closed: self.__metrics.increment("sigchld.detected") raise ChildProcessTerminated() signal.signal(signal.SIGCHLD, handle_sigchld) def __submit_batch(self) -> None: assert self.__batch_builder is not None batch = self.__batch_builder.build() logger.debug("Submitting %r to %r...", batch, self.__pool) self.__results.append(( batch, self.__pool.apply_async( parallel_transform_worker_apply, (self.__transform_function, batch, self.__output_blocks.pop()), ), )) self.__batches_in_progress.increment() self.__metrics.timing("batch.size.msg", len(batch)) self.__metrics.timing("batch.size.bytes", batch.get_content_size()) self.__batch_builder = None def __check_for_results(self, timeout: Optional[float] = None) -> None: input_batch, result = self.__results[0] # If this call is being made in a context where it is intended to be # nonblocking, checking if the result is ready (rather than trying to # retrieve the result itself) avoids costly synchronization. if timeout == 0 and not result.ready(): # ``multiprocessing.TimeoutError`` (rather than builtin # ``TimeoutError``) maintains consistency with ``AsyncResult.get``. raise multiprocessing.TimeoutError() i, output_batch = result.get(timeout=timeout) # TODO: This does not handle rejections from the next step! for message in output_batch: self.__next_step.poll() self.__next_step.submit(message) if i != len(input_batch): logger.warning( "Received incomplete batch (%0.2f%% complete), resubmitting...", i / len(input_batch) * 100, ) # TODO: This reserializes all the ``SerializedMessage`` data prior # to the processed index even though the values at those indices # will never be unpacked. It probably makes sense to remove that # data from the batch to avoid unnecessary serialization overhead. self.__results[0] = ( input_batch, self.__pool.apply_async( parallel_transform_worker_apply, ( self.__transform_function, input_batch, output_batch.block, i, ), ), ) return logger.debug("Completed %r, reclaiming blocks...", input_batch) self.__input_blocks.append(input_batch.block) self.__output_blocks.append(output_batch.block) self.__batches_in_progress.decrement() del self.__results[0] def poll(self) -> None: self.__next_step.poll() while self.__results: try: self.__check_for_results(timeout=0) except multiprocessing.TimeoutError: if self.__pool_waiting_time is None: self.__pool_waiting_time = time.time() else: current_time = time.time() if current_time - self.__pool_waiting_time > LOG_THRESHOLD_TIME: logger.warning( "Waited on the process pool longer than %d seconds. Waiting for %d results. Pool: %r", LOG_THRESHOLD_TIME, len(self.__results), self.__pool, ) self.__pool_waiting_time = current_time break else: self.__pool_waiting_time = None if self.__batch_builder is not None and self.__batch_builder.ready(): self.__submit_batch() def __reset_batch_builder(self) -> None: try: input_block = self.__input_blocks.pop() except IndexError as e: raise MessageRejected("no available input blocks") from e self.__batch_builder = BatchBuilder( MessageBatch(input_block), self.__max_batch_size, self.__max_batch_time, ) def submit(self, message: Message[TPayload]) -> None: assert not self.__closed if self.__batch_builder is None: self.__reset_batch_builder() assert self.__batch_builder is not None try: self.__batch_builder.append(message) except ValueTooLarge as e: logger.debug("Caught %r, closing batch and retrying...", e) self.__submit_batch() # This may raise ``MessageRejected`` (if all of the shared memory # is in use) and create backpressure. self.__reset_batch_builder() assert self.__batch_builder is not None # If this raises ``ValueTooLarge``, that means that the input block # size is too small (smaller than the Kafka payload limit without # compression.) self.__batch_builder.append(message) def close(self) -> None: self.__closed = True if self.__batch_builder is not None and len(self.__batch_builder) > 0: self.__submit_batch() def terminate(self) -> None: self.__closed = True logger.debug("Terminating %r...", self.__pool) self.__pool.terminate() logger.debug("Shutting down %r...", self.__shared_memory_manager) self.__shared_memory_manager.shutdown() logger.debug("Terminating %r...", self.__next_step) self.__next_step.terminate() def join(self, timeout: Optional[float] = None) -> None: deadline = time.time() + timeout if timeout is not None else None logger.debug("Waiting for %s batches...", len(self.__results)) while self.__results: self.__check_for_results( timeout=max(deadline - time.time(), 0) if deadline is not None else None) self.__pool.close() logger.debug("Waiting for %s...", self.__pool) # ``Pool.join`` doesn't accept a timeout (?!) but this really shouldn't # block for any significant amount of time unless something really went # wrong (i.e. we lost track of a task) self.__pool.join() self.__shared_memory_manager.shutdown() self.__next_step.close() self.__next_step.join( timeout=max(deadline - time.time(), 0) if deadline is not None else None)
class DataManager(Singleton): """DataManager loads data from sources and manages the shared memory""" def __init__(self, obs=-1, config=None): self._datasets = dict() self.smm = SharedMemoryManager() self.smm.start() self.conns = dict() self.obs = obs self.config = config def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() def shutdown(self): """Shutdown the shared memory manager and all data connections""" self.smm.shutdown() for conn in self.conns.values(): conn.close() def _add_to_shared_memory(self, nparray: np.recarray) -> SharedMemory: """Internal function to copy an array into shared memory. Parameters ---------- nparray : np.recarray The array to be copied into shared memory. Returns ------- SharedMemory The shared memory object. """ shm = self.smm.SharedMemory(nparray.nbytes) array = np.recarray(nparray.shape, dtype=nparray.dtype, buf=shm.buf) np.copyto(array, nparray) return shm def _download_dataset(self, dataset: Dataset) -> pd.DataFrame: """Internal function to download the dataset. Parameters ---------- dataset : Dataset Dataset information including source, library, table, vars, etc. Returns ------- pd.DataFrame DataFrame of the downloaded dataset. """ # TODO: generic login data for different data sources if dataset.source == "wrds": usr = self.config.get("wrds_username") pwd = self.config.get("wrds_password") # If there exists a connection for the data source, use it! if (conn := self.conns.get(dataset.source, None)) is None: module = import_module(f"frds.data.{dataset.source}") conn = module.Connection(usr=usr, pwd=pwd) self.conns.update({dataset.source: conn}) return conn.get_table( library=dataset.library, table=dataset.table, columns=dataset.vars, date_cols=dataset.date_vars, obs=self.obs, )
class PathFinder: def __init__(self, maze: Maze = None, memory_manager: SharedMemoryManager = None): self._maze = maze self.display = None self.memory_manager = memory_manager self.pos = [0, 0] self._data = {} if self.memory_manager is None: self.memory_manager = SharedMemoryManager() self.memory_manager.start() def assign_maze(self, maze: Maze): """Assign a new maze AND clear data""" self._maze = maze # self._data = {} def start_pygame_display(self): self.display = Process( target=pygame_display, args=[ self._data["maze"].dtype, self._data["maze_mem"], self._data["maze"].shape, self._data["path"].dtype, self._data["path_mem"], self._data["path"].shape, self._data["display_state_mem"], ], daemon=True, ) self.display.start() def solve(self, *args, **kwargs): start_time = time.time() time.sleep(0.000001) # prevent division by zero error def _setup_solve_data(**kwargs): self._data["node_depth"] = 0 self._data["total_nodes"] = sum([len(x) for x in self._data["nodes"]]) self._data["node_factorial"] = math.factorial(len(self._data["nodes"])) self._data["step_count"] = 0 if not self._data.get("path_mem", False): print("[info] setting up shared memory") path_len = len(self._maze.tiles.flat) sys.setrecursionlimit(len(self._maze.tiles.flat) * 2) self._data["path_mem"] = self.memory_manager.SharedMemory( self._maze.tiles.nbytes * 2 ) self._data["path"] = np.ndarray( (path_len, 2), dtype=int, buffer=self._data["path_mem"].buf ) self._data["maze_mem"] = self.memory_manager.SharedMemory( self._maze.tiles.nbytes ) self._data["maze"] = np.ndarray( self._maze.tiles.shape, dtype=self._maze.tiles.dtype, buffer=self._data["maze_mem"].buf, ) # state = running, step_count, step_per_sec self._data["display_state_mem"] = self.memory_manager.SharedMemory( (3 * 64) ) self._data["display_state"] = np.ndarray( (3,), dtype="uint64", buffer=self._data["display_state_mem"].buf ) self._data["display_state"][0] = True self._data["display_state"][1:] = 0 self._data["maze"][:] = self._maze.tiles[:] self._data["path"].fill(0) self._config = { "progress_style": kwargs.get("progress_style", "bar"), "interval": kwargs.get("interval", 1), "interval_type": kwargs.get("interval_type", "time"), "patience": kwargs.get("patience", 0), } if not self.display and self._config["progress_style"] == "pygame": self.start_pygame_display() elif not self.display.is_alive(): del self.display self.start_pygame_display() time.sleep(2) def _validate_maze(self) -> bool: if not isinstance(self._maze, Maze): print(f"[error] {self._maze} is not a {Maze}") return False def _check_unwalked(x, y, path_number) -> bool: if (x, y) in self._data["path"][path_number]: return False else: return True def _update_progress(self): def _show_progress(self): def _node_count(): sys.stdout.write( f'\r{self._data["node_depth"]} of {self._data["total_nodes"]} nodes' ) sys.stdout.flush() def _bar(): printProgressBar( len(path), self._data["total_nodes"], prefix="Finding path...", suffix="node depth", ) def _path(): current_time = time.time() self.show_path(path) print( f"Steps: {self._data['step_count']:,} ", f"Time: {(time.time()-start_time):.0f} sec", f"\nSteps/second: {int(self._data['step_count']/(current_time-start_time)):,}", ) def _pygame(): if not self.display: self.start_pygame_display() else: self._data["display_state"][1] = self._data["step_count"] self._data["display_state"][2] = int( self._data["step_count"] / (time.time() - start_time) ) formatted_time = time.strftime("%H:%M:%S", time.gmtime()) print(f"solving... {formatted_time}", end="\r") # call function based on progress style { "node_count": _node_count, "bar": _bar, "path": _path, "pygame": _pygame, }.get(self._config["progress_style"])() if self._config["interval_type"] == "time": if ( time.time() - self._data.get("progress_timer", start_time) > self._config["interval"] ): _show_progress(self) self._data["progress_timer"] = time.time() elif self._config["interval_type"] == "step_count": if self._data["step_count"] % self._config["interval"] == 0: _show_progress(self) def _traverse(self, coords: tuple, path: set) -> bool: """Recursively explores the maze, returns True if the end is found, returns False when the end cannot be found""" # self._data["node_depth"] += 1 self._data["step_count"] += 1 path.add(coords) _update_progress(self) if ( self._config["patience"] and time.time() - start_time > self._config["patience"] ): raise TimeoutError if coords == self._maze.end: return True connections = [c for c in self._data["nodes"][coords] if not c in path] # recursively traverse each connection for new_coords in connections: # set next point in path self._data["node_depth"] += 1 self._data["path"][self._data["node_depth"]] = new_coords if _traverse(self, new_coords, path): return True else: # remove point from path self._data["path"][self._data["node_depth"]] = (0, 0) self._data["node_depth"] -= 1 path.remove(new_coords) # Failed to find path at this point return False # start pathing the maze _setup_solve_data(**kwargs) path = set() self._data["path"][0] = self._maze.start # first point in path try: if _traverse(self, self._maze.start, path): print("[success] Path found!") else: print("[fail] No valid path could be found.") except TimeoutError: print("[timeout] took too long to solve") end_time = time.time() self._data["display_state"][1] = self._data["step_count"] self._data["display_state"][2] = int( self._data["step_count"] / (time.time() - start_time) ) self._data["solve_time"] = end_time - start_time def create_nodes(self): """ Map the maze into a list of nodes""" self._data["nodes"] = {} i = 0 map_size = self._maze.x * self._maze.y for x in range(self._maze.x): for y in range(self._maze.y): i += 1 if self._is_node((x, y)): self._data["nodes"][(x, y)] = [] printProgressBar( i, map_size, "Creating nodes... ", f"{i}/{map_size}", length=50 ) numer_of_nodes = len(self._data["nodes"]) for i, node in enumerate(self._data["nodes"].keys()): self._connect_nodes(node) printProgressBar( i, numer_of_nodes - 1, "Connecting nodes...", f"{i+1}/{numer_of_nodes}", length=50, ) def _connect_nodes(self, node: tuple): """connect nodes together""" x, y = node def _down(): dy = y while dy > 0 and self._maze.tiles[x, dy] == 0: dy -= 1 # if self._is_node((x, dy)): if (x, dy) in self._data["nodes"]: self._data["nodes"][(x, y)].append((x, dy)) break def _left(): dx = x while dx > 0 and self._maze.tiles[dx, y] == 0: dx -= 1 # if self._is_node((dx, y)): if (dx, y) in self._data["nodes"]: self._data["nodes"][(x, y)].append((dx, y)) break def _right(): dx = x while dx < self._maze.x - 1 and self._maze.tiles[dx, y] == 0: dx += 1 # if self._is_node((dx, y)): if (dx, y) in self._data["nodes"]: self._data["nodes"][(x, y)].append((dx, y)) break def _up(): dy = y while dy < self._maze.y - 1 and self._maze.tiles[x, dy] == 0: dy += 1 # if self._is_node((x, dy)): if (x, dy) in self._data["nodes"]: self._data["nodes"][(x, y)].append((x, dy)) break # order determines order of pathing _down() _left() _right() _up() def _is_node(self, coords: tuple) -> bool: # The start and end of a maze are nodes automatically if coords == self._maze.start or coords == self._maze.end: return True x, y = coords # If all 8 surrounding tiles are clear then ignore if not self._maze.tiles[x - 1 : x + 2, y - 1 : y + 2].any() > 0: return False # Check if the cardinal directions are open left, right, up, down = False, False, False, False if x > 0 and self._maze.tiles[x - 1, y] == 0: left = True if x < self._maze.x - 1 and self._maze.tiles[x + 1, y] == 0: right = True if y > 0 and self._maze.tiles[x, y - 1] == 0: down = True if y < self._maze.y - 1 and self._maze.tiles[x, y + 1] == 0: up = True # straight paths and deadends are not nodes, but corners are! for xbool in [left, right]: for ybool in [up, down]: if xbool and ybool: return True return False def display_nodes(self): if not "nodes" in self._data: print("[error] cannont display nodes, data missing") return for x in range(self._maze.x): line = [" "] for y in range(self._maze.y): if self._maze.tiles[x, y] == 1: line.append("#") elif (x, y) in self._data["nodes"].keys(): line.append("+") elif self._maze.tiles[x, y] == 0: line.append(" ") else: line.append("e") print(" ".join(line)) def show_path(self, path: tuple = None): """Prints to terminal the map with the path :n: drawn on it""" if not path is list: path = [tuple(p) for p in self._data["path"][:] if tuple(p) != (0, 0)] # tile type to character lookup dict characters = {0: " ", 1: "#", 2: "2", 3: "."} output = [] def _draw_path_between_points(a, b, maze): # if a is None then b should be the start if a is None: maze[b] = 3 else: dx = 0 if b[0] > a[0]: dx = 1 elif b[0] < a[0]: dx = -1 dy = 0 if b[1] > a[1]: dy = 1 elif b[1] < a[1]: dy = -1 # paint path onto maze x, y = a while (x, y) != b: x += dx y += dy maze[x, y] = 3 if path: maze = copy.deepcopy(self._maze.tiles) point_a = None for point_b in path: _draw_path_between_points(point_a, point_b, maze) point_a = point_b for x in range(maze.shape[0]): line = [" "] for y in range(maze.shape[1]): if (x, y) == path[-1]: line.append("X") else: line.append(characters[maze[x, y]]) output.append(" ".join(line)) if output: print("\n".join(output)) else: print("[error] No path to display")