def _compute_overlapping_paths_with_current_ts(self, handle): """ """ agent = self.env.agents[handle] overlapping_paths = np.zeros( (self.env.get_num_agents(), self.max_prediction_depth + 1), dtype=int) cells_sequence = self.predicted_pos_list[handle] # Prepend current ts if agent.status == RailAgentStatus.ACTIVE: virtual_position = agent.position elif agent.status == RailAgentStatus.READY_TO_DEPART: virtual_position = agent.initial_position int_pos = coordinate_to_position(self.env.width, [virtual_position]) cells_sequence = np.append(int_pos[0], cells_sequence) for a in range(len(self.env.agents)): if a != handle and self.env.agents[ a].status == RailAgentStatus.ACTIVE: i = 0 # Prepend other agents current ts other_agent_cells_sequence = self.predicted_pos_list[a] other_int_pos = coordinate_to_position( self.env.width, [self.env.agents[a].position]) other_agent_cells_sequence = np.append( other_int_pos[0], other_agent_cells_sequence) for pos in cells_sequence: if pos in other_agent_cells_sequence: overlapping_paths[a, i] = 1 i += 1 return overlapping_paths
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ if handles is None: handles = [] if self.predictor: self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: for t in range(self.predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: if self.predictions[a] is None: continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) observations = super().get_many(handles) return observations
def map_predictions(self, handles=None, positions=None, directions=None): if handles is None: handles = [a.handle for a in self.rail_env.agents] if self._predictor: self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} predictions = self._predictor.get(handles=handles, positions=positions, directions=directions) if predictions: for t in range(self._predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: if predictions[a] is None: continue pos_list.append(predictions[a][t][1:3]) dir_list.append(predictions[a][t][3]) self.predicted_pos.update({ t: coordinate_to_position(self.rail_env.width, pos_list) }) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos)
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]: ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function Here we can call the predictor just ones for all the agents and use the predictions to generate our observations :param handles: :return: ''' self.predictions = self.predictor.get() self.predicted_pos = {} if handles is None: handles = [] for t in range(len(self.predictions[0])): pos_list = [] for a in handles: pos_list.append(self.predictions[a][t][1:3]) # We transform (x,y) coodrinates to a single integer number for simpler comparison self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) observations = super().get_many(handles) return observations
def _get_predictions(self, timestep: int, handles: list, predictions): positions = np.zeros(len(handles)) directions = np.zeros(len(handles)) for i, h in enumerate(handles): positions[i] = coordinate_to_position( self.rail_env.width, [predictions[h][timestep][1:3]])[0] directions[i] = (predictions[h][timestep][3]) return positions, directions
def detect_conflicts_multi(self, position, agent, direction, handles=None, break_after_first=False, only_branch=False, tot_dist=1): conflict_handles = [] time_per_cell = int(np.reciprocal(agent.speed_data["speed"])) predicted_time = int(tot_dist * time_per_cell) handle_mask = np.zeros(len(handles)) handle_mask[handles.index(agent.handle)] = np.inf malfunctions = [] if predicted_time < self.max_prediction_depth and position != agent.target: int_position = coordinate_to_position(self.rail_env.width, [position]) if tot_dist < self.max_prediction_depth: pred_times = [max(0, predicted_time - 1), predicted_time] for pred_time in pred_times: masked_preds = self.predicted_pos[pred_time] + handle_mask if int_position in masked_preds: conflicting_agents = np.where( masked_preds == int_position) for ca in conflicting_agents[0]: cell_transitions = self.rail_env.rail.get_transitions( *position, direction) if direction != self.predicted_dir[pred_time][ca] \ and (np.isnan(self.predicted_dir[pred_time][ca]) or cell_transitions[ reverse_dir(self.predicted_dir[pred_time][ca])] == 1): conflict_handles.append(ca) if break_after_first: break if np.isnan(self.predicted_dir[pred_time][ca]): malf_current = self.rail_env.agents[ ca].malfunction_data['malfunction'] malf_remaining = max( malf_current - tot_dist, 0) malfunctions.append(malf_remaining) tot_dist += 1 positions, directions = self.get_shortest_path_position( position=position, direction=direction, only_branch=only_branch, handle=agent.handle) if break_after_first and len(conflict_handles) > 0: return conflict_handles, malfunctions for pos, dir in zip(positions, directions): new_chs, new_malfs = self.detect_conflicts_multi( tuple(pos), agent, dir, tot_dist=tot_dist, handles=handles) conflict_handles += new_chs malfunctions += new_malfs return conflict_handles, malfunctions
def get_many( self, handles: Optional[List[int]] = None) -> Dict[int, AgentIdNode]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ if handles is None: handles = [] if self.predictor: self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: for t in range(self.predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: if self.predictions[a] is None: continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update( {t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} self.location_has_agent = {} self.location_has_agent_direction = {} self.location_has_agent_speed = {} self.location_has_agent_malfunction = {} self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple( _agent.position)] = _agent.direction self.location_has_agent_speed[tuple( _agent.position)] = _agent.speed_data['speed'] self.location_has_agent_malfunction[tuple( _agent.position)] = _agent.malfunction_data['malfunction'] if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ _agent.initial_position: self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 observations = super().get_many(handles) return observations
def allowed_handles(self, handles=None, positions=None, directions=None): shortest_paths = get_shortest_paths(self.rail_env.distance_map, handles=handles, max_depth=self.max_depth) self.reservations = defaultdict(lambda: []) allowed_handles = [] for h in handles: position = positions[h] direction = directions[h] shortest_path = shortest_paths[h] agent = self.rail_env.agents[h] times_per_cell = int(np.reciprocal(agent.speed_data["speed"])) if position is not None: allowed = True for index in range(1, self.max_depth + 1): int_pos = coordinate_to_position(depth=self.rail_env.width, coords=[position])[0] is_reserved = False replaced_handles = [] for r in self.reservations[int_pos]: cell_transitions = self.rail_env.rail.get_transitions( *position, direction) if direction != r.direction and cell_transitions[ reverse_dir(r.direction)] == 1: if self.rail_env.agents[r.handle].status == RailAgentStatus.ACTIVE or \ self.rail_env.agents[r.handle].status == self.rail_env.agents[h].status: is_reserved = True break else: replaced_handles.append(r.handle) if not is_reserved: self.reservations[int_pos].append( Reservation(handle=h, direction=direction)) for r in replaced_handles: if r in allowed_handles: allowed_handles.remove(r) else: allowed = False if position == agent.target: break if index % times_per_cell == 0: position = shortest_path[0].position direction = shortest_path[0].direction shortest_path = shortest_path[1:] if allowed: allowed_handles.append(h) return allowed_handles
def get_many(self, handles: Optional[List[int]] = None) -> {}: """ Compute observations for all agents in the env. :param handles: :return: """ self.num_active_agents = 0 for a in self.env.agents: if a.status == RailAgentStatus.ACTIVE: self.num_active_agents += 1 self.prediction_dict = self.predictor.get() # Useful to check if occupancy is correctly computed self.cells_sequence = self.predictor.compute_cells_sequence( self.prediction_dict) if self.prediction_dict: self.max_prediction_depth = self.predictor.max_depth for t in range(self.max_prediction_depth): pos_list = [] dir_list = [] for a in handles: if self.prediction_dict[a] is None: continue pos_list.append(self.prediction_dict[a][t][1:3]) dir_list.append(self.prediction_dict[a][t][3]) self.predicted_pos_coord.update({t: pos_list}) self.predicted_pos.update( {t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) for a in range(len(self.env.agents)): pos_list = [] for ts in range(self.max_prediction_depth): pos_list.append( self.predicted_pos[ts][a]) # Use int positions self.predicted_pos_list.update({a: pos_list}) observations = {} for a in handles: observations[a] = self.get(a) return observations
def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_is_switch = False last_is_dead_end = False last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf potential_conflict = np.inf unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 other_agent_ready_to_depart_encountered = 0 found_closest_communication = False communication = None while exploring: # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: if self.location_has_agent_communication[position] is not None and not found_closest_communication: found_closest_communication = True, communication = self.location_has_agent_communication[position] if tot_dist < other_agent_encountered: other_agent_encountered = tot_dist # Check if any of the observed agents is malfunctioning, store agent with longest duration left if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0) if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed # Other direction agents # TODO: Test that this behavior is as expected other_agent_opposite_direction += \ self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), 0) else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += self.location_has_agent[position] # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") crossing_found = False if int(transition_bit, 2) == int('1000010000100001', 2): crossing_found = True # Register possible future conflict predicted_time = int(tot_dist * time_per_cell) if self.predictor and predicted_time < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: pre_step = max(0, predicted_time - 1) post_step = min(self.max_prediction_depth - 1, predicted_time + 1) # Look for conflicting paths at distance tot_dist if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0): conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[ self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] \ and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: if tot_dist < other_target_encountered: other_target_encountered = tot_dist if position == agent.target and tot_dist < own_target_encountered: own_target_encountered = tot_dist # ############################# # ############################# if (position[0], position[1], direction) in visited: last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): last_is_target = True break # Check if crossing is found --> Not an unusable switch if crossing_found: # Treat the crossing as a straight rail cell total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: unusable_switch = tot_dist if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = total_transitions if nbits == 1: # Dead-end! last_is_dead_end = True if not last_is_dead_end: # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: # Switch detected last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) last_is_terminal = True break # `position` is either a terminal node or a switch # ############################# # ############################# # Modify here to append new / different features for each visited cell! if last_is_target: dist_to_next_branch = tot_dist dist_min_to_target = 0 elif last_is_terminal: dist_to_next_branch = np.inf dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] node = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, dist_unusable_switch=unusable_switch, dist_to_next_branch=dist_to_next_branch, dist_min_to_target=dist_min_to_target, num_agents_same_direction=other_agent_same_direction, num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, communication=communication, childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, tot_dist + 1, depth + 1) node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() return node, visited
def get_many(self, handles: Optional[List[int]] = None): """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ if handles is None: handles = [] self._conflict_map = {handle: [] for handle in handles} if self.predictor: self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: for t in range(self.predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: if self.predictions[a] is None: continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update( {t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} self.location_has_agent = {} self.location_has_agent_direction = {} self.location_has_agent_speed = {} self.location_has_agent_malfunction = {} self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple( _agent.position)] = _agent.direction self.location_has_agent_speed[tuple( _agent.position)] = _agent.speed_data['speed'] self.location_has_agent_malfunction[tuple( _agent.position)] = _agent.malfunction_data['malfunction'] if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ _agent.initial_position: self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 obs_dict: Dict = super().get_many(handles) if self.use_priority: priorities = GreedyGraphColoring.color( colors=[1, 0], nodes=obs_dict.keys(), neighbors=self._conflict_map) for handle, obs in obs_dict.items(): if obs is not None: obs_dict[handle] = obs._replace( dist_own_target_encountered=priorities[handle]) return obs_dict
def detect_conflicts(self, tot_dist, time_per_cell, position, cell_transitions, handle, direction): potential_conflict = np.inf conflict_handle = None predicted_time = int(tot_dist * time_per_cell) if self.predictor and predicted_time < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: pre_step = max(0, predicted_time - 1) post_step = min(self.max_prediction_depth - 1, predicted_time + 1) # Look for conflicting paths at distance tot_dist if int_position in np.delete( self.predicted_pos[predicted_time], handle, 0): conflicting_agent = np.where( self.predicted_pos[predicted_time] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[predicted_time][ ca] and cell_transitions[self._reverse_dir( self.predicted_dir[predicted_time][ca] )] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist conflict_handle = ca if self.env.agents[ ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist conflict_handle = ca # Look for conflicting paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where( self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[pre_step][ca] \ and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist conflict_handle = ca if self.env.agents[ ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist conflict_handle = ca # Look for conflicting paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where( self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist conflict_handle = ca if self.env.agents[ ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist conflict_handle = ca return potential_conflict, conflict_handle
def get_many(self, handles: Optional[List[int]] = None): self._shortest_path_conflict_map = {handle: [] for handle in handles} self._other_path_conflict_map = {handle: [] for handle in handles} if self.predictor: self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: for t in range(self.predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: if self.predictions[a] is None: continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update( {t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} self.location_has_agent = {} self.location_has_agent_direction = {} self.location_has_agent_speed = {} self.location_has_agent_malfunction = {} self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple( _agent.position)] = _agent.direction self.location_has_agent_speed[tuple( _agent.position)] = _agent.speed_data['speed'] self.location_has_agent_malfunction[tuple( _agent.position)] = _agent.malfunction_data['malfunction'] if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ _agent.initial_position: self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 self._conflict_map = {handle: [] for handle in handles} obs_dict = {handle: self.get(handle) for handle in handles} # the order of the colors matters sp_priorities = GreedyGraphColoring.color( colors=[1, 0], nodes=obs_dict.keys(), neighbors=self._shortest_path_conflict_map) op_priorities = GreedyGraphColoring.color( colors=[1, 0], nodes=obs_dict.keys(), neighbors=self._other_path_conflict_map) for handle, obs in obs_dict.items(): if obs is not None: obs[0][6] = sp_priorities[handle] obs[0][13] = op_priorities[handle] if self._asserts: assert [ sp_priorities[h] != [sp_priorities[ch] for ch in chs] for h, chs in self._shortest_path_conflict_map.items() ] assert [ op_priorities[h] != [op_priorities[ch] for ch in chs] for h, chs in self._other_path_conflict_map.items() ] self._prev_sp_prios = sp_priorities self._prev_other_prios = op_priorities self._prev_other_path_conflict_map = self._other_path_conflict_map self._prev_shortest_path_conflict_map = self._shortest_path_conflict_map return obs_dict
def test_coordinate_to_position(): actual_positions = coordinate_to_position(depth_to_test, coordinates_to_test) expected_positions = positions_to_test assert np.array_equal(actual_positions, expected_positions), \ "converted positions {}, expected {}".format(actual_positions, expected_positions)