Пример #1
0
 def _compute_overlapping_paths_with_current_ts(self, handle):
     """
     """
     agent = self.env.agents[handle]
     overlapping_paths = np.zeros(
         (self.env.get_num_agents(), self.max_prediction_depth + 1),
         dtype=int)
     cells_sequence = self.predicted_pos_list[handle]
     # Prepend current ts
     if agent.status == RailAgentStatus.ACTIVE:
         virtual_position = agent.position
     elif agent.status == RailAgentStatus.READY_TO_DEPART:
         virtual_position = agent.initial_position
     int_pos = coordinate_to_position(self.env.width, [virtual_position])
     cells_sequence = np.append(int_pos[0], cells_sequence)
     for a in range(len(self.env.agents)):
         if a != handle and self.env.agents[
                 a].status == RailAgentStatus.ACTIVE:
             i = 0
             # Prepend other agents current ts
             other_agent_cells_sequence = self.predicted_pos_list[a]
             other_int_pos = coordinate_to_position(
                 self.env.width, [self.env.agents[a].position])
             other_agent_cells_sequence = np.append(
                 other_int_pos[0], other_agent_cells_sequence)
             for pos in cells_sequence:
                 if pos in other_agent_cells_sequence:
                     overlapping_paths[a, i] = 1
                 i += 1
     return overlapping_paths
Пример #2
0
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
        """
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
        """

        if handles is None:
            handles = []
        if self.predictor:
            self.max_prediction_depth = 0
            self.predicted_pos = {}
            self.predicted_dir = {}
            self.predictions = self.predictor.get()
            if self.predictions:
                for t in range(self.predictor.max_depth + 1):
                    pos_list = []
                    dir_list = []
                    for a in handles:
                        if self.predictions[a] is None:
                            continue
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)

        observations = super().get_many(handles)

        return observations
Пример #3
0
 def map_predictions(self, handles=None, positions=None, directions=None):
     if handles is None:
         handles = [a.handle for a in self.rail_env.agents]
     if self._predictor:
         self.max_prediction_depth = 0
         self.predicted_pos = {}
         self.predicted_dir = {}
         predictions = self._predictor.get(handles=handles,
                                           positions=positions,
                                           directions=directions)
         if predictions:
             for t in range(self._predictor.max_depth + 1):
                 pos_list = []
                 dir_list = []
                 for a in handles:
                     if predictions[a] is None:
                         continue
                     pos_list.append(predictions[a][t][1:3])
                     dir_list.append(predictions[a][t][3])
                 self.predicted_pos.update({
                     t:
                     coordinate_to_position(self.rail_env.width, pos_list)
                 })
                 self.predicted_dir.update({t: dir_list})
             self.max_prediction_depth = len(self.predicted_pos)
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
        '''
        Because we do not want to call the predictor seperately for every agent we implement the get_many function
        Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
        :param handles:
        :return:
        '''

        self.predictions = self.predictor.get()

        self.predicted_pos = {}

        if handles is None:
            handles = []

        for t in range(len(self.predictions[0])):
            pos_list = []
            for a in handles:
                pos_list.append(self.predictions[a][t][1:3])
            # We transform (x,y) coodrinates to a single integer number for simpler comparison
            self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})

        observations = super().get_many(handles)

        return observations
Пример #5
0
 def _get_predictions(self, timestep: int, handles: list, predictions):
     positions = np.zeros(len(handles))
     directions = np.zeros(len(handles))
     for i, h in enumerate(handles):
         positions[i] = coordinate_to_position(
             self.rail_env.width, [predictions[h][timestep][1:3]])[0]
         directions[i] = (predictions[h][timestep][3])
     return positions, directions
Пример #6
0
    def detect_conflicts_multi(self,
                               position,
                               agent,
                               direction,
                               handles=None,
                               break_after_first=False,
                               only_branch=False,
                               tot_dist=1):
        conflict_handles = []
        time_per_cell = int(np.reciprocal(agent.speed_data["speed"]))
        predicted_time = int(tot_dist * time_per_cell)
        handle_mask = np.zeros(len(handles))
        handle_mask[handles.index(agent.handle)] = np.inf
        malfunctions = []
        if predicted_time < self.max_prediction_depth and position != agent.target:
            int_position = coordinate_to_position(self.rail_env.width,
                                                  [position])
            if tot_dist < self.max_prediction_depth:

                pred_times = [max(0, predicted_time - 1), predicted_time]

                for pred_time in pred_times:
                    masked_preds = self.predicted_pos[pred_time] + handle_mask
                    if int_position in masked_preds:
                        conflicting_agents = np.where(
                            masked_preds == int_position)
                        for ca in conflicting_agents[0]:
                            cell_transitions = self.rail_env.rail.get_transitions(
                                *position, direction)
                            if direction != self.predicted_dir[pred_time][ca] \
                                    and (np.isnan(self.predicted_dir[pred_time][ca]) or cell_transitions[
                                reverse_dir(self.predicted_dir[pred_time][ca])] == 1):
                                conflict_handles.append(ca)
                                if break_after_first:
                                    break
                                if np.isnan(self.predicted_dir[pred_time][ca]):
                                    malf_current = self.rail_env.agents[
                                        ca].malfunction_data['malfunction']
                                    malf_remaining = max(
                                        malf_current - tot_dist, 0)
                                    malfunctions.append(malf_remaining)

            tot_dist += 1
            positions, directions = self.get_shortest_path_position(
                position=position,
                direction=direction,
                only_branch=only_branch,
                handle=agent.handle)
            if break_after_first and len(conflict_handles) > 0:
                return conflict_handles, malfunctions

            for pos, dir in zip(positions, directions):
                new_chs, new_malfs = self.detect_conflicts_multi(
                    tuple(pos), agent, dir, tot_dist=tot_dist, handles=handles)
                conflict_handles += new_chs
                malfunctions += new_malfs

        return conflict_handles, malfunctions
Пример #7
0
    def get_many(
            self,
            handles: Optional[List[int]] = None) -> Dict[int, AgentIdNode]:
        """
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
        """

        if handles is None:
            handles = []
        if self.predictor:
            self.max_prediction_depth = 0
            self.predicted_pos = {}
            self.predicted_dir = {}
            self.predictions = self.predictor.get()
            if self.predictions:
                for t in range(self.predictor.max_depth + 1):
                    pos_list = []
                    dir_list = []
                    for a in handles:
                        if self.predictions[a] is None:
                            continue
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update(
                        {t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}

        for _agent in self.env.agents:
            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
                    _agent.position:
                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(
                    _agent.position)] = _agent.direction
                self.location_has_agent_speed[tuple(
                    _agent.position)] = _agent.speed_data['speed']
                self.location_has_agent_malfunction[tuple(
                    _agent.position)] = _agent.malfunction_data['malfunction']

            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
                    _agent.initial_position:
                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1

        observations = super().get_many(handles)

        return observations
Пример #8
0
    def allowed_handles(self, handles=None, positions=None, directions=None):
        shortest_paths = get_shortest_paths(self.rail_env.distance_map,
                                            handles=handles,
                                            max_depth=self.max_depth)
        self.reservations = defaultdict(lambda: [])
        allowed_handles = []

        for h in handles:
            position = positions[h]
            direction = directions[h]
            shortest_path = shortest_paths[h]
            agent = self.rail_env.agents[h]
            times_per_cell = int(np.reciprocal(agent.speed_data["speed"]))
            if position is not None:
                allowed = True
                for index in range(1, self.max_depth + 1):

                    int_pos = coordinate_to_position(depth=self.rail_env.width,
                                                     coords=[position])[0]

                    is_reserved = False
                    replaced_handles = []
                    for r in self.reservations[int_pos]:
                        cell_transitions = self.rail_env.rail.get_transitions(
                            *position, direction)
                        if direction != r.direction and cell_transitions[
                                reverse_dir(r.direction)] == 1:
                            if self.rail_env.agents[r.handle].status == RailAgentStatus.ACTIVE or \
                               self.rail_env.agents[r.handle].status == self.rail_env.agents[h].status:
                                is_reserved = True
                                break
                            else:
                                replaced_handles.append(r.handle)

                    if not is_reserved:
                        self.reservations[int_pos].append(
                            Reservation(handle=h, direction=direction))
                        for r in replaced_handles:
                            if r in allowed_handles:
                                allowed_handles.remove(r)
                    else:
                        allowed = False

                    if position == agent.target:
                        break

                    if index % times_per_cell == 0:
                        position = shortest_path[0].position
                        direction = shortest_path[0].direction

                        shortest_path = shortest_path[1:]

                if allowed:
                    allowed_handles.append(h)

        return allowed_handles
Пример #9
0
    def get_many(self, handles: Optional[List[int]] = None) -> {}:
        """
        Compute observations for all agents in the env.
        :param handles: 
        :return: 
        """

        self.num_active_agents = 0
        for a in self.env.agents:
            if a.status == RailAgentStatus.ACTIVE:
                self.num_active_agents += 1
        self.prediction_dict = self.predictor.get()
        # Useful to check if occupancy is correctly computed
        self.cells_sequence = self.predictor.compute_cells_sequence(
            self.prediction_dict)

        if self.prediction_dict:
            self.max_prediction_depth = self.predictor.max_depth
            for t in range(self.max_prediction_depth):
                pos_list = []
                dir_list = []
                for a in handles:
                    if self.prediction_dict[a] is None:
                        continue
                    pos_list.append(self.prediction_dict[a][t][1:3])
                    dir_list.append(self.prediction_dict[a][t][3])
                self.predicted_pos_coord.update({t: pos_list})
                self.predicted_pos.update(
                    {t: coordinate_to_position(self.env.width, pos_list)})
                self.predicted_dir.update({t: dir_list})

            for a in range(len(self.env.agents)):
                pos_list = []
                for ts in range(self.max_prediction_depth):
                    pos_list.append(
                        self.predicted_pos[ts][a])  # Use int positions
                self.predicted_pos_list.update({a: pos_list})

        observations = {}
        for a in handles:
            observations[a] = self.get(a)
        return observations
Пример #10
0
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
        """
        Utility function to compute tree-based observations.
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
        """

        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
            return [], []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False

        visited = OrderedSet()
        agent = self.env.agents[handle]
        time_per_cell = np.reciprocal(agent.speed_data["speed"])
        own_target_encountered = np.inf
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
        potential_conflict = np.inf
        unusable_switch = np.inf
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
        malfunctioning_agent = 0
        min_fractional_speed = 1.
        num_steps = 1
        other_agent_ready_to_depart_encountered = 0
        found_closest_communication = False
        communication = None

        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
                if self.location_has_agent_communication[position] is not None and not found_closest_communication:
                    found_closest_communication = True,
                    communication = self.location_has_agent_communication[position]

                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist

                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[position]

                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)

                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
                    other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)

                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed

                    # Other direction agents
                    # TODO: Test that this behavior is as expected
                    other_agent_opposite_direction += \
                        self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction),
                                                                                                  0)

                else:
                    # If no agent in the same direction was found all agents in that position are other direction
                    other_agent_opposite_direction += self.location_has_agent[position]

                # Check number of possible transitions for agent and total number of transitions in cell (type)
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

            # Register possible future conflict
            predicted_time = int(tot_dist * time_per_cell)
            if self.predictor and predicted_time < self.max_prediction_depth:
                int_position = coordinate_to_position(self.env.width, [position])
                if tot_dist < self.max_prediction_depth:

                    pre_step = max(0, predicted_time - 1)
                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)

                    # Look for conflicting paths at distance tot_dist
                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
                                self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

                    # Look for conflicting paths at distance num_step-1
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

                    # Look for conflicting paths at distance num_step+1
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

            if position in self.location_has_target and position != agent.target:
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist

            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist

            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
                last_is_terminal = True
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
                last_is_target = True
                break

            # Check if crossing is found --> Not an unusable switch
            if crossing_found:
                # Treat the crossing as a straight rail cell
                total_transitions = 2
            num_transitions = np.count_nonzero(cell_transitions)

            exploring = False

            # Detect Switches that can only be used by other agents.
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
                unusable_switch = tot_dist

            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
                nbits = total_transitions
                if nbits == 1:
                    # Dead-end!
                    last_is_dead_end = True

                if not last_is_dead_end:
                    # Keep walking through the tree along `direction`
                    exploring = True
                    # convert one-hot encoding to 0,1,2,3
                    direction = np.argmax(cell_transitions)
                    position = get_new_position(position, direction)
                    num_steps += 1
                    tot_dist += 1
            elif num_transitions > 0:
                # Switch detected
                last_is_switch = True
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
                last_is_terminal = True
                break

        # `position` is either a terminal node or a switch

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!

        if last_is_target:
            dist_to_next_branch = tot_dist
            dist_min_to_target = 0
        elif last_is_terminal:
            dist_to_next_branch = np.inf
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
        else:
            dist_to_next_branch = tot_dist
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]

        node = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
                                      dist_other_target_encountered=other_target_encountered,
                                      dist_other_agent_encountered=other_agent_encountered,
                                      dist_potential_conflict=potential_conflict,
                                      dist_unusable_switch=unusable_switch,
                                      dist_to_next_branch=dist_to_next_branch,
                                      dist_min_to_target=dist_min_to_target,
                                      num_agents_same_direction=other_agent_same_direction,
                                      num_agents_opposite_direction=other_agent_opposite_direction,
                                      num_agents_malfunctioning=malfunctioning_agent,
                                      speed_min_fractional=min_fractional_speed,
                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
                                      communication=communication,
                                      childs={})

        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
        possible_transitions = self.env.rail.get_transitions(*position, direction)
        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
                                                                          tot_dist + 1,
                                                                          depth + 1)
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            elif last_is_switch and possible_transitions[branch_direction]:
                new_cell = get_new_position(position, branch_direction)
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
                                                                          tot_dist + 1,
                                                                          depth + 1)
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            else:
                # no exploring possible, add just cells with infinity
                node.childs[self.tree_explored_actions_char[i]] = -np.inf

        if depth == self.max_depth:
            node.childs.clear()
        return node, visited
Пример #11
0
    def get_many(self, handles: Optional[List[int]] = None):
        """
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
        """

        if handles is None:
            handles = []

        self._conflict_map = {handle: [] for handle in handles}

        if self.predictor:
            self.max_prediction_depth = 0
            self.predicted_pos = {}
            self.predicted_dir = {}
            self.predictions = self.predictor.get()
            if self.predictions:
                for t in range(self.predictor.max_depth + 1):
                    pos_list = []
                    dir_list = []
                    for a in handles:
                        if self.predictions[a] is None:
                            continue
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update(
                        {t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}

        for _agent in self.env.agents:
            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
                    _agent.position:
                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(
                    _agent.position)] = _agent.direction
                self.location_has_agent_speed[tuple(
                    _agent.position)] = _agent.speed_data['speed']
                self.location_has_agent_malfunction[tuple(
                    _agent.position)] = _agent.malfunction_data['malfunction']

            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
                    _agent.initial_position:
                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1

        obs_dict: Dict = super().get_many(handles)

        if self.use_priority:
            priorities = GreedyGraphColoring.color(
                colors=[1, 0],
                nodes=obs_dict.keys(),
                neighbors=self._conflict_map)

            for handle, obs in obs_dict.items():
                if obs is not None:
                    obs_dict[handle] = obs._replace(
                        dist_own_target_encountered=priorities[handle])

        return obs_dict
Пример #12
0
    def detect_conflicts(self, tot_dist, time_per_cell, position,
                         cell_transitions, handle, direction):
        potential_conflict = np.inf
        conflict_handle = None
        predicted_time = int(tot_dist * time_per_cell)
        if self.predictor and predicted_time < self.max_prediction_depth:
            int_position = coordinate_to_position(self.env.width, [position])
            if tot_dist < self.max_prediction_depth:

                pre_step = max(0, predicted_time - 1)
                post_step = min(self.max_prediction_depth - 1,
                                predicted_time + 1)

                # Look for conflicting paths at distance tot_dist
                if int_position in np.delete(
                        self.predicted_pos[predicted_time], handle, 0):
                    conflicting_agent = np.where(
                        self.predicted_pos[predicted_time] == int_position)
                    for ca in conflicting_agent[0]:
                        if direction != self.predicted_dir[predicted_time][
                                ca] and cell_transitions[self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca]
                                )] == 1 and tot_dist < potential_conflict:
                            potential_conflict = tot_dist
                            conflict_handle = ca
                        if self.env.agents[
                                ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                            potential_conflict = tot_dist
                            conflict_handle = ca

                # Look for conflicting paths at distance num_step-1
                elif int_position in np.delete(self.predicted_pos[pre_step],
                                               handle, 0):
                    conflicting_agent = np.where(
                        self.predicted_pos[pre_step] == int_position)
                    for ca in conflicting_agent[0]:
                        if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                            potential_conflict = tot_dist
                            conflict_handle = ca
                        if self.env.agents[
                                ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                            potential_conflict = tot_dist
                            conflict_handle = ca

                # Look for conflicting paths at distance num_step+1
                elif int_position in np.delete(self.predicted_pos[post_step],
                                               handle, 0):
                    conflicting_agent = np.where(
                        self.predicted_pos[post_step] == int_position)
                    for ca in conflicting_agent[0]:
                        if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                            potential_conflict = tot_dist
                            conflict_handle = ca
                        if self.env.agents[
                                ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                            potential_conflict = tot_dist
                            conflict_handle = ca

        return potential_conflict, conflict_handle
Пример #13
0
    def get_many(self, handles: Optional[List[int]] = None):

        self._shortest_path_conflict_map = {handle: [] for handle in handles}
        self._other_path_conflict_map = {handle: [] for handle in handles}

        if self.predictor:
            self.max_prediction_depth = 0
            self.predicted_pos = {}
            self.predicted_dir = {}
            self.predictions = self.predictor.get()
            if self.predictions:
                for t in range(self.predictor.max_depth + 1):
                    pos_list = []
                    dir_list = []
                    for a in handles:
                        if self.predictions[a] is None:
                            continue
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update(
                        {t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}

        for _agent in self.env.agents:
            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
                    _agent.position:
                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(
                    _agent.position)] = _agent.direction
                self.location_has_agent_speed[tuple(
                    _agent.position)] = _agent.speed_data['speed']
                self.location_has_agent_malfunction[tuple(
                    _agent.position)] = _agent.malfunction_data['malfunction']

            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
                    _agent.initial_position:
                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1

        self._conflict_map = {handle: [] for handle in handles}
        obs_dict = {handle: self.get(handle) for handle in handles}

        # the order of the colors matters
        sp_priorities = GreedyGraphColoring.color(
            colors=[1, 0],
            nodes=obs_dict.keys(),
            neighbors=self._shortest_path_conflict_map)

        op_priorities = GreedyGraphColoring.color(
            colors=[1, 0],
            nodes=obs_dict.keys(),
            neighbors=self._other_path_conflict_map)
        for handle, obs in obs_dict.items():
            if obs is not None:
                obs[0][6] = sp_priorities[handle]
                obs[0][13] = op_priorities[handle]

        if self._asserts:
            assert [
                sp_priorities[h] != [sp_priorities[ch] for ch in chs]
                for h, chs in self._shortest_path_conflict_map.items()
            ]
            assert [
                op_priorities[h] != [op_priorities[ch] for ch in chs]
                for h, chs in self._other_path_conflict_map.items()
            ]

        self._prev_sp_prios = sp_priorities
        self._prev_other_prios = op_priorities
        self._prev_other_path_conflict_map = self._other_path_conflict_map
        self._prev_shortest_path_conflict_map = self._shortest_path_conflict_map

        return obs_dict
def test_coordinate_to_position():
    actual_positions = coordinate_to_position(depth_to_test,
                                              coordinates_to_test)
    expected_positions = positions_to_test
    assert np.array_equal(actual_positions, expected_positions), \
        "converted positions {}, expected {}".format(actual_positions, expected_positions)