コード例 #1
0
    def get(self, handle):
        '''
        Build the prediction for the given agent
        '''
        agent = self.env.agents[handle]
        if agent.status == RailAgentStatus.DONE_REMOVED or agent.status == RailAgentStatus.DONE:
            return None

        # Build predictions
        lenght, path = self.get_shortest_path(handle)
        edges = self.env.railway_encoding.edges_from_path(
            path[:self.max_depth]
        )
        pos = self.env.railway_encoding.positions_from_path(
            path[:self.max_depth]
        )
        shortest_path_prediction = Prediction(
            lenght=lenght, path=path[:self.max_depth], edges=edges, positions=pos
        )
        deviation_paths_prediction = self.get_deviation_paths(
            handle, lenght, path
        )

        # Update GUI
        visited = OrderedSet()
        visited.update(shortest_path_prediction.positions)
        self.env.dev_pred_dict[handle] = visited

        return (shortest_path_prediction, deviation_paths_prediction)
コード例 #2
0
    def check_path_exists(self, start: IntVector2DArray, direction: int,
                          end: IntVector2DArray):
        """
        Breath first search for a possible path from one node with a certain orientation to a target node.
        :param start: Start cell rom where we want to check the path
        :param direction: Start direction for the path we are testing
        :param end: Cell that we try to reach from the start cell
        :return: True if a path exists, False otherwise
        """
        visited = OrderedSet()
        stack = [(start, direction)]
        while stack:
            node = stack.pop()
            node_position = node[0]
            node_direction = node[1]

            if Vec2d.is_equal(node_position, end):
                return True
            if node not in visited:
                visited.add(node)

                moves = self.get_transitions(node_position[0],
                                             node_position[1], node_direction)
                for move_index in range(4):
                    if moves[move_index]:
                        stack.append(
                            (get_new_position(node_position,
                                              move_index), move_index))

        return False
コード例 #3
0
 def is_simple_turn(trans):
     all_simple_turns = OrderedSet()
     for trans in [int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
                   int('0001001000000000', 2)  # Case 1c (9)  - simple turn left]:
                   ]:
         for _ in range(3):
             trans = self.transitions.rotate_transition(trans, rotation=90)
             all_simple_turns.add(trans)
     return trans in all_simple_turns
コード例 #4
0
    def __init__(self):
        super(RailEnvTransitions,
              self).__init__(transitions=self.transition_list)

        # create this to make validation faster
        self.transitions_all = OrderedSet()
        for index, trans in enumerate(self.transitions):
            self.transitions_all.add(trans)
            if index in (2, 4, 6, 7, 8, 9, 10):
                for _ in range(3):
                    trans = self.rotate_transition(trans, rotation=90)
                    self.transitions_all.add(trans)
            elif index in (1, 5):
                trans = self.rotate_transition(trans, rotation=90)
                self.transitions_all.add(trans)
コード例 #5
0
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
                            agent_position: Tuple[int, int],
                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
    """
    Get the valid move actions (forward, left, right) for an agent.

    TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
    and more elegant. But given the few calls this has no priority now.

    Parameters
    ----------
    agent_direction : Grid4TransitionsEnum
    agent_position: Tuple[int,int]
    rail : GridTransitionMap


    Returns
    -------
    Set of `RailEnvNextAction` (tuples of (action,position,direction))
        Possible move actions (forward,left,right) and the next position/direction they lead to.
        It is not checked that the next cell is free.
    """
    valid_actions: Set[RailEnvNextAction] = OrderedSet()
    possible_transitions = rail.get_transitions(*agent_position,
                                                agent_direction)
    num_transitions = np.count_nonzero(possible_transitions)
    # Start from the current orientation, and see which transitions are available;
    # organize them as [left, forward, right], relative to the current orientation
    # If only one transition is possible, the forward branch is aligned with it.
    if rail.is_dead_end(agent_position):
        action = RailEnvActions.MOVE_FORWARD
        exit_direction = (agent_direction + 2) % 4
        if possible_transitions[exit_direction]:
            new_position = get_new_position(agent_position, exit_direction)
            valid_actions.add(
                RailEnvNextAction(action, new_position, exit_direction))
    elif num_transitions == 1:
        action = RailEnvActions.MOVE_FORWARD
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                new_position = get_new_position(agent_position, new_direction)
                valid_actions.add(
                    RailEnvNextAction(action, new_position, new_direction))
    else:
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                if new_direction == agent_direction:
                    action = RailEnvActions.MOVE_FORWARD
                elif new_direction == (agent_direction + 1) % 4:
                    action = RailEnvActions.MOVE_RIGHT
                elif new_direction == (agent_direction - 1) % 4:
                    action = RailEnvActions.MOVE_LEFT
                else:
                    raise Exception("Illegal state")

                new_position = get_new_position(agent_position, new_direction)
                valid_actions.add(
                    RailEnvNextAction(action, new_position, new_direction))
    return valid_actions
コード例 #6
0
    def get(self, handle: int = 0) -> np.ndarray:
        '''
        Lets write a simple observation which just indicates whether or not the own predicted path
        overlaps with other predicted paths at any time. This is useless for the task of navigation but might
        help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class

        Each agent recieves an observation of length 10, where each element represents a prediction step and its value
        is:
         - 0 if no overlap is happening
         - 1 where n i the number of other paths crossing the predicted cell

        :param handle: handeled as an index of an agent
        :return: Observation of handle
        '''

        observation = np.zeros(10)

        # We are going to track what cells where considered while building the obervation and make them accesible
        # For rendering

        visited = OrderedSet()
        for _idx in range(10):
            # Check if any of the other prediction overlap with agents own predictions
            x_coord = self.predictions[handle][_idx][1]
            y_coord = self.predictions[handle][_idx][2]

            # We add every observed cell to the observation rendering
            visited.add((x_coord, y_coord))
            if self.predicted_pos[_idx][handle] in np.delete(
                    self.predicted_pos[_idx], handle, 0):
                # We detect if another agent is predicting to pass through the same cell at the same predicted time
                observation[handle] = 1

        # This variable will be access by the renderer to visualize the observation
        self.env.dev_obs_dict[handle] = visited

        return observation
コード例 #7
0
    def prediction_from_path(self, handle, path):
        agent = self.env.agents[handle]
        prediction = np.zeros(shape=(self.max_depth + 1, 5), dtype=int)

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE:
            agent_virtual_position = agent.target
        else:  # agent.status == DONE_REMOVED, prediction must be None
            return self.empty_prediction

        agent_virtual_direction = agent.direction
        agent_speed = agent.speed_data["speed"]
        times_per_cell = int(np.reciprocal(agent_speed))
        # First cell is info relative to actual time step
        prediction[0] = [0, *agent_virtual_position,
                            agent_virtual_direction, RailEnvActions.MOVE_FORWARD] # TODO dell'action

        # If there is a shortest path, remove the initial position
        if path:
            path = path[1:]

        new_direction = agent_virtual_direction
        new_position = agent_virtual_position
        visited = OrderedSet()
        for index in range(1, self.max_depth + 1):
            action = RailEnvActions.MOVE_FORWARD
            # If we're at the target or not moving, stop moving until max_depth is reached
            # if new_position == agent.target or not agent.moving or not path:
            # Writing like this you don't consider the fact that the agent is stopped
            if new_position == agent.target or not path:
                prediction[index] = [index, *new_position,
                                        new_direction, RailEnvActions.STOP_MOVING]
                visited.add((*new_position, agent.direction))
                continue

            if index % times_per_cell == 0:
                new_position = path[0].position
                new_direction = path[0].direction

                action = path[0][2].action

                path = path[1:]

            # Prediction is ready
            prediction[index] = [index, *new_position, new_direction, action]
            visited.add((*new_position, new_direction))

        return prediction
コード例 #8
0
    def get(self, handle: int = None):
        """
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.

        Parameters
        ----------
        handle : int, optional
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        np.array
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here (not implemented yet)
            The prediction at 0 is the current position, direction etc.
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map,
                                            max_depth=self.max_depth)

        prediction_dict = {}
        for agent in agents:

            if agent.status == RailAgentStatus.READY_TO_DEPART:
                agent_virtual_position = agent.initial_position
            elif agent.status == RailAgentStatus.ACTIVE:
                agent_virtual_position = agent.position
            elif agent.status == RailAgentStatus.DONE:
                agent_virtual_position = agent.target
            else:

                prediction = np.zeros(shape=(self.max_depth + 1, 5))
                for i in range(self.max_depth):
                    prediction[i] = [i, None, None, None, None]
                prediction_dict[agent.handle] = prediction
                continue

            agent_virtual_direction = agent.direction
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
            prediction[0] = [
                0, *agent_virtual_position, agent_virtual_direction, 0
            ]

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

            new_direction = agent_virtual_direction
            new_position = agent_virtual_position
            visited = OrderedSet()
            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving until max_depth is reached
                if new_position == agent.target or not shortest_path:
                    prediction[index] = [
                        index, *new_position, new_direction,
                        RailEnvActions.STOP_MOVING
                    ]
                    visited.add((*new_position, agent.direction))
                    continue

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]

                # prediction is ready
                prediction[index] = [index, *new_position, new_direction, 0]
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
            self.env.dev_pred_dict[agent.handle] = visited
            prediction_dict[agent.handle] = prediction

        return prediction_dict
コード例 #9
0
    def get(self,
            handle: int = None,
            handles=None,
            positions=None,
            directions=None):
        """
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.

        Parameters
        ----------
        handle : int, optional
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        np.array
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here (not implemented yet)
            The prediction at 0 is the current position, direction etc.
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        if handles is not None:
            agents = [self.env.agents[h] for h in handles]

        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map,
                                            handles=handles,
                                            max_depth=self.max_depth,
                                            branch_only=self.branch_only)

        prediction_dict = {}

        agents = deepcopy(agents)
        for agent in agents:

            if not agent.status == RailAgentStatus.ACTIVE and not agent.status == RailAgentStatus.READY_TO_DEPART:
                prediction = np.zeros(shape=(self.max_depth + 1, 5))
                for i in range(self.max_depth):
                    prediction[i] = [i, None, None, None, None]
                prediction_dict[agent.handle] = prediction
                continue

            agent_virtual_direction = agent.direction
            agent_virtual_position = agent.position
            if agent.status == RailAgentStatus.READY_TO_DEPART:
                agent_virtual_position = agent.initial_position

            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
            prediction = np.zeros(shape=(self.max_depth + 1, 5))

            prediction[0] = [
                0, *agent_virtual_position, agent_virtual_direction, 0
            ]

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

            if positions is not None and positions.get(agent.handle,
                                                       None) is not None:
                new_direction = directions[agent.handle]
                new_position = positions[agent.handle]
            else:
                new_direction = agent_virtual_direction
                new_position = agent_virtual_position
            visited = OrderedSet()
            for index in range(1, self.max_depth + 1):

                if self.branch_only:
                    cell_transitions = self.env.rail.get_transitions(
                        *new_position, new_direction)
                    if np.count_nonzero(cell_transitions) > 1:
                        break

                if not shortest_path:
                    prediction[index] = [
                        index, *new_position, new_direction,
                        RailEnvActions.STOP_MOVING
                    ]
                    visited.add((*new_position, agent.direction))
                    continue

                if agent.malfunction_data["malfunction"] > 0:
                    agent.malfunction_data["malfunction"] -= 1
                    prediction[index] = [
                        index, *new_position, None, RailEnvActions.STOP_MOVING
                    ]
                    visited.add((*new_position, agent.direction))
                    continue

                if new_position == agent.target:
                    prediction[index] = [
                        index, *new_position, new_direction,
                        RailEnvActions.STOP_MOVING
                    ]
                    visited.add((*new_position, agent.direction))
                    break

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]

                # prediction is ready
                prediction[index] = [index, *new_position, new_direction, 0]
                visited.add((*new_position, new_direction))

            self.env.dev_pred_dict[agent.handle] = visited
            prediction_dict[agent.handle] = prediction

        return prediction_dict
コード例 #10
0
def get_k_shortest_paths(env: RailEnv,
                         source_position: Tuple[int, int],
                         source_direction: int,
                         target_position=Tuple[int, int],
                         k: int = 1,
                         debug=False) -> List[Tuple[Waypoint]]:
    """
    Computes the k shortest paths using modified Dijkstra
    following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
    In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths.

    Parameters
    ----------
    env :             RailEnv
    source_position:  Tuple[int,int]
    source_direction: int
    target_position:  Tuple[int,int]
    k :               int
        max number of shortest paths
    debug:            bool
        print debug statements

    Returns
    -------
    List[Tuple[WalkingElement]]
        We use tuples since we need the path elements to be hashable.
        We use a list of paths in order to keep the order of length.
    """

    # P: set of shortest paths from s to t
    # P =empty,
    shortest_paths: List[Tuple[Waypoint]] = []

    # countu: number of shortest paths found to node u
    # countu = 0, for all u in V
    count = {(r, c, d): 0
             for r in range(env.height) for c in range(env.width)
             for d in range(4)}

    # B is a heap data structure containing paths
    # N.B. use OrderedSet to make result deterministic!
    heap: OrderedSet[Tuple[Waypoint]] = OrderedSet()

    # insert path Ps = {s} into B with cost 0
    heap.add((Waypoint(source_position, source_direction), ))

    # while B is not empty and countt < K:
    while len(heap) > 0 and len(shortest_paths) < k:
        if debug:
            print("iteration heap={}, shortest_paths={}".format(
                heap, shortest_paths))
        # – let Pu be the shortest cost path in B with cost C
        cost = np.inf
        pu = None
        for path in heap:
            if len(path) < cost:
                pu = path
                cost = len(path)
        u: Waypoint = pu[-1]
        if debug:
            print("  looking at pu={}".format(pu))

        #     – B = B − {Pu }
        heap.remove(pu)
        #     – countu = countu + 1

        urcd = (*u.position, u.direction)
        count[urcd] += 1

        # – if u = t then P = P U {Pu}
        if u.position == target_position:
            if debug:
                print(" found of length {} {}".format(len(pu), pu))
            shortest_paths.append(pu)

        # – if countu ≤ K then
        # CAVEAT: do not allow for loopy paths
        elif count[urcd] <= k:
            possible_transitions = env.rail.get_transitions(*urcd)
            if debug:
                print("  looking at neighbors of u={}, transitions are {}".
                      format(u, possible_transitions))
            #     for each vertex v adjacent to u:
            for new_direction in range(4):
                if debug:
                    print("        looking at new_direction={}".format(
                        new_direction))
                if possible_transitions[new_direction]:
                    new_position = get_new_position(u.position, new_direction)
                    if debug:
                        print("        looking at neighbor v={}".format(
                            (*new_position, new_direction)))

                    v = Waypoint(position=new_position,
                                 direction=new_direction)
                    # CAVEAT: do not allow for loopy paths
                    if v in pu:
                        continue

                    # – let Pv be a new path with cost C + w(u, v) formed by concatenating edge (u, v) to path Pu
                    pv = pu + (v, )
                    #     – insert Pv into B
                    heap.add(pv)

    # return P
    return shortest_paths
コード例 #11
0
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
        """
        Utility function to compute tree-based observations.
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
        """

        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
            return [], []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False

        visited = OrderedSet()
        agent = self.env.agents[handle]
        time_per_cell = np.reciprocal(agent.speed_data["speed"])
        own_target_encountered = 0
        other_agent_encountered = 0
        other_target_encountered = 0
        potential_conflict = np.inf
        unusable_switch = np.inf
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
        malfunctioning_agent = 0
        min_fractional_speed = 1.
        num_steps = 1
        total_cells = 0
        total_switch = 0
        total_switches_neighbors = 0
        other_agent_ready_to_depart_encountered = 0
        index_comparision = 0
        first_switch_free = 0
        first_switch_neighbor = 0
        while exploring:
            total_cells += 1
            if len(self.switches_list) > 0:
                if position in self.switches_list:
                    total_switch += 1
            if len(self.switches_neightbors_list) > 0:
                if position in self.switches_neightbors_list:
                    total_switches_neighbors += 1

            opp_a = self.env.agent_positions[position]
            if opp_a > -1 and opp_a != handle and index_comparision == 0:
                index_comparision = 2.0 * int(opp_a < handle) - 1.0

            if first_switch_free == 0:
                if position in self.switches_list:
                    if opp_a > -1:
                        first_switch_free = 1.0
                    else:
                        first_switch_free = -1.0

            if first_switch_neighbor == 0:
                if position in self.switches_neightbors_list:
                    if opp_a > -1:
                        first_switch_neighbor = 1.0
                    else:
                        first_switch_neighbor = -1.0

            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:

                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[
                        position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[
                        position]

                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(
                    position, 0)

                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
                    # other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)

                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[
                        position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed

            # possible_transitions = self.envs.rail.get_transitions(*position, direction)
            x = self.env.agent_positions[position]
            if x != -1:
                if x != handle:
                    other_agent_encountered += 1.0
            if position in self.location_has_target and position != agent.target:
                other_target_encountered += 1.0

            if position == agent.target:
                own_target_encountered = 1.0

            num_transitions = 1  # np.count_nonzero(possible_transitions)
            if num_transitions == 1:
                new_direction_me = direction  # np.argmax(possible_transitions)
                new_cell_me = position  # get_new_position(position, new_direction_me)
                a = self.env.agent_positions[new_cell_me]
                if a != -1 and a != handle:
                    opp_agent = self.env.agents[a]
                    # look one step forward
                    possible_transitions = self.env.rail.get_transitions(
                        *opp_agent.position, opp_agent.direction)
                    if possible_transitions[new_direction_me] == 0:
                        self.opposite.append(position)
                        other_agent_opposite_direction += 1
                        break  # path is broken
                    else:
                        other_agent_same_direction += 1

            # Check number of possible transitions for agent and total number of transitions in cell (type)
            cell_transitions = self.env.rail.get_transitions(
                *position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

            # Register possible future conflict
            predicted_time = int(tot_dist * time_per_cell)

            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
                last_is_terminal = True
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
                last_is_target = True
                break

            # Check if crossing is found --> Not an unusable switch
            if crossing_found:
                # Treat the crossing as a straight rail cell
                total_transitions = 2
            num_transitions = np.count_nonzero(cell_transitions)

            exploring = False

            # Detect Switches that can only be used by other agents.
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
                unusable_switch = tot_dist

            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
                nbits = total_transitions
                if nbits == 1:
                    # Dead-end!
                    last_is_dead_end = True

                if not last_is_dead_end:
                    # Keep walking through the tree along `direction`
                    exploring = True
                    # convert one-hot encoding to 0,1,2,3
                    direction = np.argmax(cell_transitions)
                    position = get_new_position(position, direction)
                    num_steps += 1
                    tot_dist += 1
            elif num_transitions > 0:
                # Switch detected
                last_is_switch = True
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print(
                    "WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",
                    position[0], position[1], direction)
                last_is_terminal = True
                break

        # `position` is either a terminal node or a switch

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!

        if last_is_target:
            dist_to_next_branch = tot_dist
            dist_min_to_target = 0
        elif last_is_terminal:
            dist_to_next_branch = np.inf
            dist_min_to_target = self.env.distance_map.get()[handle,
                                                             position[0],
                                                             position[1],
                                                             direction]
        else:
            dist_to_next_branch = tot_dist
            dist_min_to_target = self.env.distance_map.get()[handle,
                                                             position[0],
                                                             position[1],
                                                             direction]

        node = MyNode(
            dist_own_target_encountered=own_target_encountered,
            dist_other_target_encountered=other_target_encountered,
            dist_other_agent_encountered=other_agent_encountered,
            dist_potential_conflict=potential_conflict,
            dist_unusable_switch=unusable_switch,
            dist_to_next_branch=dist_to_next_branch,
            dist_min_to_target=dist_min_to_target,
            num_agents_same_direction=other_agent_same_direction,
            num_agents_opposite_direction=other_agent_opposite_direction,
            num_agents_malfunctioning=malfunctioning_agent,
            speed_min_fractional=min_fractional_speed,
            num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
            total_cells=total_cells,
            total_switch=total_switch,
            total_switches_neighbors=total_switches_neighbors,
            index_comparision=index_comparision,
            first_switch_free=first_switch_free,
            first_switch_neighbor=first_switch_neighbor,
            childs={})

        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
        possible_transitions = self.env.rail.get_transitions(
            *position, direction)

        for i, branch_direction in enumerate([(direction + 4 + i) % 4
                                              for i in range(-1, 3)]):
            if last_is_dead_end and self.env.rail.get_transition(
                (*position, direction), (branch_direction + 2) % 4):
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
                new_cell = get_new_position(position,
                                            (branch_direction + 2) % 4)
                branch_observation, branch_visited = self._explore_branch(
                    handle, new_cell, (branch_direction + 2) % 4, tot_dist + 1,
                    depth + 1)
                node.childs[
                    self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            elif last_is_switch and possible_transitions[branch_direction]:
                new_cell = get_new_position(position, branch_direction)
                branch_observation, branch_visited = self._explore_branch(
                    handle, new_cell, branch_direction, tot_dist + 1,
                    depth + 1)
                node.childs[
                    self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            else:
                # no exploring possible, add just cells with infinity
                node.childs[self.tree_explored_actions_char[i]] = -np.inf

        if depth == self.max_depth:
            node.childs.clear()
        return node, visited
コード例 #12
0
class RailEnvTransitions(Grid4Transitions):
    """
    Special case of `GridTransitions` over a 2D-grid, with a pre-defined set
    of transitions mimicking the types of real Swiss rail connections.

    As no diagonal transitions are allowed in the RailEnv environment, the
    possible transitions for RailEnv from a cell to its neighboring ones
    are represented over 16 bits.

    The 16 bits are organized in 4 blocks of 4 bits each, the direction that
    the agent is facing.
    E.g., the most-significant 4-bits represent the possible movements (NESW)
    if the agent is facing North, etc...

    agent's direction:          North    East   South   West
    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
    example:                     1000     0000   0010    0000

    In the example, the agent can move from North to South and viceversa.
    """

    # Contains the basic transitions;
    # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions.
    transition_list = [
        int('0000000000000000', 2),  # empty cell - Case 0
        int('1000000000100000', 2),  # Case 1 - straight
        int('1001001000100000', 2),  # Case 2 - simple switch
        int('1000010000100001', 2),  # Case 3 - diamond drossing
        int('1001011000100001', 2),  # Case 4 - single slip
        int('1100110000110011', 2),  # Case 5 - double slip
        int('0101001000000010', 2),  # Case 6 - symmetrical
        int('0010000000000000', 2),  # Case 7 - dead end
        int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
        int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
        int('1100000000100010', 2)
    ]  # Case 2b (10) - simple switch mirrored

    def __init__(self):
        super(RailEnvTransitions,
              self).__init__(transitions=self.transition_list)

        # create this to make validation faster
        self.transitions_all = OrderedSet()
        for index, trans in enumerate(self.transitions):
            self.transitions_all.add(trans)
            if index in (2, 4, 6, 7, 8, 9, 10):
                for _ in range(3):
                    trans = self.rotate_transition(trans, rotation=90)
                    self.transitions_all.add(trans)
            elif index in (1, 5):
                trans = self.rotate_transition(trans, rotation=90)
                self.transitions_all.add(trans)

    def print(self, cell_transition):
        print("  NESW")
        print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
        print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
        print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
        print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))

    def is_valid(self, cell_transition):
        """
        Checks if a cell transition is a valid cell setup.

        Parameters
        ----------
        cell_transition : int
            64 bits used to encode the valid transitions for a cell.

        Returns
        -------
        Boolean
            True or False
        """
        return cell_transition in self.transitions_all
コード例 #13
0
    def get(self, handle: int = 0) -> Node:

        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)",
                  len(self.env.agents))
        agent = self.env.agents[handle]

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
            agent_virtual_position = agent.target
        else:
            return None

        possible_transitions = self.env.rail.get_transitions(
            *agent_virtual_position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Here information about the agent itself is stored
        distance_map = self.env.distance_map.get()

        # was referring to TreeObsForRailEnv.Node
        root_node_observation = Node(
            dist_own_target_encountered=0,
            dist_other_target_encountered=0,
            dist_other_agent_encountered=0,
            dist_potential_conflict=0,
            dist_unusable_switch=0,
            dist_to_next_branch=0,
            dist_min_to_target=distance_map[(handle, *agent_virtual_position,
                                             agent.direction)],
            num_agents_same_direction=0,
            num_agents_opposite_direction=0,
            num_agents_malfunctioning=agent.malfunction_data['malfunction'],
            speed_min_fractional=agent.speed_data['speed'],
            num_agents_ready_to_depart=0,
            childs={})
        # print("root node type:", type(root_node_observation))

        visited = OrderedSet()

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction

        if num_transitions == 1:
            orientation = np.argmax(possible_transitions)

        for i, branch_direction in enumerate([(orientation + i) % 4
                                              for i in range(-1, 3)]):

            if possible_transitions[branch_direction]:
                new_cell = get_new_position(agent_virtual_position,
                                            branch_direction)

                branch_observation, branch_visited = \
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                root_node_observation.childs[
                    self.tree_explored_actions_char[i]] = branch_observation

                visited |= branch_visited
            else:
                # add cells filled with infinity if no transition is possible
                root_node_observation.childs[
                    self.tree_explored_actions_char[i]] = -np.inf
        self.env.dev_obs_dict[handle] = visited

        return root_node_observation
コード例 #14
0
    def get(self, handle: int = 0) -> (Dict[str, Node], np.ndarray):
        """
        Computes the current observation for agent `handle` in env

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is::

            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

        Each branch data is organized as::

            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

        Each node information is composed of 9 features:

        #1:
            if own target lies on the explored branch the current distance from the agent in number of cells is stored.

        #2:
            if another agents target is detected the distance in number of cells from the agents current location\
            is stored

        #3:
            if another agent is detected the distance in number of cells from current agent position is stored.

        #4:
            possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
             distance in number of cells from current agent position

            0 = No other agent reserve the same cell at similar time

        #5:
            if an not usable switch (for agent) is detected we store the distance.

        #6:
            This feature stores the distance in number of cells to the next branching  (current node)

        #7:
            minimum distance from node to the agent's target given the direction of the agent if this path is chosen

        #8:
            agent in the same direction
            n = number of agents present same direction \
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction

        #9:
            agent in the opposite direction
            n = number of agents present other direction than myself (so conflict) \
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
            0 = no agent present other direction than myself

        #10:
            malfunctioning/blokcing agents
            n = number of time steps the oberved agent remains blocked

        #11:
            slowest observed speed of an agent in same direction
            1 if no agent is observed

            min_fractional speed otherwise
        #12:
            number of agents ready to depart but no yet active

        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)",
                  len(self.env.agents))
        agent = self.env.agents[handle]  # TODO: handle being treated as index

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
            agent_virtual_position = agent.target
        else:
            return None

        possible_transitions = self.env.rail.get_transitions(
            *agent_virtual_position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Here information about the agent itself is stored
        distance_map = self.env.distance_map.get()
        agent_not_started = int(
            agent.status.value == RailAgentStatus.READY_TO_DEPART.value)

        # TODO: this is hacky, find a better way
        root_node_observation = Node(
            dist_own_target_encountered=0,
            dist_other_target_encountered=0,
            own_target_encountered=0,
            shortest_path_direction=0,
            dist_other_agent_encountered=0,
            dist_potential_conflict=0,
            dist_unusable_switch=0,
            dist_to_next_branch=0,
            dist_min_to_target=distance_map[(handle, *agent_virtual_position,
                                             agent.direction)],
            num_agents_same_direction=0,
            num_agents_opposite_direction=0,
            num_agents_malfunctioning=agent.malfunction_data['malfunction'],
            num_agents_ready_to_depart=agent_not_started,
            childs={})

        visited = OrderedSet()

        orientation = agent.direction

        if num_transitions == 1:
            orientation = np.argmax(possible_transitions)

        min_dist = np.inf
        conflict_handles = []

        for i, branch_direction in enumerate([(orientation + i) % 4
                                              for i in range(-1, 3)]):

            if possible_transitions[branch_direction]:
                new_cell = get_new_position(agent_virtual_position,
                                            branch_direction)

                branch_observation, branch_visited, c_handles = self._explore_branch(
                    handle, new_cell, branch_direction, 1, 1)
                root_node_observation.childs[
                    self.tree_explored_actions_char[i]] = branch_observation

                visited |= branch_visited

                if branch_observation.dist_min_to_target < min_dist:
                    min_dist = branch_observation.dist_min_to_target
                    conflict_handles = c_handles
            else:
                # add cells filled with infinity if no transition is possible
                root_node_observation.childs[
                    self.tree_explored_actions_char[i]] = -np.inf
        self.env.dev_obs_dict[handle] = visited

        for ch in conflict_handles:
            self._conflict_map[handle].append(ch)

        nodes = [
            n for n in root_node_observation.childs.values() if n != -np.inf
        ]

        for i in range(self.max_depth):
            if len(nodes) < 1:
                break
            shortest_path_node = min(nodes, key=lambda n: n.dist_min_to_target)
            shortest_path_node = shortest_path_node._replace(
                shortest_path_direction=1.)
            nodes = [
                n for n in shortest_path_node.childs.values() if n != -np.inf
            ]

        return root_node_observation
コード例 #15
0
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
        """
        Utility function to compute tree-based observations.
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
        """

        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
            return [], []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False

        visited = OrderedSet()
        agent = self.env.agents[handle]
        time_per_cell = np.reciprocal(agent.speed_data["speed"])
        own_target_encountered = np.inf
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
        potential_conflict = np.inf
        unusable_switch = np.inf
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
        malfunctioning_agent = 0
        min_fractional_speed = 1.
        num_steps = 1
        other_agent_ready_to_depart_encountered = 0
        found_closest_communication = False
        communication = None

        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
                if self.location_has_agent_communication[position] is not None and not found_closest_communication:
                    found_closest_communication = True,
                    communication = self.location_has_agent_communication[position]

                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist

                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[position]

                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)

                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
                    other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)

                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed

                    # Other direction agents
                    # TODO: Test that this behavior is as expected
                    other_agent_opposite_direction += \
                        self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction),
                                                                                                  0)

                else:
                    # If no agent in the same direction was found all agents in that position are other direction
                    other_agent_opposite_direction += self.location_has_agent[position]

                # Check number of possible transitions for agent and total number of transitions in cell (type)
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

            # Register possible future conflict
            predicted_time = int(tot_dist * time_per_cell)
            if self.predictor and predicted_time < self.max_prediction_depth:
                int_position = coordinate_to_position(self.env.width, [position])
                if tot_dist < self.max_prediction_depth:

                    pre_step = max(0, predicted_time - 1)
                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)

                    # Look for conflicting paths at distance tot_dist
                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
                                self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

                    # Look for conflicting paths at distance num_step-1
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

                    # Look for conflicting paths at distance num_step+1
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
                                potential_conflict = tot_dist
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                potential_conflict = tot_dist

            if position in self.location_has_target and position != agent.target:
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist

            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist

            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
                last_is_terminal = True
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
                last_is_target = True
                break

            # Check if crossing is found --> Not an unusable switch
            if crossing_found:
                # Treat the crossing as a straight rail cell
                total_transitions = 2
            num_transitions = np.count_nonzero(cell_transitions)

            exploring = False

            # Detect Switches that can only be used by other agents.
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
                unusable_switch = tot_dist

            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
                nbits = total_transitions
                if nbits == 1:
                    # Dead-end!
                    last_is_dead_end = True

                if not last_is_dead_end:
                    # Keep walking through the tree along `direction`
                    exploring = True
                    # convert one-hot encoding to 0,1,2,3
                    direction = np.argmax(cell_transitions)
                    position = get_new_position(position, direction)
                    num_steps += 1
                    tot_dist += 1
            elif num_transitions > 0:
                # Switch detected
                last_is_switch = True
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
                last_is_terminal = True
                break

        # `position` is either a terminal node or a switch

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!

        if last_is_target:
            dist_to_next_branch = tot_dist
            dist_min_to_target = 0
        elif last_is_terminal:
            dist_to_next_branch = np.inf
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
        else:
            dist_to_next_branch = tot_dist
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]

        node = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
                                      dist_other_target_encountered=other_target_encountered,
                                      dist_other_agent_encountered=other_agent_encountered,
                                      dist_potential_conflict=potential_conflict,
                                      dist_unusable_switch=unusable_switch,
                                      dist_to_next_branch=dist_to_next_branch,
                                      dist_min_to_target=dist_min_to_target,
                                      num_agents_same_direction=other_agent_same_direction,
                                      num_agents_opposite_direction=other_agent_opposite_direction,
                                      num_agents_malfunctioning=malfunctioning_agent,
                                      speed_min_fractional=min_fractional_speed,
                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
                                      communication=communication,
                                      childs={})

        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
        possible_transitions = self.env.rail.get_transitions(*position, direction)
        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
                                                                          tot_dist + 1,
                                                                          depth + 1)
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            elif last_is_switch and possible_transitions[branch_direction]:
                new_cell = get_new_position(position, branch_direction)
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
                                                                          tot_dist + 1,
                                                                          depth + 1)
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                if len(branch_visited) != 0:
                    visited |= branch_visited
            else:
                # no exploring possible, add just cells with infinity
                node.childs[self.tree_explored_actions_char[i]] = -np.inf

        if depth == self.max_depth:
            node.childs.clear()
        return node, visited
コード例 #16
0
def a_star(grid_map: GridTransitionMap,
           start: IntVector2D,
           end: IntVector2D,
           a_star_distance_function: IntVector2DDistance = Vec2d.
           get_manhattan_distance,
           avoid_rails=False,
           respect_transition_validity=True,
           forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
    """

    :param avoid_rails:
    :param grid_map: Grid Map where the path is found in
    :param start: Start positions as (row,column)
    :param end:  End position as (row,column)
    :param a_star_distance_function: Define the distance function to use as heuristc:
            -get_euclidean_distance
            -get_manhattan_distance
            -get_chebyshev_distance
    :param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map.
            - True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found
            - False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards
    :param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map
    :return: IF a path is found a ordered list of al cells in path is returned
    """
    """
    Returns a list of tuples as a path from the given start to end.
    If no path is found, returns path to closest point to end.
    """
    rail_shape = grid_map.grid.shape

    start_node = AStarNode(start, None)
    end_node = AStarNode(end, None)
    open_nodes = OrderedSet()
    closed_nodes = OrderedSet()
    open_nodes.add(start_node)

    while len(open_nodes) > 0:
        # get node with current shortest est. path (lowest f)
        current_node = None
        for item in open_nodes:
            if current_node is None:
                current_node = item
                continue
            if item.f < current_node.f:
                current_node = item

        # pop current off open list, add to closed list
        open_nodes.remove(current_node)
        closed_nodes.add(current_node)

        # found the goal
        if current_node == end_node:
            path = []
            current = current_node
            while current is not None:
                path.append(current.pos)
                current = current.parent

            # return reversed path
            return path[::-1]

        # generate children
        children = []
        if current_node.parent is not None:
            prev_pos = current_node.parent.pos
        else:
            prev_pos = None

        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
            # update the "current" pos
            node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)

            # is node_pos inside the grid?
            if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[
                    1] >= rail_shape[1] or node_pos[1] < 0:
                continue

            # validate positions
            #
            if not grid_map.validate_new_transition(
                    prev_pos, current_node.pos, node_pos,
                    end_node.pos) and respect_transition_validity:
                continue
            # create new node
            new_node = AStarNode(node_pos, current_node)

            # Skip paths through forbidden regions if they are provided
            if forbidden_cells is not None:
                if node_pos in forbidden_cells and new_node != start_node and new_node != end_node:
                    continue

            children.append(new_node)

        # loop through children
        for child in children:
            # already in closed list?
            if child in closed_nodes:
                continue

            # create the f, g, and h values
            child.g = current_node.g + 1.0
            # this heuristic avoids diagonal paths
            if avoid_rails:
                child.h = a_star_distance_function(
                    child.pos, end_node.pos) + np.clip(
                        grid_map.grid[child.pos], 0, 1)
            else:
                child.h = a_star_distance_function(child.pos, end_node.pos)
            child.f = child.g + child.h

            # already in the open list?
            if child in open_nodes:
                continue

            # add the child to the open list
            open_nodes.add(child)

        # no full path found
        if len(open_nodes) == 0:
            return []
コード例 #17
0
    def get(self, handle: int = 0) -> Node:
        """
        Computes the current observation for agent `handle` in env

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is::

            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

        Each branch data is organized as::

            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

        Each node information is composed of 9 features:

        #1:
            if own target lies on the explored branch the current distance from the agent in number of cells is stored.

        #2:
            if another agents target is detected the distance in number of cells from the agents current location\
            is stored

        #3:
            if another agent is detected the distance in number of cells from current agent position is stored.

        #4:
            possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
             distance in number of cells from current agent position

            0 = No other agent reserve the same cell at similar time

        #5:
            if an not usable switch (for agent) is detected we store the distance.

        #6:
            This feature stores the distance in number of cells to the next branching  (current node)

        #7:
            minimum distance from node to the agent's target given the direction of the agent if this path is chosen

        #8:
            agent in the same direction
            n = number of agents present same direction \
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction

        #9:
            agent in the opposite direction
            n = number of agents present other direction than myself (so conflict) \
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
            0 = no agent present other direction than myself

        #10:
            malfunctioning/blokcing agents
            n = number of time steps the oberved agent remains blocked

        #11:
            slowest observed speed of an agent in same direction
            1 if no agent is observed

            min_fractional speed otherwise
        #12:
            number of agents ready to depart but no yet active

        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}
        self.location_has_agent_communication = {}

        for _agent in self.env.agents:
            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
                    _agent.position:

                try:
                    _agent.communication
                except:
                    _agent.communication = None

                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
                self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction']
                self.location_has_agent_communication[tuple(_agent.position)] = _agent.communication

            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
                    _agent.initial_position:
                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1

        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
        agent = self.env.agents[handle]  # TODO: handle being treated as index

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE:
            agent_virtual_position = agent.target
        else:
            return None

        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Here information about the agent itself is stored
        distance_map = self.env.distance_map.get()


        root_node_observation = CustomTreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
                                                       dist_other_agent_encountered=0, dist_potential_conflict=0,
                                                       dist_unusable_switch=0, dist_to_next_branch=0,
                                                       dist_min_to_target=distance_map[
                                                           (handle, *agent_virtual_position,
                                                            agent.direction)],
                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                       num_agents_malfunctioning=agent.malfunction_data['malfunction'],
                                                       speed_min_fractional=agent.speed_data['speed'],
                                                       num_agents_ready_to_depart=0,
                                                       communication=np.zeros(5),
                                                       childs={})

        visited = OrderedSet()

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction

        if num_transitions == 1:
            orientation = np.argmax(possible_transitions)

        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):

            if possible_transitions[branch_direction]:
                new_cell = get_new_position(agent_virtual_position, branch_direction)

                branch_observation, branch_visited = \
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation

                visited |= branch_visited
            else:
                # add cells filled with infinity if no transition is possible
                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
        self.env.dev_obs_dict[handle] = visited

        return root_node_observation
コード例 #18
0
    def get(self, handle: int = 0):
        agent = self.env.agents[handle]

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE:
            agent_virtual_position = agent.target
        else:
            return None

        if agent.position:
            possible_transitions = self.env.rail.get_transitions(
                *agent.position, agent.direction)
        else:
            possible_transitions = self.env.rail.get_transitions(
                *agent.initial_position, agent.direction)

        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions
        # are available;
        # organize them as [left, forward, right], relative to
        # the current orientation
        # If only one transition is possible, the forward branch is
        # aligned with it.
        distance_map = self.env.distance_map.get()
        max_distance = self.env.width + self.env.height
        # max_steps = int(4 * 2 * (20 + self.env.height + self.env.width))

        visited = OrderedSet()
        for _idx in range(10):
            # Check if any of the other prediction overlap
            # with agents own predictions
            x_coord = self.predictions[handle][_idx][1]
            y_coord = self.predictions[handle][_idx][2]

            # We add every observed cell to the observation rendering
            visited.add((x_coord, y_coord))

        # This variable will be access by the renderer to
        # visualize the observation
        self.env.dev_obs_dict[handle] = visited

        # min_distance stores the distance to target in shortest path
        # and any alternate path if exists
        min_distances = []
        for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[direction]:
                new_position = get_new_position(agent_virtual_position,
                                                direction)
                min_distances.append(distance_map[handle, new_position[0],
                                                  new_position[1], direction])
            else:
                min_distances.append(np.inf)

        if num_transitions == 1:
            observation1 = [0, 1, 0]
            observation2 = observation1

        elif num_transitions == 2:
            idx = np.argpartition(np.array(min_distances), 2)
            observation1 = [0, 0, 0]
            observation1[idx[0]] = 1

            observation2 = [0, 0, 0]
            observation2[idx[1]] = 1

        min_distances = np.sort(min_distances)
        incremental_distances = np.diff(np.sort(min_distances))
        incremental_distances[incremental_distances == np.inf] = 0
        incremental_distances[np.isnan(incremental_distances)] = 0

        distance_target = distance_map[(handle, *agent_virtual_position,
                                        agent.direction)]

        root_node_observation = LocalConflictObsForRailEnv.Node(
            distance_target=distance_target / max_distance,
            observation_shortest=observation1,
            observation_next_shortest=observation2,
            extra_distance=incremental_distances[0] / max_distance,
            malfunction=agent.malfunction_data['malfunction'] / max_distance,
            malfunction_rate=agent.malfunction_data['malfunction_rate'],
            next_malfunction=agent.malfunction_data['next_malfunction'] /
            max_distance,
            nr_malfunctions=agent.malfunction_data['nr_malfunctions'],
            speed=agent.speed_data['speed'],
            position_fraction=agent.speed_data['position_fraction'],
            transition_action_on_cellexit=agent.
            speed_data['transition_action_on_cellexit'],
            num_transitions=num_transitions,
            moving=agent.moving,
            status=agent.status,
            action_required=action_required(agent),
            width=self.env.width,
            height=self.env.height,
            n_agents=self.get_number_of_agents(),
            predictions=self.predictions[handle],
            predicted_pos=self.predicted_pos)

        return root_node_observation