def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): """ Breath first search for a possible path from one node with a certain orientation to a target node. :param start: Start cell rom where we want to check the path :param direction: Start direction for the path we are testing :param end: Cell that we try to reach from the start cell :return: True if a path exists, False otherwise """ visited = OrderedSet() stack = [(start, direction)] while stack: node = stack.pop() node_position = node[0] node_direction = node[1] if Vec2d.is_equal(node_position, end): return True if node not in visited: visited.add(node) moves = self.get_transitions(node_position[0], node_position[1], node_direction) for move_index in range(4): if moves[move_index]: stack.append( (get_new_position(node_position, move_index), move_index)) return False
def is_simple_turn(trans): all_simple_turns = OrderedSet() for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2) # Case 1c (9) - simple turn left]: ]: for _ in range(3): trans = self.transitions.rotate_transition(trans, rotation=90) all_simple_turns.add(trans) return trans in all_simple_turns
def prediction_from_path(self, handle, path): agent = self.env.agents[handle] prediction = np.zeros(shape=(self.max_depth + 1, 5), dtype=int) if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: # agent.status == DONE_REMOVED, prediction must be None return self.empty_prediction agent_virtual_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) # First cell is info relative to actual time step prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, RailEnvActions.MOVE_FORWARD] # TODO dell'action # If there is a shortest path, remove the initial position if path: path = path[1:] new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): action = RailEnvActions.MOVE_FORWARD # If we're at the target or not moving, stop moving until max_depth is reached # if new_position == agent.target or not agent.moving or not path: # Writing like this you don't consider the fact that the agent is stopped if new_position == agent.target or not path: prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING] visited.add((*new_position, agent.direction)) continue if index % times_per_cell == 0: new_position = path[0].position new_direction = path[0].direction action = path[0][2].action path = path[1:] # Prediction is ready prediction[index] = [index, *new_position, new_direction, action] visited.add((*new_position, new_direction)) return prediction
def get(self, handle: int = 0) -> np.ndarray: ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class Each agent recieves an observation of length 10, where each element represents a prediction step and its value is: - 0 if no overlap is happening - 1 where n i the number of other paths crossing the predicted cell :param handle: handeled as an index of an agent :return: Observation of handle ''' observation = np.zeros(10) # We are going to track what cells where considered while building the obervation and make them accesible # For rendering visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] y_coord = self.predictions[handle][_idx][2] # We add every observed cell to the observation rendering visited.add((x_coord, y_coord)) if self.predicted_pos[_idx][handle] in np.delete( self.predicted_pos[_idx], handle, 0): # We detect if another agent is predicting to pass through the same cell at the same predicted time observation[handle] = 1 # This variable will be access by the renderer to visualize the observation self.env.dev_obs_dict[handle] = visited return observation
def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_is_switch = False last_is_dead_end = False last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf potential_conflict = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 other_agent_ready_to_depart_encountered = 0 found_closest_communication = False communication = None while exploring: # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: if self.location_has_agent_communication[position] is not None and not found_closest_communication: found_closest_communication = True, communication = self.location_has_agent_communication[position] if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0) if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed # Other direction agents # TODO: Test that this behavior is as expected other_agent_opposite_direction += \ self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), 0) else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += self.location_has_agent[position] # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") crossing_found = False if int(transition_bit, 2) == int('1000010000100001', 2): crossing_found = True # Register possible future conflict predicted_time = int(tot_dist * time_per_cell) if self.predictor and predicted_time < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: pre_step = max(0, predicted_time - 1) post_step = min(self.max_prediction_depth - 1, predicted_time + 1) # Look for conflicting paths at distance tot_dist if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0): conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[ self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] \ and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: if tot_dist < other_target_encountered: other_target_encountered = tot_dist if position == agent.target and tot_dist < own_target_encountered: own_target_encountered = tot_dist # ############################# # ############################# if (position[0], position[1], direction) in visited: last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): last_is_target = True break # Check if crossing is found --> Not an unusable switch if crossing_found: # Treat the crossing as a straight rail cell total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = total_transitions if nbits == 1: # Dead-end! last_is_dead_end = True if not last_is_dead_end: # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: # Switch detected last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) last_is_terminal = True break # `position` is either a terminal node or a switch # ############################# # ############################# # Modify here to append new / different features for each visited cell! if last_is_target: dist_to_next_branch = tot_dist dist_min_to_target = 0 elif last_is_terminal: dist_to_next_branch = np.inf dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, dist_unusable_switch=unusable_switch, dist_to_next_branch=dist_to_next_branch, dist_min_to_target=dist_min_to_target, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, communication=communication, childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() return node, visited
def get(self, handle: int = None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Does not take into account future positions of other agents! If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- handle : int, optional Handle of the agent for which to compute the observation vector. Returns ------- np.array Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] distance_map: DistanceMap = self.env.distance_map shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth) prediction_dict = {} for agent in agents: if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: prediction = np.zeros(shape=(self.max_depth + 1, 5)) for i in range(self.max_depth): prediction[i] = [i, None, None, None, None] prediction_dict[agent.handle] = prediction continue agent_virtual_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [ 0, *agent_virtual_position, agent_virtual_direction, 0 ] shortest_path = shortest_paths[agent.handle] # if there is a shortest path, remove the initial position if shortest_path: shortest_path = shortest_path[1:] new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving until max_depth is reached if new_position == agent.target or not shortest_path: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if index % times_per_cell == 0: new_position = shortest_path[0].position new_direction = shortest_path[0].direction shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] visited.add((*new_position, new_direction)) # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env! self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction return prediction_dict
def get(self, handle: int = None, handles=None, positions=None, directions=None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Does not take into account future positions of other agents! If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- handle : int, optional Handle of the agent for which to compute the observation vector. Returns ------- np.array Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] if handles is not None: agents = [self.env.agents[h] for h in handles] distance_map: DistanceMap = self.env.distance_map shortest_paths = get_shortest_paths(distance_map, handles=handles, max_depth=self.max_depth, branch_only=self.branch_only) prediction_dict = {} agents = deepcopy(agents) for agent in agents: if not agent.status == RailAgentStatus.ACTIVE and not agent.status == RailAgentStatus.READY_TO_DEPART: prediction = np.zeros(shape=(self.max_depth + 1, 5)) for i in range(self.max_depth): prediction[i] = [i, None, None, None, None] prediction_dict[agent.handle] = prediction continue agent_virtual_direction = agent.direction agent_virtual_position = agent.position if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [ 0, *agent_virtual_position, agent_virtual_direction, 0 ] shortest_path = shortest_paths[agent.handle] # if there is a shortest path, remove the initial position if shortest_path: shortest_path = shortest_path[1:] if positions is not None and positions.get(agent.handle, None) is not None: new_direction = directions[agent.handle] new_position = positions[agent.handle] else: new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): if self.branch_only: cell_transitions = self.env.rail.get_transitions( *new_position, new_direction) if np.count_nonzero(cell_transitions) > 1: break if not shortest_path: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if agent.malfunction_data["malfunction"] > 0: agent.malfunction_data["malfunction"] -= 1 prediction[index] = [ index, *new_position, None, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if new_position == agent.target: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) break if index % times_per_cell == 0: new_position = shortest_path[0].position new_direction = shortest_path[0].direction shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] visited.add((*new_position, new_direction)) self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction return prediction_dict
def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_is_switch = False last_is_dead_end = False last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = 0 other_agent_encountered = 0 other_target_encountered = 0 potential_conflict = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 total_cells = 0 total_switch = 0 total_switches_neighbors = 0 other_agent_ready_to_depart_encountered = 0 index_comparision = 0 first_switch_free = 0 first_switch_neighbor = 0 while exploring: total_cells += 1 if len(self.switches_list) > 0: if position in self.switches_list: total_switch += 1 if len(self.switches_neightbors_list) > 0: if position in self.switches_neightbors_list: total_switches_neighbors += 1 opp_a = self.env.agent_positions[position] if opp_a > -1 and opp_a != handle and index_comparision == 0: index_comparision = 2.0 * int(opp_a < handle) - 1.0 if first_switch_free == 0: if position in self.switches_list: if opp_a > -1: first_switch_free = 1.0 else: first_switch_free = -1.0 if first_switch_neighbor == 0: if position in self.switches_neightbors_list: if opp_a > -1: first_switch_neighbor = 1.0 else: first_switch_neighbor = -1.0 # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.location_has_agent_malfunction[ position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[ position] other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get( position, 0) if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction # other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[ position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed # possible_transitions = self.envs.rail.get_transitions(*position, direction) x = self.env.agent_positions[position] if x != -1: if x != handle: other_agent_encountered += 1.0 if position in self.location_has_target and position != agent.target: other_target_encountered += 1.0 if position == agent.target: own_target_encountered = 1.0 num_transitions = 1 # np.count_nonzero(possible_transitions) if num_transitions == 1: new_direction_me = direction # np.argmax(possible_transitions) new_cell_me = position # get_new_position(position, new_direction_me) a = self.env.agent_positions[new_cell_me] if a != -1 and a != handle: opp_agent = self.env.agents[a] # look one step forward possible_transitions = self.env.rail.get_transitions( *opp_agent.position, opp_agent.direction) if possible_transitions[new_direction_me] == 0: self.opposite.append(position) other_agent_opposite_direction += 1 break # path is broken else: other_agent_same_direction += 1 # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions( *position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") crossing_found = False if int(transition_bit, 2) == int('1000010000100001', 2): crossing_found = True # Register possible future conflict predicted_time = int(tot_dist * time_per_cell) # ############################# # ############################# if (position[0], position[1], direction) in visited: last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): last_is_target = True break # Check if crossing is found --> Not an unusable switch if crossing_found: # Treat the crossing as a straight rail cell total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = total_transitions if nbits == 1: # Dead-end! last_is_dead_end = True if not last_is_dead_end: # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: # Switch detected last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print( "WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) last_is_terminal = True break # `position` is either a terminal node or a switch # ############################# # ############################# # Modify here to append new / different features for each visited cell! if last_is_target: dist_to_next_branch = tot_dist dist_min_to_target = 0 elif last_is_terminal: dist_to_next_branch = np.inf dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = MyNode( dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, dist_unusable_switch=unusable_switch, dist_to_next_branch=dist_to_next_branch, dist_min_to_target=dist_min_to_target, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, total_cells=total_cells, total_switch=total_switch, total_switches_neighbors=total_switches_neighbors, index_comparision=index_comparision, first_switch_free=first_switch_free, first_switch_neighbor=first_switch_neighbor, childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions( *position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition( (*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch( handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1, depth + 1) node.childs[ self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch( handle, new_cell, branch_direction, tot_dist + 1, depth + 1) node.childs[ self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() return node, visited
def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d. get_manhattan_distance, avoid_rails=False, respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: """ :param avoid_rails: :param grid_map: Grid Map where the path is found in :param start: Start positions as (row,column) :param end: End position as (row,column) :param a_star_distance_function: Define the distance function to use as heuristc: -get_euclidean_distance -get_manhattan_distance -get_chebyshev_distance :param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map. - True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found - False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards :param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map :return: IF a path is found a ordered list of al cells in path is returned """ """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. """ rail_shape = grid_map.grid.shape start_node = AStarNode(start, None) end_node = AStarNode(end, None) open_nodes = OrderedSet() closed_nodes = OrderedSet() open_nodes.add(start_node) while len(open_nodes) > 0: # get node with current shortest est. path (lowest f) current_node = None for item in open_nodes: if current_node is None: current_node = item continue if item.f < current_node.f: current_node = item # pop current off open list, add to closed list open_nodes.remove(current_node) closed_nodes.add(current_node) # found the goal if current_node == end_node: path = [] current = current_node while current is not None: path.append(current.pos) current = current.parent # return reversed path return path[::-1] # generate children children = [] if current_node.parent is not None: prev_pos = current_node.parent.pos else: prev_pos = None for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: # update the "current" pos node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos) # is node_pos inside the grid? if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[ 1] >= rail_shape[1] or node_pos[1] < 0: continue # validate positions # if not grid_map.validate_new_transition( prev_pos, current_node.pos, node_pos, end_node.pos) and respect_transition_validity: continue # create new node new_node = AStarNode(node_pos, current_node) # Skip paths through forbidden regions if they are provided if forbidden_cells is not None: if node_pos in forbidden_cells and new_node != start_node and new_node != end_node: continue children.append(new_node) # loop through children for child in children: # already in closed list? if child in closed_nodes: continue # create the f, g, and h values child.g = current_node.g + 1.0 # this heuristic avoids diagonal paths if avoid_rails: child.h = a_star_distance_function( child.pos, end_node.pos) + np.clip( grid_map.grid[child.pos], 0, 1) else: child.h = a_star_distance_function(child.pos, end_node.pos) child.f = child.g + child.h # already in the open list? if child in open_nodes: continue # add the child to the open list open_nodes.add(child) # no full path found if len(open_nodes) == 0: return []
class RailEnvTransitions(Grid4Transitions): """ Special case of `GridTransitions` over a 2D-grid, with a pre-defined set of transitions mimicking the types of real Swiss rail connections. As no diagonal transitions are allowed in the RailEnv environment, the possible transitions for RailEnv from a cell to its neighboring ones are represented over 16 bits. The 16 bits are organized in 4 blocks of 4 bits each, the direction that the agent is facing. E.g., the most-significant 4-bits represent the possible movements (NESW) if the agent is facing North, etc... agent's direction: North East South West agent's allowed movements: [nesw] [nesw] [nesw] [nesw] example: 1000 0000 0010 0000 In the example, the agent can move from North to South and viceversa. """ # Contains the basic transitions; # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. transition_list = [ int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch int('1000010000100001', 2), # Case 3 - diamond drossing int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical int('0010000000000000', 2), # Case 7 - dead end int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2), # Case 1c (9) - simple turn left int('1100000000100010', 2) ] # Case 2b (10) - simple switch mirrored def __init__(self): super(RailEnvTransitions, self).__init__(transitions=self.transition_list) # create this to make validation faster self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): for _ in range(3): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans) elif index in (1, 5): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans) def print(self, cell_transition): print(" NESW") print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) def is_valid(self, cell_transition): """ Checks if a cell transition is a valid cell setup. Parameters ---------- cell_transition : int 64 bits used to encode the valid transitions for a cell. Returns ------- Boolean True or False """ return cell_transition in self.transitions_all
def get(self, handle: int = 0): agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: return None if agent.position: possible_transitions = self.env.rail.get_transitions( *agent.position, agent.direction) else: possible_transitions = self.env.rail.get_transitions( *agent.initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions # are available; # organize them as [left, forward, right], relative to # the current orientation # If only one transition is possible, the forward branch is # aligned with it. distance_map = self.env.distance_map.get() max_distance = self.env.width + self.env.height # max_steps = int(4 * 2 * (20 + self.env.height + self.env.width)) visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap # with agents own predictions x_coord = self.predictions[handle][_idx][1] y_coord = self.predictions[handle][_idx][2] # We add every observed cell to the observation rendering visited.add((x_coord, y_coord)) # This variable will be access by the renderer to # visualize the observation self.env.dev_obs_dict[handle] = visited # min_distance stores the distance to target in shortest path # and any alternate path if exists min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent_virtual_position, direction) min_distances.append(distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) if num_transitions == 1: observation1 = [0, 1, 0] observation2 = observation1 elif num_transitions == 2: idx = np.argpartition(np.array(min_distances), 2) observation1 = [0, 0, 0] observation1[idx[0]] = 1 observation2 = [0, 0, 0] observation2[idx[1]] = 1 min_distances = np.sort(min_distances) incremental_distances = np.diff(np.sort(min_distances)) incremental_distances[incremental_distances == np.inf] = 0 incremental_distances[np.isnan(incremental_distances)] = 0 distance_target = distance_map[(handle, *agent_virtual_position, agent.direction)] root_node_observation = LocalConflictObsForRailEnv.Node( distance_target=distance_target / max_distance, observation_shortest=observation1, observation_next_shortest=observation2, extra_distance=incremental_distances[0] / max_distance, malfunction=agent.malfunction_data['malfunction'] / max_distance, malfunction_rate=agent.malfunction_data['malfunction_rate'], next_malfunction=agent.malfunction_data['next_malfunction'] / max_distance, nr_malfunctions=agent.malfunction_data['nr_malfunctions'], speed=agent.speed_data['speed'], position_fraction=agent.speed_data['position_fraction'], transition_action_on_cellexit=agent. speed_data['transition_action_on_cellexit'], num_transitions=num_transitions, moving=agent.moving, status=agent.status, action_required=action_required(agent), width=self.env.width, height=self.env.height, n_agents=self.get_number_of_agents(), predictions=self.predictions[handle], predicted_pos=self.predicted_pos) return root_node_observation