Exemple #1
0
    def get(self, handle: int = 0) -> Node:
        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
        agent = self.env.agents[handle]

        if agent.status == RailAgentStatus.READY_TO_DEPART:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.ACTIVE:
            agent_virtual_position = agent.position
        elif agent.status == RailAgentStatus.DONE:
            agent_virtual_position = agent.target
        else:
            return None

        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Root node
        # Here information about the agent itself is stored
        distance_map = self.env.distance_map.get()
        root_node_observation = TreeObsForRailEnv.Node(dist_other_agent_encountered=0,
                                                       dist_to_unusable_switch=0,
                                                       dist_to_next_branch=0,
                                                       dist_min_to_target=distance_map[(handle, *agent_virtual_position, agent.direction)],
                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                       num_agents_malfunctioning=agent.malfunction_data['malfunction'],
                                                       max_index_oppposite_direction=0, max_handle_agent_not_opposite=0,
                                                       has_deadlocked_agent=0,
                                                       first_agent_handle=0,
                                                       first_agent_not_opposite=0,
                                                       childs=[None]*4)

        # Child nodes
        orientation = agent.direction

        if num_transitions == 1:  # Whaat? TODO
            orientation = np.argmax(possible_transitions)

        for i, branch_direction in enumerate([(orientation + 4 + i) % 4 for i in range(-1, 3)]):
            if possible_transitions[branch_direction]:
                new_cell = get_new_position(agent_virtual_position, branch_direction)

                branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
                root_node_observation.childs[i] = branch_observation

        return root_node_observation
Exemple #2
0
 def dfs(self, h, w, d):
     self.dfs_used[h, w, d] = 1
     possible_transitions = self.env.rail.get_transitions(h, w, d)
     for ndir in range(4):
         if possible_transitions[ndir]:
             nh, nw = get_new_position((h, w), ndir)
             if (nh, nw, ndir) in self.nodes:
                 self.dist_to_node[h, w, d, ndir] = 1
                 self.next_node[h][w][d][ndir] = (nh, nw, ndir)
             else:
                 if not self.dfs_used[nh, nw, ndir]:
                     self.dfs(nh, nw, ndir)
                 for last_dir in range(4):
                     if self.dist_to_node[nh, nw, ndir, last_dir] > -0.5:
                         self.dist_to_node[h, w, d,
                                           ndir] = self.dist_to_node[
                                               nh, nw, ndir, last_dir] + 1
                         self.next_node[h][w][d][ndir] = self.next_node[nh][
                             nw][ndir][last_dir]
Exemple #3
0
def _build_shortest_path(env, handle):
    agent = env.agents[handle]
    pos = agent.initial_position
    dir = agent.initial_direction

    dist_min_to_target = env.obs_builder.rail_graph.dist_to_target(handle, pos[0], pos[1], dir)

    path = set()
    while dist_min_to_target:
        path.add((*pos, dir))
        possible_transitions = env.rail.get_transitions(*pos, dir)
        for new_dir in range(4):
            if possible_transitions[new_dir]:
                new_pos = get_new_position(pos, new_dir)
                new_min_dist = env.obs_builder.rail_graph.dist_to_target(handle, new_pos[0], new_pos[1], new_dir)
                if new_min_dist + 1 == dist_min_to_target:
                    dist_min_to_target = new_min_dist
                    pos, dir = new_pos, new_dir
                    break

    return path
Exemple #4
0
    def reinit_greedy(self):
        self.greedy_way = defaultdict(int)
        rail_env = self.env
        self.location_has_target = set(agent.target for agent in self.env.agents)
        self.switches = set()

        for h in range(rail_env.height):
            for w in range(rail_env.width):
                pos = (h, w)
                transition_bit = bin(self.env.rail.get_full_transitions(*pos))
                total_transitions = transition_bit.count("1")
                if total_transitions > 2:
                    self.switches.add(pos)

        self.target_neighbors = set()
        self.switches_neighbors = set()

        for h in range(rail_env.height):
            for w in range(rail_env.width):
                pos = (h, w)
                for orientation in range(4):
                    possible_transitions = self.env.rail.get_transitions(*pos, orientation)
                    for ndir in range(4):
                        if possible_transitions[ndir]:
                            nxt = get_new_position(pos, ndir)
                            if nxt in self.location_has_target:
                                self.target_neighbors.add((h, w, orientation))
                            if nxt in self.switches:
                                self.switches_neighbors.add((h, w, orientation))

        self.decision_cells = np.zeros((self.env.height, self.env.width, 4), dtype=np.bool)
        for posdir in self.switches_neighbors.union(self.target_neighbors):
            self.decision_cells[posdir] = 1
        for pos in self.switches.union(self.location_has_target):
            self.decision_cells[pos[0], pos[1], :] = 1

        self.location_has_target_array = np.zeros((self.env.height, self.env.width), dtype=np.bool)
        for pos in self.location_has_target:
            self.location_has_target_array[pos] = 1
        self.location_has_target = self.location_has_target_array
Exemple #5
0
    def _explore_branch(self, handle, position, direction, depth, was_target=False):
        start_position = (position[0], position[1], direction)
        if start_position in self.cached_nodes:
            node, position, direction = self.cached_nodes[start_position]
        else:
            node, position, direction = self._explore_line(position, direction)
            self.cached_nodes[start_position] = node, position, direction

        last_is_target = (position == self.env.agents[handle].target)

        if was_target:
            dist_min_to_target = -1
        elif last_is_target:
            dist_min_to_target = 0
        else:
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]

        node = node._replace(dist_min_to_target=dist_min_to_target, childs=[None]*4) # copy returned here!

        if depth == self.max_depth:
            return node
        if node.has_deadlocked_agent: # thee end
            return node

        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions

        possible_transitions = self.env.rail.get_transitions(*position, direction)
        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
            if possible_transitions[branch_direction]:
                new_cell = get_new_position(position, branch_direction)
                branch_observation = self._explore_branch(handle, new_cell, branch_direction,
                                                                    depth + 1, was_target=last_is_target or was_target)
                node.childs[i] = branch_observation

        return node
Exemple #6
0
    def _build(self):
        self.nodes = set(agent.target for agent in self.env.agents)
        height, width = self.env.height, self.env.width
        self.valid_pos = list()

        for h in range(height):
            for w in range(width):
                pos = (h, w)
                transition_bit = bin(self.env.rail.get_full_transitions(*pos))
                total_transitions = transition_bit.count("1")
                if total_transitions > 2:
                    self.nodes.add(pos)
                if total_transitions > 0:
                    self.valid_pos.append((h, w))

        n_nodes = set()
        for h, w in self.nodes:
            for d in range(4):
                cell_transitions = self.env.rail.get_transitions(h, w, d)
                if np.any(cell_transitions):
                    n_nodes.add((h, w, d))

        self.nodes = n_nodes

        self.dist_to_node = -np.ones((height, width, 4, 4))
        self.next_node = [[[[None for _ in range(4)] for _ in range(4)]
                           for _ in range(width)] for _ in range(height)]
        self.dfs_used = np.zeros((height, width, 4))
        for h in range(height):
            for w in range(width):
                for d in range(4):
                    if not self.dfs_used[h, w, d]:
                        self.dfs(h, w, d)

        self.n_nodes = len(self.nodes)
        self.nodes_dict = np.empty((height, width, 4), dtype=np.int)
        for i, (h, w, d) in enumerate(self.nodes):
            self.nodes_dict[h, w, d] = i

        self.cell_to_edge = [[list() for _ in range(width)]
                             for _ in range(height)]

        self.amatrix = np.ones((self.n_nodes, self.n_nodes)) * np.inf
        self.amatrix[np.arange(self.n_nodes), np.arange(self.n_nodes)] = 0
        for i, (h, w, d) in enumerate(self.nodes):
            for dd in range(4):
                nnode = self.next_node[h][w][d][dd]
                if nnode is not None:
                    self.amatrix[i][
                        self.nodes_dict[nnode]] = self.dist_to_node[h, w, d,
                                                                    dd]

                    cell = (h, w, d)
                    nnode_i = self.nodes_dict[nnode]
                    while cell != nnode:
                        possible_transitions = self.env.rail.get_transitions(
                            *cell)
                        for ndir in range(4):
                            if possible_transitions[ndir] and (cell !=
                                                               (h, w, d)
                                                               or ndir == dd):
                                nh, nw = get_new_position((cell[0], cell[1]),
                                                          ndir)
                                cell = (nh, nw, ndir)
                                self.cell_to_edge[nh][nw].append((i, nnode_i))
                                break
Exemple #7
0
    def _explore_line(self, position, direction):
        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops # TODO Whaat? Treat dead-ends not as nodes
        tot_dist = 1
        other_agent_encountered = np.inf
        unusable_switch = np.inf
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
        num_agents_malfunctioning = 0
        max_index_oppposite_direction = 0
        has_deadlocked_agent = 0
        first_agent_handle = -1
        first_agent_not_opposite = True
        max_handle_agent_not_opposite = True
        max_handle = -1
        while True:
            if self.location_with_agent[position] != -1:
                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist

                other_handle = self.location_with_agent[position]

                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.env.agents[other_handle].malfunction_data['malfunction'] > 0:
                    num_agents_malfunctioning += 1

                if self.deadlock_checker.is_deadlocked(other_handle): # hack SimpleObservation has deadlock_checker
                    has_deadlocked_agent = 1
                elif first_agent_handle == -1:
                    first_agent_handle = other_handle
                    first_agent_not_opposite = (self.env.agents[first_agent_handle].direction == direction)

                if self.env.agents[other_handle].direction == direction:
                    # Cummulate the number of agents on branch with same direction
                    other_agent_same_direction += 1
                    if other_handle > max_handle:
                        max_handle = other_handle
                        max_handle_agent_not_opposite = True
                else:
                    # If no agent in the same direction was found all agents in that position are other direction
                    other_agent_opposite_direction += 1
                    max_index_oppposite_direction = max(max_index_oppposite_direction, other_handle)
                    if other_handle > max_handle:
                        max_handle = other_handle
                        max_handle_agent_not_opposite = False

            if self.greedy_checker.location_has_target[position]:
                break

            only_direction = self.graph.get_if_one_transition(position, direction)
            if only_direction != -1:
                position = get_new_position(position, only_direction)
                direction = only_direction
                tot_dist += 1
                continue

            # Check number of possible transitions for agent and total number of transitions in cell (type)
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            num_transitions = np.count_nonzero(cell_transitions)
            total_transitions = self.graph.get_total_transitions(position)

            # Detect Switches that can only be used by other agents.
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
                unusable_switch = tot_dist

            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
                if total_transitions == 1:
                    # Dead-end!
                    break

                assert False
            elif num_transitions > 0:
                # Switch detected
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
                break


        node = TreeObsForRailEnv.Node(dist_other_agent_encountered=other_agent_encountered,
                                      dist_to_unusable_switch=unusable_switch,
                                      dist_to_next_branch=tot_dist,
                                      dist_min_to_target=None,
                                      num_agents_same_direction=other_agent_same_direction,
                                      num_agents_opposite_direction=other_agent_opposite_direction,
                                      num_agents_malfunctioning=num_agents_malfunctioning,
                                      max_index_oppposite_direction=max_index_oppposite_direction,
                                      max_handle_agent_not_opposite=max_handle_agent_not_opposite,
                                      has_deadlocked_agent=has_deadlocked_agent,
                                      first_agent_handle=first_agent_handle,
                                      first_agent_not_opposite=first_agent_not_opposite,
                                      childs=[None]*4)

        return node, position, direction