def get(self, handle): ''' Build the prediction for the given agent ''' agent = self.env.agents[handle] if agent.status == RailAgentStatus.DONE_REMOVED or agent.status == RailAgentStatus.DONE: return None # Build predictions lenght, path = self.get_shortest_path(handle) edges = self.env.railway_encoding.edges_from_path( path[:self.max_depth] ) pos = self.env.railway_encoding.positions_from_path( path[:self.max_depth] ) shortest_path_prediction = Prediction( lenght=lenght, path=path[:self.max_depth], edges=edges, positions=pos ) deviation_paths_prediction = self.get_deviation_paths( handle, lenght, path ) # Update GUI visited = OrderedSet() visited.update(shortest_path_prediction.positions) self.env.dev_pred_dict[handle] = visited return (shortest_path_prediction, deviation_paths_prediction)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): """ Breath first search for a possible path from one node with a certain orientation to a target node. :param start: Start cell rom where we want to check the path :param direction: Start direction for the path we are testing :param end: Cell that we try to reach from the start cell :return: True if a path exists, False otherwise """ visited = OrderedSet() stack = [(start, direction)] while stack: node = stack.pop() node_position = node[0] node_direction = node[1] if Vec2d.is_equal(node_position, end): return True if node not in visited: visited.add(node) moves = self.get_transitions(node_position[0], node_position[1], node_direction) for move_index in range(4): if moves[move_index]: stack.append( (get_new_position(node_position, move_index), move_index)) return False
def is_simple_turn(trans): all_simple_turns = OrderedSet() for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2) # Case 1c (9) - simple turn left]: ]: for _ in range(3): trans = self.transitions.rotate_transition(trans, rotation=90) all_simple_turns.add(trans) return trans in all_simple_turns
def __init__(self): super(RailEnvTransitions, self).__init__(transitions=self.transition_list) # create this to make validation faster self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): for _ in range(3): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans) elif index in (1, 5): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans)
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, agent_position: Tuple[int, int], rail: GridTransitionMap) -> Set[RailEnvNextAction]: """ Get the valid move actions (forward, left, right) for an agent. TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient and more elegant. But given the few calls this has no priority now. Parameters ---------- agent_direction : Grid4TransitionsEnum agent_position: Tuple[int,int] rail : GridTransitionMap Returns ------- Set of `RailEnvNextAction` (tuples of (action,position,direction)) Possible move actions (forward,left,right) and the next position/direction they lead to. It is not checked that the next cell is free. """ valid_actions: Set[RailEnvNextAction] = OrderedSet() possible_transitions = rail.get_transitions(*agent_position, agent_direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right], relative to the current orientation # If only one transition is possible, the forward branch is aligned with it. if rail.is_dead_end(agent_position): action = RailEnvActions.MOVE_FORWARD exit_direction = (agent_direction + 2) % 4 if possible_transitions[exit_direction]: new_position = get_new_position(agent_position, exit_direction) valid_actions.add( RailEnvNextAction(action, new_position, exit_direction)) elif num_transitions == 1: action = RailEnvActions.MOVE_FORWARD for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[new_direction]: new_position = get_new_position(agent_position, new_direction) valid_actions.add( RailEnvNextAction(action, new_position, new_direction)) else: for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[new_direction]: if new_direction == agent_direction: action = RailEnvActions.MOVE_FORWARD elif new_direction == (agent_direction + 1) % 4: action = RailEnvActions.MOVE_RIGHT elif new_direction == (agent_direction - 1) % 4: action = RailEnvActions.MOVE_LEFT else: raise Exception("Illegal state") new_position = get_new_position(agent_position, new_direction) valid_actions.add( RailEnvNextAction(action, new_position, new_direction)) return valid_actions
def get(self, handle: int = 0) -> np.ndarray: ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class Each agent recieves an observation of length 10, where each element represents a prediction step and its value is: - 0 if no overlap is happening - 1 where n i the number of other paths crossing the predicted cell :param handle: handeled as an index of an agent :return: Observation of handle ''' observation = np.zeros(10) # We are going to track what cells where considered while building the obervation and make them accesible # For rendering visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] y_coord = self.predictions[handle][_idx][2] # We add every observed cell to the observation rendering visited.add((x_coord, y_coord)) if self.predicted_pos[_idx][handle] in np.delete( self.predicted_pos[_idx], handle, 0): # We detect if another agent is predicting to pass through the same cell at the same predicted time observation[handle] = 1 # This variable will be access by the renderer to visualize the observation self.env.dev_obs_dict[handle] = visited return observation
def prediction_from_path(self, handle, path): agent = self.env.agents[handle] prediction = np.zeros(shape=(self.max_depth + 1, 5), dtype=int) if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: # agent.status == DONE_REMOVED, prediction must be None return self.empty_prediction agent_virtual_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) # First cell is info relative to actual time step prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, RailEnvActions.MOVE_FORWARD] # TODO dell'action # If there is a shortest path, remove the initial position if path: path = path[1:] new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): action = RailEnvActions.MOVE_FORWARD # If we're at the target or not moving, stop moving until max_depth is reached # if new_position == agent.target or not agent.moving or not path: # Writing like this you don't consider the fact that the agent is stopped if new_position == agent.target or not path: prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING] visited.add((*new_position, agent.direction)) continue if index % times_per_cell == 0: new_position = path[0].position new_direction = path[0].direction action = path[0][2].action path = path[1:] # Prediction is ready prediction[index] = [index, *new_position, new_direction, action] visited.add((*new_position, new_direction)) return prediction
def get(self, handle: int = None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Does not take into account future positions of other agents! If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- handle : int, optional Handle of the agent for which to compute the observation vector. Returns ------- np.array Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] distance_map: DistanceMap = self.env.distance_map shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth) prediction_dict = {} for agent in agents: if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: prediction = np.zeros(shape=(self.max_depth + 1, 5)) for i in range(self.max_depth): prediction[i] = [i, None, None, None, None] prediction_dict[agent.handle] = prediction continue agent_virtual_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [ 0, *agent_virtual_position, agent_virtual_direction, 0 ] shortest_path = shortest_paths[agent.handle] # if there is a shortest path, remove the initial position if shortest_path: shortest_path = shortest_path[1:] new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving until max_depth is reached if new_position == agent.target or not shortest_path: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if index % times_per_cell == 0: new_position = shortest_path[0].position new_direction = shortest_path[0].direction shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] visited.add((*new_position, new_direction)) # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env! self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction return prediction_dict
def get(self, handle: int = None, handles=None, positions=None, directions=None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Does not take into account future positions of other agents! If there is no shortest path, the agent just stands still and stops moving. Parameters ---------- handle : int, optional Handle of the agent for which to compute the observation vector. Returns ------- np.array Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here (not implemented yet) The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] if handles is not None: agents = [self.env.agents[h] for h in handles] distance_map: DistanceMap = self.env.distance_map shortest_paths = get_shortest_paths(distance_map, handles=handles, max_depth=self.max_depth, branch_only=self.branch_only) prediction_dict = {} agents = deepcopy(agents) for agent in agents: if not agent.status == RailAgentStatus.ACTIVE and not agent.status == RailAgentStatus.READY_TO_DEPART: prediction = np.zeros(shape=(self.max_depth + 1, 5)) for i in range(self.max_depth): prediction[i] = [i, None, None, None, None] prediction_dict[agent.handle] = prediction continue agent_virtual_direction = agent.direction agent_virtual_position = agent.position if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [ 0, *agent_virtual_position, agent_virtual_direction, 0 ] shortest_path = shortest_paths[agent.handle] # if there is a shortest path, remove the initial position if shortest_path: shortest_path = shortest_path[1:] if positions is not None and positions.get(agent.handle, None) is not None: new_direction = directions[agent.handle] new_position = positions[agent.handle] else: new_direction = agent_virtual_direction new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): if self.branch_only: cell_transitions = self.env.rail.get_transitions( *new_position, new_direction) if np.count_nonzero(cell_transitions) > 1: break if not shortest_path: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if agent.malfunction_data["malfunction"] > 0: agent.malfunction_data["malfunction"] -= 1 prediction[index] = [ index, *new_position, None, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) continue if new_position == agent.target: prediction[index] = [ index, *new_position, new_direction, RailEnvActions.STOP_MOVING ] visited.add((*new_position, agent.direction)) break if index % times_per_cell == 0: new_position = shortest_path[0].position new_direction = shortest_path[0].direction shortest_path = shortest_path[1:] # prediction is ready prediction[index] = [index, *new_position, new_direction, 0] visited.add((*new_position, new_direction)) self.env.dev_pred_dict[agent.handle] = visited prediction_dict[agent.handle] = prediction return prediction_dict
def get_k_shortest_paths(env: RailEnv, source_position: Tuple[int, int], source_direction: int, target_position=Tuple[int, int], k: int = 1, debug=False) -> List[Tuple[Waypoint]]: """ Computes the k shortest paths using modified Dijkstra following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths. Parameters ---------- env : RailEnv source_position: Tuple[int,int] source_direction: int target_position: Tuple[int,int] k : int max number of shortest paths debug: bool print debug statements Returns ------- List[Tuple[WalkingElement]] We use tuples since we need the path elements to be hashable. We use a list of paths in order to keep the order of length. """ # P: set of shortest paths from s to t # P =empty, shortest_paths: List[Tuple[Waypoint]] = [] # countu: number of shortest paths found to node u # countu = 0, for all u in V count = {(r, c, d): 0 for r in range(env.height) for c in range(env.width) for d in range(4)} # B is a heap data structure containing paths # N.B. use OrderedSet to make result deterministic! heap: OrderedSet[Tuple[Waypoint]] = OrderedSet() # insert path Ps = {s} into B with cost 0 heap.add((Waypoint(source_position, source_direction), )) # while B is not empty and countt < K: while len(heap) > 0 and len(shortest_paths) < k: if debug: print("iteration heap={}, shortest_paths={}".format( heap, shortest_paths)) # – let Pu be the shortest cost path in B with cost C cost = np.inf pu = None for path in heap: if len(path) < cost: pu = path cost = len(path) u: Waypoint = pu[-1] if debug: print(" looking at pu={}".format(pu)) # – B = B − {Pu } heap.remove(pu) # – countu = countu + 1 urcd = (*u.position, u.direction) count[urcd] += 1 # – if u = t then P = P U {Pu} if u.position == target_position: if debug: print(" found of length {} {}".format(len(pu), pu)) shortest_paths.append(pu) # – if countu ≤ K then # CAVEAT: do not allow for loopy paths elif count[urcd] <= k: possible_transitions = env.rail.get_transitions(*urcd) if debug: print(" looking at neighbors of u={}, transitions are {}". format(u, possible_transitions)) # for each vertex v adjacent to u: for new_direction in range(4): if debug: print(" looking at new_direction={}".format( new_direction)) if possible_transitions[new_direction]: new_position = get_new_position(u.position, new_direction) if debug: print(" looking at neighbor v={}".format( (*new_position, new_direction))) v = Waypoint(position=new_position, direction=new_direction) # CAVEAT: do not allow for loopy paths if v in pu: continue # – let Pv be a new path with cost C + w(u, v) formed by concatenating edge (u, v) to path Pu pv = pu + (v, ) # – insert Pv into B heap.add(pv) # return P return shortest_paths
def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_is_switch = False last_is_dead_end = False last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = 0 other_agent_encountered = 0 other_target_encountered = 0 potential_conflict = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 total_cells = 0 total_switch = 0 total_switches_neighbors = 0 other_agent_ready_to_depart_encountered = 0 index_comparision = 0 first_switch_free = 0 first_switch_neighbor = 0 while exploring: total_cells += 1 if len(self.switches_list) > 0: if position in self.switches_list: total_switch += 1 if len(self.switches_neightbors_list) > 0: if position in self.switches_neightbors_list: total_switches_neighbors += 1 opp_a = self.env.agent_positions[position] if opp_a > -1 and opp_a != handle and index_comparision == 0: index_comparision = 2.0 * int(opp_a < handle) - 1.0 if first_switch_free == 0: if position in self.switches_list: if opp_a > -1: first_switch_free = 1.0 else: first_switch_free = -1.0 if first_switch_neighbor == 0: if position in self.switches_neightbors_list: if opp_a > -1: first_switch_neighbor = 1.0 else: first_switch_neighbor = -1.0 # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.location_has_agent_malfunction[ position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[ position] other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get( position, 0) if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction # other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[ position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed # possible_transitions = self.envs.rail.get_transitions(*position, direction) x = self.env.agent_positions[position] if x != -1: if x != handle: other_agent_encountered += 1.0 if position in self.location_has_target and position != agent.target: other_target_encountered += 1.0 if position == agent.target: own_target_encountered = 1.0 num_transitions = 1 # np.count_nonzero(possible_transitions) if num_transitions == 1: new_direction_me = direction # np.argmax(possible_transitions) new_cell_me = position # get_new_position(position, new_direction_me) a = self.env.agent_positions[new_cell_me] if a != -1 and a != handle: opp_agent = self.env.agents[a] # look one step forward possible_transitions = self.env.rail.get_transitions( *opp_agent.position, opp_agent.direction) if possible_transitions[new_direction_me] == 0: self.opposite.append(position) other_agent_opposite_direction += 1 break # path is broken else: other_agent_same_direction += 1 # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions( *position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") crossing_found = False if int(transition_bit, 2) == int('1000010000100001', 2): crossing_found = True # Register possible future conflict predicted_time = int(tot_dist * time_per_cell) # ############################# # ############################# if (position[0], position[1], direction) in visited: last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): last_is_target = True break # Check if crossing is found --> Not an unusable switch if crossing_found: # Treat the crossing as a straight rail cell total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = total_transitions if nbits == 1: # Dead-end! last_is_dead_end = True if not last_is_dead_end: # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: # Switch detected last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print( "WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) last_is_terminal = True break # `position` is either a terminal node or a switch # ############################# # ############################# # Modify here to append new / different features for each visited cell! if last_is_target: dist_to_next_branch = tot_dist dist_min_to_target = 0 elif last_is_terminal: dist_to_next_branch = np.inf dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = MyNode( dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, dist_unusable_switch=unusable_switch, dist_to_next_branch=dist_to_next_branch, dist_min_to_target=dist_min_to_target, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, total_cells=total_cells, total_switch=total_switch, total_switches_neighbors=total_switches_neighbors, index_comparision=index_comparision, first_switch_free=first_switch_free, first_switch_neighbor=first_switch_neighbor, childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions( *position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition( (*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch( handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1, depth + 1) node.childs[ self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch( handle, new_cell, branch_direction, tot_dist + 1, depth + 1) node.childs[ self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() return node, visited
class RailEnvTransitions(Grid4Transitions): """ Special case of `GridTransitions` over a 2D-grid, with a pre-defined set of transitions mimicking the types of real Swiss rail connections. As no diagonal transitions are allowed in the RailEnv environment, the possible transitions for RailEnv from a cell to its neighboring ones are represented over 16 bits. The 16 bits are organized in 4 blocks of 4 bits each, the direction that the agent is facing. E.g., the most-significant 4-bits represent the possible movements (NESW) if the agent is facing North, etc... agent's direction: North East South West agent's allowed movements: [nesw] [nesw] [nesw] [nesw] example: 1000 0000 0010 0000 In the example, the agent can move from North to South and viceversa. """ # Contains the basic transitions; # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. transition_list = [ int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch int('1000010000100001', 2), # Case 3 - diamond drossing int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical int('0010000000000000', 2), # Case 7 - dead end int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2), # Case 1c (9) - simple turn left int('1100000000100010', 2) ] # Case 2b (10) - simple switch mirrored def __init__(self): super(RailEnvTransitions, self).__init__(transitions=self.transition_list) # create this to make validation faster self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): for _ in range(3): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans) elif index in (1, 5): trans = self.rotate_transition(trans, rotation=90) self.transitions_all.add(trans) def print(self, cell_transition): print(" NESW") print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) def is_valid(self, cell_transition): """ Checks if a cell transition is a valid cell setup. Parameters ---------- cell_transition : int 64 bits used to encode the valid transitions for a cell. Returns ------- Boolean True or False """ return cell_transition in self.transitions_all
def get(self, handle: int = 0) -> Node: if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: agent_virtual_position = agent.target else: return None possible_transitions = self.env.rail.get_transitions( *agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored distance_map = self.env.distance_map.get() # was referring to TreeObsForRailEnv.Node root_node_observation = Node( dist_own_target_encountered=0, dist_other_target_encountered=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[(handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], speed_min_fractional=agent.speed_data['speed'], num_agents_ready_to_depart=0, childs={}) # print("root node type:", type(root_node_observation)) visited = OrderedSet() # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. orientation = agent.direction if num_transitions == 1: orientation = np.argmax(possible_transitions) for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: new_cell = get_new_position(agent_virtual_position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) root_node_observation.childs[ self.tree_explored_actions_char[i]] = branch_observation visited |= branch_visited else: # add cells filled with infinity if no transition is possible root_node_observation.childs[ self.tree_explored_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited return root_node_observation
def get(self, handle: int = 0) -> (Dict[str, Node], np.ndarray): """ Computes the current observation for agent `handle` in env The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for the transitions. The order is:: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] Each branch data is organized as:: [root node information] + [recursive branch data from 'left'] + [... from 'forward'] + [... from 'right] + [... from 'back'] Each node information is composed of 9 features: #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. #2: if another agents target is detected the distance in number of cells from the agents current location\ is stored #3: if another agent is detected the distance in number of cells from current agent position is stored. #4: possible conflict detected tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \ distance in number of cells from current agent position 0 = No other agent reserve the same cell at similar time #5: if an not usable switch (for agent) is detected we store the distance. #6: This feature stores the distance in number of cells to the next branching (current node) #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen #8: agent in the same direction n = number of agents present same direction \ (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction #9: agent in the opposite direction n = number of agents present other direction than myself (so conflict) \ (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself #10: malfunctioning/blokcing agents n = number of time steps the oberved agent remains blocked #11: slowest observed speed of an agent in same direction 1 if no agent is observed min_fractional speed otherwise #12: number of agents ready to depart but no yet active Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: agent_virtual_position = agent.target else: return None possible_transitions = self.env.rail.get_transitions( *agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored distance_map = self.env.distance_map.get() agent_not_started = int( agent.status.value == RailAgentStatus.READY_TO_DEPART.value) # TODO: this is hacky, find a better way root_node_observation = Node( dist_own_target_encountered=0, dist_other_target_encountered=0, own_target_encountered=0, shortest_path_direction=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[(handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], num_agents_ready_to_depart=agent_not_started, childs={}) visited = OrderedSet() orientation = agent.direction if num_transitions == 1: orientation = np.argmax(possible_transitions) min_dist = np.inf conflict_handles = [] for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: new_cell = get_new_position(agent_virtual_position, branch_direction) branch_observation, branch_visited, c_handles = self._explore_branch( handle, new_cell, branch_direction, 1, 1) root_node_observation.childs[ self.tree_explored_actions_char[i]] = branch_observation visited |= branch_visited if branch_observation.dist_min_to_target < min_dist: min_dist = branch_observation.dist_min_to_target conflict_handles = c_handles else: # add cells filled with infinity if no transition is possible root_node_observation.childs[ self.tree_explored_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited for ch in conflict_handles: self._conflict_map[handle].append(ch) nodes = [ n for n in root_node_observation.childs.values() if n != -np.inf ] for i in range(self.max_depth): if len(nodes) < 1: break shortest_path_node = min(nodes, key=lambda n: n.dist_min_to_target) shortest_path_node = shortest_path_node._replace( shortest_path_direction=1.) nodes = [ n for n in shortest_path_node.childs.values() if n != -np.inf ] return root_node_observation
def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_is_switch = False last_is_dead_end = False last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf potential_conflict = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 other_agent_ready_to_depart_encountered = 0 found_closest_communication = False communication = None while exploring: # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: if self.location_has_agent_communication[position] is not None and not found_closest_communication: found_closest_communication = True, communication = self.location_has_agent_communication[position] if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0) if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed # Other direction agents # TODO: Test that this behavior is as expected other_agent_opposite_direction += \ self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), 0) else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += self.location_has_agent[position] # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") crossing_found = False if int(transition_bit, 2) == int('1000010000100001', 2): crossing_found = True # Register possible future conflict predicted_time = int(tot_dist * time_per_cell) if self.predictor and predicted_time < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: pre_step = max(0, predicted_time - 1) post_step = min(self.max_prediction_depth - 1, predicted_time + 1) # Look for conflicting paths at distance tot_dist if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0): conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[ self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] \ and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: if tot_dist < other_target_encountered: other_target_encountered = tot_dist if position == agent.target and tot_dist < own_target_encountered: own_target_encountered = tot_dist # ############################# # ############################# if (position[0], position[1], direction) in visited: last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): last_is_target = True break # Check if crossing is found --> Not an unusable switch if crossing_found: # Treat the crossing as a straight rail cell total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = total_transitions if nbits == 1: # Dead-end! last_is_dead_end = True if not last_is_dead_end: # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: # Switch detected last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) last_is_terminal = True break # `position` is either a terminal node or a switch # ############################# # ############################# # Modify here to append new / different features for each visited cell! if last_is_target: dist_to_next_branch = tot_dist dist_min_to_target = 0 elif last_is_terminal: dist_to_next_branch = np.inf dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, dist_unusable_switch=unusable_switch, dist_to_next_branch=dist_to_next_branch, dist_min_to_target=dist_min_to_target, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, communication=communication, childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() return node, visited
def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d. get_manhattan_distance, avoid_rails=False, respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: """ :param avoid_rails: :param grid_map: Grid Map where the path is found in :param start: Start positions as (row,column) :param end: End position as (row,column) :param a_star_distance_function: Define the distance function to use as heuristc: -get_euclidean_distance -get_manhattan_distance -get_chebyshev_distance :param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map. - True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found - False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards :param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map :return: IF a path is found a ordered list of al cells in path is returned """ """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. """ rail_shape = grid_map.grid.shape start_node = AStarNode(start, None) end_node = AStarNode(end, None) open_nodes = OrderedSet() closed_nodes = OrderedSet() open_nodes.add(start_node) while len(open_nodes) > 0: # get node with current shortest est. path (lowest f) current_node = None for item in open_nodes: if current_node is None: current_node = item continue if item.f < current_node.f: current_node = item # pop current off open list, add to closed list open_nodes.remove(current_node) closed_nodes.add(current_node) # found the goal if current_node == end_node: path = [] current = current_node while current is not None: path.append(current.pos) current = current.parent # return reversed path return path[::-1] # generate children children = [] if current_node.parent is not None: prev_pos = current_node.parent.pos else: prev_pos = None for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: # update the "current" pos node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos) # is node_pos inside the grid? if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[ 1] >= rail_shape[1] or node_pos[1] < 0: continue # validate positions # if not grid_map.validate_new_transition( prev_pos, current_node.pos, node_pos, end_node.pos) and respect_transition_validity: continue # create new node new_node = AStarNode(node_pos, current_node) # Skip paths through forbidden regions if they are provided if forbidden_cells is not None: if node_pos in forbidden_cells and new_node != start_node and new_node != end_node: continue children.append(new_node) # loop through children for child in children: # already in closed list? if child in closed_nodes: continue # create the f, g, and h values child.g = current_node.g + 1.0 # this heuristic avoids diagonal paths if avoid_rails: child.h = a_star_distance_function( child.pos, end_node.pos) + np.clip( grid_map.grid[child.pos], 0, 1) else: child.h = a_star_distance_function(child.pos, end_node.pos) child.f = child.g + child.h # already in the open list? if child in open_nodes: continue # add the child to the open list open_nodes.add(child) # no full path found if len(open_nodes) == 0: return []
def get(self, handle: int = 0) -> Node: """ Computes the current observation for agent `handle` in env The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for the transitions. The order is:: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] Each branch data is organized as:: [root node information] + [recursive branch data from 'left'] + [... from 'forward'] + [... from 'right] + [... from 'back'] Each node information is composed of 9 features: #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. #2: if another agents target is detected the distance in number of cells from the agents current location\ is stored #3: if another agent is detected the distance in number of cells from current agent position is stored. #4: possible conflict detected tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \ distance in number of cells from current agent position 0 = No other agent reserve the same cell at similar time #5: if an not usable switch (for agent) is detected we store the distance. #6: This feature stores the distance in number of cells to the next branching (current node) #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen #8: agent in the same direction n = number of agents present same direction \ (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction #9: agent in the opposite direction n = number of agents present other direction than myself (so conflict) \ (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself #10: malfunctioning/blokcing agents n = number of time steps the oberved agent remains blocked #11: slowest observed speed of an agent in same direction 1 if no agent is observed min_fractional speed otherwise #12: number of agents ready to depart but no yet active Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} self.location_has_agent = {} self.location_has_agent_direction = {} self.location_has_agent_speed = {} self.location_has_agent_malfunction = {} self.location_has_agent_ready_to_depart = {} self.location_has_agent_communication = {} for _agent in self.env.agents: if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ _agent.position: try: _agent.communication except: _agent.communication = None self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] self.location_has_agent_communication[tuple(_agent.position)] = _agent.communication if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ _agent.initial_position: self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: return None possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored distance_map = self.env.distance_map.get() root_node_observation = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[ (handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], speed_min_fractional=agent.speed_data['speed'], num_agents_ready_to_depart=0, communication=np.zeros(5), childs={}) visited = OrderedSet() # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. orientation = agent.direction if num_transitions == 1: orientation = np.argmax(possible_transitions) for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: new_cell = get_new_position(agent_virtual_position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation visited |= branch_visited else: # add cells filled with infinity if no transition is possible root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited return root_node_observation
def get(self, handle: int = 0): agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: agent_virtual_position = agent.target else: return None if agent.position: possible_transitions = self.env.rail.get_transitions( *agent.position, agent.direction) else: possible_transitions = self.env.rail.get_transitions( *agent.initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions # are available; # organize them as [left, forward, right], relative to # the current orientation # If only one transition is possible, the forward branch is # aligned with it. distance_map = self.env.distance_map.get() max_distance = self.env.width + self.env.height # max_steps = int(4 * 2 * (20 + self.env.height + self.env.width)) visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap # with agents own predictions x_coord = self.predictions[handle][_idx][1] y_coord = self.predictions[handle][_idx][2] # We add every observed cell to the observation rendering visited.add((x_coord, y_coord)) # This variable will be access by the renderer to # visualize the observation self.env.dev_obs_dict[handle] = visited # min_distance stores the distance to target in shortest path # and any alternate path if exists min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent_virtual_position, direction) min_distances.append(distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) if num_transitions == 1: observation1 = [0, 1, 0] observation2 = observation1 elif num_transitions == 2: idx = np.argpartition(np.array(min_distances), 2) observation1 = [0, 0, 0] observation1[idx[0]] = 1 observation2 = [0, 0, 0] observation2[idx[1]] = 1 min_distances = np.sort(min_distances) incremental_distances = np.diff(np.sort(min_distances)) incremental_distances[incremental_distances == np.inf] = 0 incremental_distances[np.isnan(incremental_distances)] = 0 distance_target = distance_map[(handle, *agent_virtual_position, agent.direction)] root_node_observation = LocalConflictObsForRailEnv.Node( distance_target=distance_target / max_distance, observation_shortest=observation1, observation_next_shortest=observation2, extra_distance=incremental_distances[0] / max_distance, malfunction=agent.malfunction_data['malfunction'] / max_distance, malfunction_rate=agent.malfunction_data['malfunction_rate'], next_malfunction=agent.malfunction_data['next_malfunction'] / max_distance, nr_malfunctions=agent.malfunction_data['nr_malfunctions'], speed=agent.speed_data['speed'], position_fraction=agent.speed_data['position_fraction'], transition_action_on_cellexit=agent. speed_data['transition_action_on_cellexit'], num_transitions=num_transitions, moving=agent.moving, status=agent.status, action_required=action_required(agent), width=self.env.width, height=self.env.height, n_agents=self.get_number_of_agents(), predictions=self.predictions[handle], predicted_pos=self.predicted_pos) return root_node_observation