def get(self, handle: int = 0) -> Node: if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: return None possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Root node # Here information about the agent itself is stored distance_map = self.env.distance_map.get() root_node_observation = TreeObsForRailEnv.Node(dist_other_agent_encountered=0, dist_to_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[(handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], max_index_oppposite_direction=0, max_handle_agent_not_opposite=0, has_deadlocked_agent=0, first_agent_handle=0, first_agent_not_opposite=0, childs=[None]*4) # Child nodes orientation = agent.direction if num_transitions == 1: # Whaat? TODO orientation = np.argmax(possible_transitions) for i, branch_direction in enumerate([(orientation + 4 + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: new_cell = get_new_position(agent_virtual_position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) root_node_observation.childs[i] = branch_observation return root_node_observation
def dfs(self, h, w, d): self.dfs_used[h, w, d] = 1 possible_transitions = self.env.rail.get_transitions(h, w, d) for ndir in range(4): if possible_transitions[ndir]: nh, nw = get_new_position((h, w), ndir) if (nh, nw, ndir) in self.nodes: self.dist_to_node[h, w, d, ndir] = 1 self.next_node[h][w][d][ndir] = (nh, nw, ndir) else: if not self.dfs_used[nh, nw, ndir]: self.dfs(nh, nw, ndir) for last_dir in range(4): if self.dist_to_node[nh, nw, ndir, last_dir] > -0.5: self.dist_to_node[h, w, d, ndir] = self.dist_to_node[ nh, nw, ndir, last_dir] + 1 self.next_node[h][w][d][ndir] = self.next_node[nh][ nw][ndir][last_dir]
def _build_shortest_path(env, handle): agent = env.agents[handle] pos = agent.initial_position dir = agent.initial_direction dist_min_to_target = env.obs_builder.rail_graph.dist_to_target(handle, pos[0], pos[1], dir) path = set() while dist_min_to_target: path.add((*pos, dir)) possible_transitions = env.rail.get_transitions(*pos, dir) for new_dir in range(4): if possible_transitions[new_dir]: new_pos = get_new_position(pos, new_dir) new_min_dist = env.obs_builder.rail_graph.dist_to_target(handle, new_pos[0], new_pos[1], new_dir) if new_min_dist + 1 == dist_min_to_target: dist_min_to_target = new_min_dist pos, dir = new_pos, new_dir break return path
def reinit_greedy(self): self.greedy_way = defaultdict(int) rail_env = self.env self.location_has_target = set(agent.target for agent in self.env.agents) self.switches = set() for h in range(rail_env.height): for w in range(rail_env.width): pos = (h, w) transition_bit = bin(self.env.rail.get_full_transitions(*pos)) total_transitions = transition_bit.count("1") if total_transitions > 2: self.switches.add(pos) self.target_neighbors = set() self.switches_neighbors = set() for h in range(rail_env.height): for w in range(rail_env.width): pos = (h, w) for orientation in range(4): possible_transitions = self.env.rail.get_transitions(*pos, orientation) for ndir in range(4): if possible_transitions[ndir]: nxt = get_new_position(pos, ndir) if nxt in self.location_has_target: self.target_neighbors.add((h, w, orientation)) if nxt in self.switches: self.switches_neighbors.add((h, w, orientation)) self.decision_cells = np.zeros((self.env.height, self.env.width, 4), dtype=np.bool) for posdir in self.switches_neighbors.union(self.target_neighbors): self.decision_cells[posdir] = 1 for pos in self.switches.union(self.location_has_target): self.decision_cells[pos[0], pos[1], :] = 1 self.location_has_target_array = np.zeros((self.env.height, self.env.width), dtype=np.bool) for pos in self.location_has_target: self.location_has_target_array[pos] = 1 self.location_has_target = self.location_has_target_array
def _explore_branch(self, handle, position, direction, depth, was_target=False): start_position = (position[0], position[1], direction) if start_position in self.cached_nodes: node, position, direction = self.cached_nodes[start_position] else: node, position, direction = self._explore_line(position, direction) self.cached_nodes[start_position] = node, position, direction last_is_target = (position == self.env.agents[handle].target) if was_target: dist_min_to_target = -1 elif last_is_target: dist_min_to_target = 0 else: dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = node._replace(dist_min_to_target=dist_min_to_target, childs=[None]*4) # copy returned here! if depth == self.max_depth: return node if node.has_deadlocked_agent: # thee end return node # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth + 1, was_target=last_is_target or was_target) node.childs[i] = branch_observation return node
def _build(self): self.nodes = set(agent.target for agent in self.env.agents) height, width = self.env.height, self.env.width self.valid_pos = list() for h in range(height): for w in range(width): pos = (h, w) transition_bit = bin(self.env.rail.get_full_transitions(*pos)) total_transitions = transition_bit.count("1") if total_transitions > 2: self.nodes.add(pos) if total_transitions > 0: self.valid_pos.append((h, w)) n_nodes = set() for h, w in self.nodes: for d in range(4): cell_transitions = self.env.rail.get_transitions(h, w, d) if np.any(cell_transitions): n_nodes.add((h, w, d)) self.nodes = n_nodes self.dist_to_node = -np.ones((height, width, 4, 4)) self.next_node = [[[[None for _ in range(4)] for _ in range(4)] for _ in range(width)] for _ in range(height)] self.dfs_used = np.zeros((height, width, 4)) for h in range(height): for w in range(width): for d in range(4): if not self.dfs_used[h, w, d]: self.dfs(h, w, d) self.n_nodes = len(self.nodes) self.nodes_dict = np.empty((height, width, 4), dtype=np.int) for i, (h, w, d) in enumerate(self.nodes): self.nodes_dict[h, w, d] = i self.cell_to_edge = [[list() for _ in range(width)] for _ in range(height)] self.amatrix = np.ones((self.n_nodes, self.n_nodes)) * np.inf self.amatrix[np.arange(self.n_nodes), np.arange(self.n_nodes)] = 0 for i, (h, w, d) in enumerate(self.nodes): for dd in range(4): nnode = self.next_node[h][w][d][dd] if nnode is not None: self.amatrix[i][ self.nodes_dict[nnode]] = self.dist_to_node[h, w, d, dd] cell = (h, w, d) nnode_i = self.nodes_dict[nnode] while cell != nnode: possible_transitions = self.env.rail.get_transitions( *cell) for ndir in range(4): if possible_transitions[ndir] and (cell != (h, w, d) or ndir == dd): nh, nw = get_new_position((cell[0], cell[1]), ndir) cell = (nh, nw, ndir) self.cell_to_edge[nh][nw].append((i, nnode_i)) break
def _explore_line(self, position, direction): # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops # TODO Whaat? Treat dead-ends not as nodes tot_dist = 1 other_agent_encountered = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 num_agents_malfunctioning = 0 max_index_oppposite_direction = 0 has_deadlocked_agent = 0 first_agent_handle = -1 first_agent_not_opposite = True max_handle_agent_not_opposite = True max_handle = -1 while True: if self.location_with_agent[position] != -1: if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist other_handle = self.location_with_agent[position] # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.env.agents[other_handle].malfunction_data['malfunction'] > 0: num_agents_malfunctioning += 1 if self.deadlock_checker.is_deadlocked(other_handle): # hack SimpleObservation has deadlock_checker has_deadlocked_agent = 1 elif first_agent_handle == -1: first_agent_handle = other_handle first_agent_not_opposite = (self.env.agents[first_agent_handle].direction == direction) if self.env.agents[other_handle].direction == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += 1 if other_handle > max_handle: max_handle = other_handle max_handle_agent_not_opposite = True else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += 1 max_index_oppposite_direction = max(max_index_oppposite_direction, other_handle) if other_handle > max_handle: max_handle = other_handle max_handle_agent_not_opposite = False if self.greedy_checker.location_has_target[position]: break only_direction = self.graph.get_if_one_transition(position, direction) if only_direction != -1: position = get_new_position(position, only_direction) direction = only_direction tot_dist += 1 continue # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) num_transitions = np.count_nonzero(cell_transitions) total_transitions = self.graph.get_total_transitions(position) # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction if total_transitions == 1: # Dead-end! break assert False elif num_transitions > 0: # Switch detected break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) break node = TreeObsForRailEnv.Node(dist_other_agent_encountered=other_agent_encountered, dist_to_unusable_switch=unusable_switch, dist_to_next_branch=tot_dist, dist_min_to_target=None, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=num_agents_malfunctioning, max_index_oppposite_direction=max_index_oppposite_direction, max_handle_agent_not_opposite=max_handle_agent_not_opposite, has_deadlocked_agent=has_deadlocked_agent, first_agent_handle=first_agent_handle, first_agent_not_opposite=first_agent_not_opposite, childs=[None]*4) return node, position, direction