Ejemplo n.º 1
0
    def load_pngs(self,
                  file_directory,
                  rotate=False,
                  agent_colors=False,
                  background_image=None,
                  whitefilter=None):
        pil = {}

        transitions = RailEnvTransitions()

        directions = list("NESW")

        for transition, file in file_directory.items():

            # Translate the ascii transition description in the format  "NE WS" to the
            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
            transition_16_bit = ["0"] * 16
            for sTran in transition.split(" "):
                if len(sTran) == 2:
                    in_direction = directions.index(sTran[0])
                    out_direction = directions.index(sTran[1])
                    transition_idx = 4 * in_direction + out_direction
                    transition_16_bit[transition_idx] = "1"
            transition_16_bit_string = "".join(transition_16_bit)
            binary_trans = int(transition_16_bit_string, 2)

            pil_rail = self.pil_from_png_file('flatland.png',
                                              file).convert("RGBA")

            if background_image is not None:
                img_bg = self.pil_from_png_file(
                    'flatland.png', background_image).convert("RGBA")
                pil_rail = Image.alpha_composite(img_bg, pil_rail)

            if whitefilter is not None:
                img_bg = self.pil_from_png_file('flatland.png',
                                                whitefilter).convert("RGBA")
                pil_rail = Image.alpha_composite(pil_rail, img_bg)

            if rotate:
                # For rotations, we also store the base image
                pil[binary_trans] = pil_rail
                # Rotate both the transition binary and the image and save in the dict
                for nRot in [90, 180, 270]:
                    binary_trans_2 = transitions.rotate_transition(
                        binary_trans, nRot)

                    # PIL rotates anticlockwise for positive theta
                    pil_rail_2 = pil_rail.rotate(-nRot)
                    pil[binary_trans_2] = pil_rail_2

            if agent_colors:
                # For recoloring, we don't store the base image.
                base_color = self.rgb_s2i("d50000")
                pils = self.recolor_image(pil_rail, base_color,
                                          self.agent_colors)
                for color_idx, pil_rail_2 in enumerate(pils):
                    pil[(binary_trans, color_idx)] = pils[color_idx]

        return pil
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
        rail_env_transitions = RailEnvTransitions()

        height = len(rail_spec)
        width = len(rail_spec[0])
        rail = GridTransitionMap(width=width,
                                 height=height,
                                 transitions=rail_env_transitions)

        for r in range(height):
            for c in range(width):
                rail_spec_of_cell = rail_spec[r][c]
                index_basic_type_of_cell_ = rail_spec_of_cell[0]
                rotation_cell_ = rail_spec_of_cell[1]
                if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(
                        rail_env_transitions.transitions):
                    print("ERROR - invalid rail_spec_of_cell type=",
                          index_basic_type_of_cell_)
                    return []
                basic_type_of_cell_ = rail_env_transitions.transitions[
                    index_basic_type_of_cell_]
                effective_transition_cell = rail_env_transitions.rotate_transition(
                    basic_type_of_cell_, rotation_cell_)
                rail.set_transitions((r, c), effective_transition_cell)

        return [rail, None]
Ejemplo n.º 3
0
def test_get_entry_directions():
    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    vertical_line = cells[1]
    south_symmetrical_switch = cells[6]
    north_symmetrical_switch = transitions.rotate_transition(
        south_symmetrical_switch, 180)

    south_east_turn = int('0100000000000010', 2)
    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
    north_west_turn = transitions.rotate_transition(south_east_turn, 180)

    def _assert(transition, expected):
        actual = Grid4Transitions.get_entry_directions(transition)
        assert actual == expected, "Found {}, expected {}.".format(
            actual, expected)

    _assert(south_east_turn, [True, False, False, True])
    _assert(south_west_turn, [True, True, False, False])
    _assert(north_east_turn, [False, False, True, True])
    _assert(north_west_turn, [False, True, True, False])
    _assert(vertical_line, [True, False, True, False])
    _assert(south_symmetrical_switch, [True, True, False, True])
    _assert(north_symmetrical_switch, [False, True, True, True])
Ejemplo n.º 4
0
def compute_all_possible_transitions():

    # Bitmaps are read in decimal numbers
    transitions = RailEnvTransitions()
    transition_list = transitions.transition_list
    '''
    transition_list = [int('0000000000000000', 2),  # empty cell - Case 0
                       int('1000000000100000', 2),  # Case 1 - straight
                       int('1001001000100000', 2),  # Case 2 - simple switch
                       int('1000010000100001', 2),  # Case 3 - diamond drossing
                       int('1001011000100001', 2),  # Case 4 - single slip
                       int('1100110000110011', 2),  # Case 5 - double slip
                       int('0101001000000010', 2),  # Case 6 - symmetrical
                       int('0010000000000000', 2),  # Case 7 - dead end
                       int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
                       int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
                       int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
    '''

    transitions_with_rotation_dict = {}
    rotation_degrees = [0, 90, 180, 270]

    for i in range(len(transition_list)):
        for r in rotation_degrees:
            t = transition_list[i]
            rot_transition = transitions.rotate_transition(t, r)
            if rot_transition not in transitions_with_rotation_dict:
                transitions_with_rotation_dict[rot_transition] = np.array(
                    [i, r])

    return transitions_with_rotation_dict
Ejemplo n.º 5
0
def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
    # We instantiate a very simple rail network on a 7x10 grid:
    #  0 1 2 3 4 5 6 7 8 9  10
    # 0        /-------------\
    # 1        |             |
    # 2        |             |
    # 3 _ _ _ /_  _ _        |
    # 4              \   ___ /
    # 5               |/
    # 6               |
    # 7               |
    transitions = RailEnvTransitions()
    cells = transitions.transition_list

    empty = cells[0]
    dead_end_from_south = cells[7]
    right_turn_from_south = cells[8]
    right_turn_from_west = transitions.rotate_transition(
        right_turn_from_south, 90)
    right_turn_from_north = transitions.rotate_transition(
        right_turn_from_south, 180)
    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
    dead_end_from_north = transitions.rotate_transition(
        dead_end_from_south, 180)
    dead_end_from_east = transitions.rotate_transition(dead_end_from_south,
                                                       270)
    vertical_straight = cells[1]
    simple_switch_north_left = cells[2]
    simple_switch_north_right = cells[10]
    simple_switch_left_east = transitions.rotate_transition(
        simple_switch_north_left, 90)
    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
    double_switch_north_horizontal_straight = transitions.rotate_transition(
        double_switch_south_horizontal_straight, 180)
    rail_map = np.array(
        [[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 +
         [right_turn_from_west]] +
        [[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight]
         ] * 2 + [[dead_end_from_east] + [horizontal_straight] * 2 +
                  [simple_switch_left_east] + [horizontal_straight] * 2 +
                  [right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
        [[empty] * 6 + [simple_switch_north_right] +
         [horizontal_straight] * 2 + [right_turn_from_north]] +
        [[empty] * 6 + [vertical_straight] + [empty] * 3] +
        [[empty] * 6 + [dead_end_from_north] + [empty] * 3],
        dtype=np.uint16)
    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)
    rail.grid = rail_map
    return rail, rail_map
Ejemplo n.º 6
0
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
    # We instantiate a very simple rail network on a 7x10 grid:
    #        |
    #        |
    #        |
    # _ _ _ _\ _ _  _  _ _ _
    #               \
    #                |
    #                |
    #                |
    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    empty = cells[0]
    dead_end_from_south = cells[7]
    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
    vertical_straight = cells[1]
    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
    simple_switch_north_right = cells[10]
    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
    simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
    rail_map = np.array(
        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
        [[dead_end_from_east] + [horizontal_straight] * 2 +
         [simple_switch_east_west_north] +
         [horizontal_straight] * 2 + [simple_switch_west_east_south] +
         [horizontal_straight] * 2 + [dead_end_from_west]] +
        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0], transitions=transitions)
    rail.grid = rail_map
    return rail, rail_map
def test_rail_env_has_deadend():
    deadends = set([
        int(rw('0010 0000 0000 0000'), 2),
        int(rw('0000 0001 0000 0000'), 2),
        int(rw('0000 0000 1000 0000'), 2),
        int(rw('0000 0000 0000 0100'), 2)
    ])
    ret = RailEnvTransitions()
    transitions_all = ret.transitions_all
    for t in transitions_all:
        expected_has_deadend = t in deadends
        actual_had_deadend = ret.has_deadend(t)
        assert actual_had_deadend == expected_has_deadend, \
            "{} should be deadend = {}, actual = {}".format(t, )
 def generator(width: int,
               height: int,
               num_agents: int = 0,
               num_resets: int = 0) -> RailGeneratorProduct:
     rail_trans = RailEnvTransitions()
     grid_map = GridTransitionMap(width=width,
                                  height=height,
                                  transitions=rail_trans)
     rail_array = grid_map.grid
     rail_array.fill(0)
     new_tran = rail_trans.set_transition(1, 1, 1, 1)
     print(new_tran)
     rail_array[0, 0] = new_tran
     rail_array[0, 1] = new_tran
     return grid_map, None
Ejemplo n.º 9
0
    def compute_all_possible_transitions(self):
        '''
        Given transitions list considering cell types,
        outputs all possible transitions bitmap,
        considering cell rotations too
        '''
        # Bitmaps are read in decimal numbers
        transitions = RailEnvTransitions()
        transitions_with_rotation_dict = {}
        rotation_degrees = [0, 90, 180, 270]

        for index, transition in enumerate(transitions.transition_list):
            for rot_type, rot in enumerate(rotation_degrees):
                rot_transition = transitions.rotate_transition(transition, rot)
                if rot_transition not in transitions_with_rotation_dict:
                    transitions_with_rotation_dict[rot_transition] = (np.array(
                        [index, rot_type]))
        return transitions_with_rotation_dict
Ejemplo n.º 10
0
def test_walker():
    # _ _ _

    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    dead_end_from_south = cells[7]
    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
    dead_end_from_east = transitions.rotate_transition(dead_end_from_south,
                                                       270)
    vertical_straight = cells[1]
    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)

    rail_map = np.array(
        [[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]],
        dtype=np.uint16)
    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)
    rail.grid = rail_map
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2,
            predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
    )
    env.reset()

    # set initial position and direction for testing...
    env.agents[0].position = (0, 1)
    env.agents[0].direction = 1
    env.agents[0].target = (0, 0)

    # reset to set agents from agents_static
    env.reset(False, False)

    print(env.distance_map.get()[(0, *[0, 1], 1)])
    assert env.distance_map.get()[(0, *[0, 1], 1)] == 3
    print(env.distance_map.get()[(0, *[0, 2], 3)])
    assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
        rail_trans = RailEnvTransitions()
        grid_map = GridTransitionMap(width=width,
                                     height=height,
                                     transitions=rail_trans)
        rail_array = grid_map.grid
        rail_array.fill(0)

        return grid_map, None
Ejemplo n.º 12
0
def test_dead_end():
    transitions = RailEnvTransitions()

    straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
    straight_horizontal = transitions.rotate_transition(straight_vertical,
                                                        90)

    dead_end_from_south = int('0010000000000000', 2)  # Case 7 - dead end

    # We instantiate the following railway
    # O->-- where > is the train and O the target. After 6 steps,
    # the train should be done.

    rail_map = np.array(
        [[transitions.rotate_transition(dead_end_from_south, 270)] +
         [straight_horizontal] * 3 +
         [transitions.rotate_transition(dead_end_from_south, 90)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    # We try the configuration in the 4 directions:
    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]

    # In the vertical configuration:
    rail_map = np.array(
        [[dead_end_from_south]] + [[straight_vertical]] * 3 +
        [[transitions.rotate_transition(dead_end_from_south, 180)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
def test_is_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()
    transition_list = rail_env_trans.transitions

    for t in transition_list:
        assert (rail_env_trans.is_valid(t) is True)
        for i in range(3):
            rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
            assert (rail_env_trans.is_valid(rot_trans) is True)

    assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
    assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
    assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> List:
        env_dict = persistence.RailEnvPersister.load_env_dict(
            filename, load_from_package=load_from_package)
        rail_env_transitions = RailEnvTransitions()

        grid = np.array(env_dict["grid"])
        rail = GridTransitionMap(width=np.shape(grid)[1],
                                 height=np.shape(grid)[0],
                                 transitions=rail_env_transitions)
        rail.grid = grid
        if "distance_map" in env_dict:
            distance_map = env_dict["distance_map"]
            if len(distance_map) > 0:
                return rail, {'distance_map': distance_map}
        return [rail, None]
def test_adding_new_valid_transition():
    rail_trans = RailEnvTransitions()
    grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)

    # adding straight
    assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5),
                                             (10, 10)) is True)

    # adding valid right turn
    assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6),
                                             (10, 10)) is True)
    # adding valid left turn
    assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6),
                                             (10, 10)) is True)

    # adding invalid turn
    grid_map.grid[(5, 5)] = rail_trans.transitions[2]
    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6),
                                             (10, 10)) is False)

    # should create #4 -> valid
    grid_map.grid[(5, 5)] = rail_trans.transitions[3]
    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6),
                                             (10, 10)) is True)

    # adding invalid turn
    grid_map.grid[(5, 5)] = rail_trans.transitions[7]
    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6),
                                             (10, 10)) is False)

    # test path start condition
    grid_map.grid[(5, 5)] = rail_trans.transitions[0]
    assert (grid_map.validate_new_transition(None, (5, 5), (5, 6),
                                             (10, 10)) is True)

    # test path end condition
    grid_map.grid[(5, 5)] = rail_trans.transitions[0]
    assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5),
                                             (6, 5)) is True)
Ejemplo n.º 16
0
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = np.random) -> RailGenerator:
        """

        Parameters
        ----------
        width: int
            Width of the environment
        height: int
            Height of the environment
        num_agents:
            Number of agents to be placed within the environment
        num_resets: int
            Count for how often the environment has been reset

        Returns
        -------
        grid_map: 

        """
        rail_trans = RailEnvTransitions()
        grid_map = GridTransitionMap(width=width, height=height, \
            transitions=rail_trans)
        vector_field = np.zeros(shape=(height, width)) - 1.
        city_padding = 2
        city_radius = city_padding
        rails_between_cities = 1
        rails_in_city = 2
        np_random.seed(seed)

        # Calculate the max number of cities allowed
        # and reduce the number of cities to build to avoid problems
        max_feasible_cities = min(num_cities,
                                  ((height - 2) // (2 * (city_radius + 1))) * \
                                      ((width - 2) // (2 * (city_radius + 1))))

        if max_feasible_cities < num_cities:
            sys.exit(
                f"[ABORT] Cannot fit more than {max_feasible_cities} city in this map, no feasible environment possible! Aborting."
            )

        # obtain city positions
        city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, \
            city_radius, width, height)

        # Set up connection points for all cities
        inner_connection_points, outer_connection_points, city_orientations, city_cells = \
            _generate_city_connection_points(
                city_positions, city_radius, vector_field, rails_between_cities,
                rails_in_city, np_random=np_random)
        # import pdb; pdb.set_trace()
        # connect the cities through the connection points
        inter_city_lines = _connect_cities(city_positions,
                                           outer_connection_points, city_cells,
                                           rail_trans, grid_map)

        # Build inner cities
        free_rails = _build_inner_cities(city_positions,
                                         inner_connection_points,
                                         outer_connection_points, rail_trans,
                                         grid_map)

        # Populate cities
        train_stations = _set_trainstation_positions(city_positions,
                                                     city_radius, free_rails)

        # Fix all transition elements
        _fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)

        return grid_map, {
            'agents_hints': {
                'num_agents': num_agents,
                'city_positions': city_positions,
                'train_stations': train_stations,
                'city_orientations': city_orientations
            }
        }
Ejemplo n.º 17
0
    def __init__(self):
        dFiles = {
            "": "Background_#9CCB89.svg",
            "WE": "Gleis_Deadend.svg",
            "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
            "WW EE": "Gleis_horizontal.svg",
            "EN SW": "Gleis_Kurve_oben_links.svg",
            "WN SE": "Gleis_Kurve_oben_rechts.svg",
            "ES NW": "Gleis_Kurve_unten_links.svg",
            "NE WS": "Gleis_Kurve_unten_rechts.svg",
            "NN SS": "Gleis_vertikal.svg",
            "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
            "EE WW EN SW": "Weiche_horizontal_oben_links.svg",
            "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
            "EE WW ES NW": "Weiche_horizontal_unten_links.svg",
            "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
            "NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
            "NE NW ES WS": "Weiche_Symetrical.svg",
            "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
            "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
            "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"
        }

        self.dSvg = {}

        transitions = RailEnvTransitions()

        lDirs = list("NESW")

        svgBG = SVG("./svg/Background_#9CCB89.svg")

        for sTrans, sFile in dFiles.items():
            svg = SVG("./svg/" + sFile)

            # Translate the ascii transition descption in the format  "NE WS" to the
            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
            lTrans16 = ["0"] * 16
            for sTran in sTrans.split(" "):
                if len(sTran) == 2:
                    iDirIn = lDirs.index(sTran[0])
                    iDirOut = lDirs.index(sTran[1])
                    iTrans = 4 * iDirIn + iDirOut
                    lTrans16[iTrans] = "1"
            sTrans16 = "".join(lTrans16)
            binTrans = int(sTrans16, 2)
            print(sTrans, sTrans16, sFile)

            # Merge the transition svg image with the background colour.
            # This is a shortcut / hack and will need re-working.
            if binTrans > 0:
                svg = svg.merge(svgBG)

            self.dSvg[binTrans] = svg

            # Rotate both the transition binary and the image and save in the dict
            for nRot in [90, 180, 270]:
                binTrans2 = transitions.rotate_transition(binTrans, nRot)
                svg2 = svg.copy()
                svg2.set_rotate(nRot)
                self.dSvg[binTrans2] = svg2
def test_rail_environment_single_agent():
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    vertical_line = cells[1]
    south_symmetrical_switch = cells[6]
    north_symmetrical_switch = transitions.rotate_transition(
        south_symmetrical_switch, 180)
    south_east_turn = int('0100000000000010', 2)
    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
    north_west_turn = transitions.rotate_transition(south_east_turn, 180)

    rail_map = np.array(
        [[south_east_turn, south_symmetrical_switch, south_west_turn],
         [vertical_line, vertical_line, vertical_line],
         [north_east_turn, north_symmetrical_switch, north_west_turn]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=3, height=3, transitions=transitions)
    rail.grid = rail_map
    rail_env = RailEnv(width=3,
                       height=3,
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(),
                       number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    for _ in range(200):
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        assert (transitions.get_transitions(rail_map[agent.position],
                                            agent.direction) != (0, 0, 0, 0))

        initial_pos = agent.position

        valid_active_actions_done = 0
        pos = initial_pos
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.randint(4)

            _, _, _, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]
            if prev_pos != pos:
                valid_active_actions_done += 1

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            done = False
            while not done:
                # We randomly select an action
                action = np.random.randint(4)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
Ejemplo n.º 19
0
def test_rail_environment_single_agent(show=False):
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    
    
    
    if False:
        # This env creation doesn't quite work right.
        cells = transitions.transition_list
        vertical_line = cells[1]
        south_symmetrical_switch = cells[6]
        north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
        south_east_turn = int('0100000000000010', 2)
        south_west_turn = transitions.rotate_transition(south_east_turn, 90)
        north_east_turn = transitions.rotate_transition(south_east_turn, 270)
        north_west_turn = transitions.rotate_transition(south_east_turn, 180)

        rail_map = np.array([[south_east_turn, south_symmetrical_switch,
                            south_west_turn],
                            [vertical_line, vertical_line, vertical_line],
                            [north_east_turn, north_symmetrical_switch,
                            north_west_turn]],
                            dtype=np.uint16)

        rail = GridTransitionMap(width=3, height=3, transitions=transitions)
        rail.grid = rail_map
        rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
                        schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
    else:
        rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
        rail_map = rail_env.rail.grid
    
    rail_env._max_episode_steps = 1000

    _ = rail_env.reset(False, False, True)

    liActions = [int(a) for a in RailEnvActions]

    env_renderer = RenderTool(rail_env)

    #RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
    
    for _ in range(5):

        #rail_env.agents[0].initial_position = (1,2)
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        if show:
            print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
            print(transitions.get_transitions(rail_map[agent.position], agent.direction))

        #assert (transitions.get_transitions(
        #    rail_map[agent.position],
        #    agent.direction) != (0, 0, 0, 0))

        # HACK - force the direction to one we know is good.
        #agent.initial_position = agent.position = (2,3)
        agent.initial_direction = agent.direction = 0

        if show:
            print ("handle:", agent.handle)
        #agent.initial_position = initial_pos = agent.position

        valid_active_actions_done = 0
        pos = agent.position

        if show:
            env_renderer.render_env(show=show, show_agents=True)
            time.sleep(0.01)

        iStep = 0
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.choice(liActions)
            #action = RailEnvActions.MOVE_FORWARD

            _, _, dict_done, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]

            print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
            print(dict_done)
            if prev_pos != pos:
                valid_active_actions_done += 1
            iStep += 1
            
            if show:
                env_renderer.render_env(show=show, show_agents=True, step=iStep)
                time.sleep(0.01)
            assert iStep < 100, "valid actions should have been performed by now - hung agent"

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        #assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            rail_env.agents[0].direction = 0

            # JW - to avoid problem with random_schedule_generator.
            #rail_env.agents[0].position = (1,2)

            iStep = 0
            while iStep < 100:
                # We randomly select an action
                action = np.random.choice(liActions)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
                if done:
                    break
                iStep +=1
                assert iStep < 100, "agent should have finished by now"
                env_renderer.render_env(show=show)
Ejemplo n.º 20
0
def test_build_railway_infrastructure():
    rail_trans = RailEnvTransitions()
    grid_map = GridTransitionMap(width=20, height=20, transitions=rail_trans)
    grid_map.grid.fill(0)

    # Make connection with dead-ends on both sides
    start_point = (2, 2)
    end_point = (8, 8)
    connection_001 = connect_rail_in_grid_map(grid_map,
                                              start_point,
                                              end_point,
                                              rail_trans,
                                              flip_start_node_trans=True,
                                              flip_end_node_trans=True,
                                              respect_transition_validity=True,
                                              forbidden_cells=None)
    connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7),
                               (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8),
                               (8, 8)]

    # Make connection with open ends on both sides
    start_point = (1, 3)
    end_point = (1, 7)
    connection_002 = connect_rail_in_grid_map(grid_map,
                                              start_point,
                                              end_point,
                                              rail_trans,
                                              flip_start_node_trans=False,
                                              flip_end_node_trans=False,
                                              respect_transition_validity=True,
                                              forbidden_cells=None)
    connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)]

    # Make connection with open end at beginning and dead end on end
    start_point = (6, 2)
    end_point = (6, 5)
    connection_003 = connect_rail_in_grid_map(grid_map,
                                              start_point,
                                              end_point,
                                              rail_trans,
                                              flip_start_node_trans=False,
                                              flip_end_node_trans=True,
                                              respect_transition_validity=True,
                                              forbidden_cells=None)
    connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)]

    # Make connection with dead end on start and opend end
    start_point = (7, 5)
    end_point = (8, 9)
    connection_004 = connect_rail_in_grid_map(grid_map,
                                              start_point,
                                              end_point,
                                              rail_trans,
                                              flip_start_node_trans=True,
                                              flip_end_node_trans=False,
                                              respect_transition_validity=True,
                                              forbidden_cells=None)
    connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)]

    assert connection_001 == connection_001_expected, \
        "actual={}, expected={}".format(connection_001, connection_001_expected)
    assert connection_002 == connection_002_expected, \
        "actual={}, expected={}".format(connection_002, connection_002_expected)
    assert connection_003 == connection_003_expected, \
        "actual={}, expected={}".format(connection_003, connection_003_expected)
    assert connection_004 == connection_004_expected, \
        "actual={}, expected={}".format(connection_004, connection_004_expected)

    grid_map_grid_expected = [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1025, 1025, 1025, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [
            0, 0, 4, 1025, 1025, 1025, 1025, 1025, 4608, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0
        ], [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [
            0, 0, 0, 1025, 1025, 256, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0
        ],
        [
            0, 0, 0, 0, 0, 4, 1025, 1025, 33825, 4608, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0
        ], [0, 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ]
    for i in range(len(grid_map_grid_expected)):
        assert np.all(grid_map.grid[i] == grid_map_grid_expected[i])
Ejemplo n.º 21
0
def test_fix_inner_nodes():
    rail_trans = RailEnvTransitions()
    grid_map = GridTransitionMap(width=6, height=10, transitions=rail_trans)
    grid_map.grid.fill(0)

    start = (2, 2)
    target = (8, 2)
    parallel_start = (3, 3)
    parallel_target = (7, 3)
    parallel_start_1 = (4, 4)
    parallel_target_1 = (6, 4)

    inner_nodes = [
        start, target, parallel_start, parallel_target, parallel_start_1,
        parallel_target_1
    ]
    track_0 = connect_straight_line_in_grid_map(grid_map, start, target,
                                                rail_trans)
    track_1 = connect_straight_line_in_grid_map(grid_map, parallel_start,
                                                parallel_target, rail_trans)
    track_2 = connect_straight_line_in_grid_map(grid_map, parallel_start_1,
                                                parallel_target_1, rail_trans)

    # Fix the ends of the inner node
    # This is not a fix in transition type but rather makes the necessary connections to the parallel tracks

    for node in inner_nodes:
        fix_inner_nodes(grid_map, node, rail_trans)

    def orienation(pos):
        if pos[0] < grid_map.grid.shape[0] / 2:
            return 2
        else:
            return 0

    # Fix all the different transitions to legal elements
    for c in range(grid_map.grid.shape[1]):
        for r in range(grid_map.grid.shape[0]):
            grid_map.fix_transitions((r, c), orienation((r, c)))
            # Print for assertion tests
            # print("assert grid_map.grid[{}] == {}".format((r,c),grid_map.grid[(r,c)]))

    assert grid_map.grid[(1, 0)] == 0
    assert grid_map.grid[(2, 0)] == 0
    assert grid_map.grid[(3, 0)] == 0
    assert grid_map.grid[(4, 0)] == 0
    assert grid_map.grid[(5, 0)] == 0
    assert grid_map.grid[(6, 0)] == 0
    assert grid_map.grid[(7, 0)] == 0
    assert grid_map.grid[(8, 0)] == 0
    assert grid_map.grid[(9, 0)] == 0
    assert grid_map.grid[(0, 1)] == 0
    assert grid_map.grid[(1, 1)] == 0
    assert grid_map.grid[(2, 1)] == 0
    assert grid_map.grid[(3, 1)] == 0
    assert grid_map.grid[(4, 1)] == 0
    assert grid_map.grid[(5, 1)] == 0
    assert grid_map.grid[(6, 1)] == 0
    assert grid_map.grid[(7, 1)] == 0
    assert grid_map.grid[(8, 1)] == 0
    assert grid_map.grid[(9, 1)] == 0
    assert grid_map.grid[(0, 2)] == 0
    assert grid_map.grid[(1, 2)] == 0
    assert grid_map.grid[(2, 2)] == 8192
    assert grid_map.grid[(3, 2)] == 49186
    assert grid_map.grid[(4, 2)] == 32800
    assert grid_map.grid[(5, 2)] == 32800
    assert grid_map.grid[(6, 2)] == 32800
    assert grid_map.grid[(7, 2)] == 32872
    assert grid_map.grid[(8, 2)] == 128
    assert grid_map.grid[(9, 2)] == 0
    assert grid_map.grid[(0, 3)] == 0
    assert grid_map.grid[(1, 3)] == 0
    assert grid_map.grid[(2, 3)] == 0
    assert grid_map.grid[(3, 3)] == 4608
    assert grid_map.grid[(4, 3)] == 49186
    assert grid_map.grid[(5, 3)] == 32800
    assert grid_map.grid[(6, 3)] == 32872
    assert grid_map.grid[(7, 3)] == 2064
    assert grid_map.grid[(8, 3)] == 0
    assert grid_map.grid[(9, 3)] == 0
    assert grid_map.grid[(0, 4)] == 0
    assert grid_map.grid[(1, 4)] == 0
    assert grid_map.grid[(2, 4)] == 0
    assert grid_map.grid[(3, 4)] == 0
    assert grid_map.grid[(4, 4)] == 4608
    assert grid_map.grid[(5, 4)] == 32800
    assert grid_map.grid[(6, 4)] == 2064
    assert grid_map.grid[(7, 4)] == 0
    assert grid_map.grid[(8, 4)] == 0
    assert grid_map.grid[(9, 4)] == 0
    assert grid_map.grid[(0, 5)] == 0
    assert grid_map.grid[(1, 5)] == 0
    assert grid_map.grid[(2, 5)] == 0
    assert grid_map.grid[(3, 5)] == 0
    assert grid_map.grid[(4, 5)] == 0
    assert grid_map.grid[(5, 5)] == 0
    assert grid_map.grid[(6, 5)] == 0
    assert grid_map.grid[(7, 5)] == 0
    assert grid_map.grid[(8, 5)] == 0
    assert grid_map.grid[(9, 5)] == 0
Ejemplo n.º 22
0
    def generator() -> RailGenerator:
        """
        Arguments are ignored and taken directly from the curriculum except the np_random: RandomState which is the last
        argument (args[-1])
        """

        if curriculum.get("n_agents") > curriculum.get("n_cities"):
            raise Exception("complex_rail_generator: n_agents > n_cities!")
        grid_map = GridTransitionMap(width=curriculum.get("x_dim"),
                                     height=curriculum.get("y_size"),
                                     transitions=RailEnvTransitions())
        rail_array = grid_map.grid
        rail_array.fill(0)

        # generate rail array
        # step 1:
        # - generate a start and goal position
        #   - validate min/max distance allowed
        #   - validate that start/goals are not placed too close to other start/goals
        #   - draw a rail from [start,goal]
        #     - if rail crosses existing rail then validate new connection
        #     - possibility that this fails to create a path to goal
        #     - on failure generate new start/goal
        #
        # step 2:
        # - add more rails to map randomly between cells that have rails
        #   - validate all new rails, on failure don't add new rails
        #
        # step 3:
        # - return transition map + list of [start_pos, start_dir, goal_pos] points
        #

        rail_trans = grid_map.transitions
        start_goal = []
        start_dir = []
        nr_created = 0
        created_sanity = 0
        sanity_max = 9000

        free_cells = set([(r, c) for r, row in enumerate(rail_array)
                          for c, col in enumerate(row) if col == 0])

        while nr_created < curriculum.get(
                "n_cities") and created_sanity < sanity_max:
            all_ok = False

            if len(free_cells) == 0:
                break

            for _ in range(sanity_max):
                start = random.sample(free_cells, 1)[0]
                goal = random.sample(free_cells, 1)[0]

                # check min/max distance
                dist_sg = distance_on_rail(start, goal)
                if dist_sg < curriculum.get("min_dist"):
                    continue
                if dist_sg > curriculum.get("max_dist"):
                    continue
                # check distance to existing points
                sg_new = [start, goal]

                def check_all_dist():
                    """
                    Function to check the distance betweens start and goal
                    :param sg_new: start and goal tuple
                    :return: True if distance is larger than 2, False otherwise
                    """
                    for sg in start_goal:
                        for i in range(2):
                            for j in range(2):
                                dist = distance_on_rail(sg_new[i], sg[j])
                                if dist < 2:
                                    return False
                    return True

                if check_all_dist():
                    all_ok = True
                    free_cells.remove(start)
                    free_cells.remove(goal)
                    break

            if not all_ok:
                # we might as well give up at this point
                break

            new_path = connect_rail_in_grid_map(
                grid_map,
                start,
                goal,
                rail_trans,
                Vec2d.get_chebyshev_distance,
                flip_start_node_trans=True,
                flip_end_node_trans=True,
                respect_transition_validity=True,
                forbidden_cells=None)
            if len(new_path) >= 2:
                nr_created += 1
                start_goal.append([start, goal])
                start_dir.append(
                    mirror(get_direction(new_path[0], new_path[1])))
            else:
                # after too many failures we will give up
                created_sanity += 1

        # add extra connections between existing rail
        created_sanity = 0
        nr_created = 0
        while nr_created < curriculum.get(
                "n_extra") and created_sanity < sanity_max:
            if len(free_cells) == 0:
                break

            for _ in range(sanity_max):
                start = random.sample(free_cells, 1)[0]
                goal = random.sample(free_cells, 1)[0]

            new_path = connect_rail_in_grid_map(
                grid_map,
                start,
                goal,
                rail_trans,
                Vec2d.get_chebyshev_distance,
                flip_start_node_trans=True,
                flip_end_node_trans=True,
                respect_transition_validity=True,
                forbidden_cells=None)

            if len(new_path) >= 2:
                nr_created += 1
            else:
                # after too many failures we will give up
                created_sanity += 1

        return grid_map, {
            'agents_hints': {
                'start_goal': start_goal,
                'start_dir': start_dir
            }
        }
    def fix_transitions(self,
                        rcPos: IntVector2DArray,
                        direction: IntVector2D = -1):
        """
        Fixes broken transitions
        """
        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
        grcPos = array(rcPos)
        grcMax = self.grid.shape
        # Transition elements
        transitions = RailEnvTransitions()
        cells = transitions.transition_list
        simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
        simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
        symmetrical = cells[6]
        double_slip = cells[5]
        three_way_transitions = [
            simple_switch_east_south, simple_switch_west_south
        ]
        # loop over available outbound directions (indices) for rcPos

        incoming_connections = np.zeros(4)
        for iDirOut in np.arange(4):
            gdRC = gDir2dRC[iDirOut]  # row,col increment
            gPos2 = grcPos + gdRC  # next cell in that direction

            # Check the adjacent cell is within bounds
            # if not, then ignore it for the count of incoming connections
            if np.any(gPos2 < 0):
                continue
            if np.any(gPos2 >= grcMax):
                continue

            # Get the transitions out of gPos2, using iDirOut as the inbound direction
            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
            connected = 0
            for orientation in range(4):
                connected += self.get_transition(
                    (gPos2[0], gPos2[1], orientation), mirror(iDirOut))
            if connected > 0:
                incoming_connections[iDirOut] = 1

        number_of_incoming = np.sum(incoming_connections)
        # Only one incoming direction --> Straight line set deadend
        if number_of_incoming == 1:
            if self.get_full_transitions(*rcPos) == 0:
                self.set_transitions(rcPos, 0)
            else:
                self.set_transitions(rcPos, 0)

                for direction in range(4):
                    if incoming_connections[direction] > 0:
                        self.set_transition(
                            (rcPos[0], rcPos[1], mirror(direction)), direction,
                            1)
        # Connect all incoming connections
        if number_of_incoming == 2:
            self.set_transitions(rcPos, 0)

            connect_directions = np.argwhere(incoming_connections > 0)
            self.set_transition(
                (rcPos[0], rcPos[1], mirror(connect_directions[0])),
                connect_directions[1], 1)
            self.set_transition(
                (rcPos[0], rcPos[1], mirror(connect_directions[1])),
                connect_directions[0], 1)

        # Find feasible connection for three entries
        if number_of_incoming == 3:
            self.set_transitions(rcPos, 0)
            hole = np.argwhere(incoming_connections < 1)[0][0]
            if direction >= 0:
                switch_type_idx = (direction - hole + 3) % 4
                if switch_type_idx == 0:
                    transition = simple_switch_west_south
                elif switch_type_idx == 2:
                    transition = simple_switch_east_south
                else:
                    transition = self.random_generator.choice(
                        three_way_transitions, 1)
            else:
                transition = self.random_generator.choice(
                    three_way_transitions, 1)
            transition = transitions.rotate_transition(transition,
                                                       int(hole * 90))
            self.set_transitions((rcPos[0], rcPos[1]), transition)

        # Make a double slip switch
        if number_of_incoming == 4:
            rotation = self.random_generator.randint(2)
            transition = transitions.rotate_transition(double_slip,
                                                       int(rotation * 90))
            self.set_transitions((rcPos[0], rcPos[1]), transition)
        return True
def test_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()

    # directions:
    #            'N': 0
    #            'E': 1
    #            'S': 2
    #            'W': 3

    for i in range(2):
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               i) == (1, 1, 0, 0))
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               2 + i) == (0, 0, 1, 1))

    no_transition_cell = int('0000000000000000', 2)

    for i in range(4):
        assert (rail_env_trans.get_transitions(no_transition_cell,
                                               i) == (0, 0, 0, 0))

    # Facing south, going south
    north_south_transition = rail_env_trans.set_transitions(
        no_transition_cell, 2, (0, 0, 1, 0))
    assert (rail_env_trans.set_transition(north_south_transition, 2, 2,
                                          0) == no_transition_cell)
    assert (rail_env_trans.get_transition(north_south_transition, 2, 2))

    # Facing north, going east
    south_east_transition = \
        rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
    assert (rail_env_trans.get_transition(south_east_transition, 0, 1))

    # The opposite transitions are not feasible
    assert (not rail_env_trans.get_transition(north_south_transition, 2, 0))
    assert (not rail_env_trans.get_transition(south_east_transition, 2, 1))

    east_west_transition = rail_env_trans.rotate_transition(
        north_south_transition, 90)
    north_west_transition = rail_env_trans.rotate_transition(
        south_east_transition, 180)

    # Facing west, going west
    assert (rail_env_trans.get_transition(east_west_transition, 3, 3))
    # Facing south, going west
    assert (rail_env_trans.get_transition(north_west_transition, 2, 3))

    assert (south_east_transition == rail_env_trans.rotate_transition(
        south_east_transition, 360))
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
        t_utils = RailEnvTransitions()

        transition_probability = cell_type_relative_proportion

        transitions_templates_ = []
        transition_probabilities = []
        for i in range(len(t_utils.transitions)):  # don't include dead-ends
            if t_utils.transitions[i] == int('0010000000000000', 2):
                continue

            all_transitions = 0
            for dir_ in range(4):
                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
                all_transitions |= (trans[0] << 3) | \
                                   (trans[1] << 2) | \
                                   (trans[2] << 1) | \
                                   (trans[3])

            template = [int(x) for x in bin(all_transitions)[2:]]
            template = [0] * (4 - len(template)) + template

            # add all rotations
            for rot in [0, 90, 180, 270]:
                transitions_templates_.append(
                    (template,
                     t_utils.rotate_transition(t_utils.transitions[i], rot)))
                transition_probabilities.append(transition_probability[i])
                template = [template[-1]] + template[:-1]

        def get_matching_templates(template):
            """
            Returns a list of possible transition maps for a given template

            Parameters:
            ------
            template:List[int]

            Returns:
            ------
            List[int]
            """
            ret = []
            for i in range(len(transitions_templates_)):
                is_match = True
                for j in range(4):
                    if template[j] >= 0 and template[
                            j] != transitions_templates_[i][0][j]:
                        is_match = False
                        break
                if is_match:
                    ret.append((transitions_templates_[i][1],
                                transition_probabilities[i]))
            return ret

        MAX_INSERTIONS = (width - 2) * (height - 2) * 10
        MAX_ATTEMPTS_FROM_SCRATCH = 10

        attempt_number = 0
        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
            cells_to_fill = []
            rail = []
            for r in range(height):
                rail.append([None] * width)
                if r > 0 and r < height - 1:
                    cells_to_fill = cells_to_fill + [
                        (r, c) for c in range(1, width - 1)
                    ]

            num_insertions = 0
            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
                cell = cells_to_fill[np_random.choice(len(cells_to_fill),
                                                      1)[0]]
                cells_to_fill.remove(cell)
                row = cell[0]
                col = cell[1]

                # look at its neighbors and see what are the possible transitions
                # that can be chosen from, if any.
                valid_template = [-1, -1, -1, -1]

                for el in [(0, 2, (-1, 0)), (1, 3, (0, 1)), (2, 0, (1, 0)),
                           (3, 1, (0, -1))]:  # N, E, S, W
                    neigh_trans = rail[row + el[2][0]][col + el[2][1]]
                    if neigh_trans is not None:
                        # select transition coming from facing direction el[1] and
                        # moving to direction el[1]
                        max_bit = 0
                        for k in range(4):
                            max_bit |= t_utils.get_transition(
                                neigh_trans, k, el[1])

                        if max_bit:
                            valid_template[el[0]] = 1
                        else:
                            valid_template[el[0]] = 0

                possible_cell_transitions = get_matching_templates(
                    valid_template)

                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
                    # no cell can be filled in without violating some transitions
                    # can a dead-end solve the problem?
                    if valid_template.count(1) == 1:
                        for k in range(4):
                            if valid_template[k] == 1:
                                rot = 0
                                if k == 0:
                                    rot = 180
                                elif k == 1:
                                    rot = 270
                                elif k == 2:
                                    rot = 0
                                elif k == 3:
                                    rot = 90

                                rail[row][col] = t_utils.rotate_transition(
                                    int('0010000000000000', 2), rot)
                                num_insertions += 1

                                break

                    else:
                        # can I get valid transitions by removing a single
                        # neighboring cell?
                        bestk = -1
                        besttrans = []
                        for k in range(4):
                            tmp_template = valid_template[:]
                            tmp_template[k] = -1
                            possible_cell_transitions = get_matching_templates(
                                tmp_template)
                            if len(possible_cell_transitions) > len(besttrans):
                                besttrans = possible_cell_transitions
                                bestk = k

                        if bestk >= 0:
                            # Replace the corresponding cell with None, append it
                            # to cells to fill, fill in a transition in the current
                            # cell.
                            replace_row = row - 1
                            replace_col = col
                            if bestk == 1:
                                replace_row = row
                                replace_col = col + 1
                            elif bestk == 2:
                                replace_row = row + 1
                                replace_col = col
                            elif bestk == 3:
                                replace_row = row
                                replace_col = col - 1

                            cells_to_fill.append((replace_row, replace_col))
                            rail[replace_row][replace_col] = None

                            possible_transitions, possible_probabilities = zip(
                                *besttrans)
                            possible_probabilities = [
                                p / sum(possible_probabilities)
                                for p in possible_probabilities
                            ]

                            rail[row][col] = np_random.choice(
                                possible_transitions, p=possible_probabilities)
                            num_insertions += 1

                        else:
                            print('WARNING: still nothing!')
                            rail[row][col] = int('0000000000000000', 2)
                            num_insertions += 1
                            pass

                else:
                    possible_transitions, possible_probabilities = zip(
                        *possible_cell_transitions)
                    possible_probabilities = [
                        p / sum(possible_probabilities)
                        for p in possible_probabilities
                    ]

                    rail[row][col] = np_random.choice(possible_transitions,
                                                      p=possible_probabilities)
                    num_insertions += 1

            if num_insertions == MAX_INSERTIONS:
                # Failed to generate a valid level; try again for a number of times
                attempt_number += 1
            else:
                break

        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
            print('ERROR: failed to generate level')

        # Finally pad the border of the map with dead-ends to avoid border issues;
        # at most 1 transition in the neigh cell
        for r in range(height):
            # Check for transitions coming from [r][1] to WEST
            max_bit = 0
            neigh_trans = rail[r][1]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & 1)
            if max_bit:
                rail[r][0] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 270)
            else:
                rail[r][0] = int('0000000000000000', 2)

            # Check for transitions coming from [r][-2] to EAST
            max_bit = 0
            neigh_trans = rail[r][-2]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
            if max_bit:
                rail[r][-1] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 90)
            else:
                rail[r][-1] = int('0000000000000000', 2)

        for c in range(width):
            # Check for transitions coming from [1][c] to NORTH
            max_bit = 0
            neigh_trans = rail[1][c]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
            if max_bit:
                rail[0][c] = int('0010000000000000', 2)
            else:
                rail[0][c] = int('0000000000000000', 2)

            # Check for transitions coming from [-2][c] to SOUTH
            max_bit = 0
            neigh_trans = rail[-2][c]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
            if max_bit:
                rail[-1][c] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 180)
            else:
                rail[-1][c] = int('0000000000000000', 2)

        # For display only, wrong levels
        for r in range(height):
            for c in range(width):
                if rail[r][c] is None:
                    rail[r][c] = int('0000000000000000', 2)

        tmp_rail = np.asarray(rail, dtype=np.uint16)

        return_rail = GridTransitionMap(width=width,
                                        height=height,
                                        transitions=t_utils)
        return_rail.grid = tmp_rail

        return return_rail, None
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:

        if num_agents > nr_start_goal:
            num_agents = nr_start_goal
            print(
                "complex_rail_generator: num_agents > nr_start_goal, changing num_agents"
            )
        grid_map = GridTransitionMap(width=width,
                                     height=height,
                                     transitions=RailEnvTransitions())
        rail_array = grid_map.grid
        rail_array.fill(0)

        # generate rail array
        # step 1:
        # - generate a start and goal position
        #   - validate min/max distance allowed
        #   - validate that start/goals are not placed too close to other start/goals
        #   - draw a rail from [start,goal]
        #     - if rail crosses existing rail then validate new connection
        #     - possibility that this fails to create a path to goal
        #     - on failure generate new start/goal
        #
        # step 2:
        # - add more rails to map randomly between cells that have rails
        #   - validate all new rails, on failure don't add new rails
        #
        # step 3:
        # - return transition map + list of [start_pos, start_dir, goal_pos] points
        #

        rail_trans = grid_map.transitions
        start_goal = []
        start_dir = []
        nr_created = 0
        created_sanity = 0
        sanity_max = 9000
        while nr_created < nr_start_goal and created_sanity < sanity_max:
            all_ok = False
            for _ in range(sanity_max):
                start = (np_random.randint(0, height),
                         np_random.randint(0, width))
                goal = (np_random.randint(0,
                                          height), np_random.randint(0, width))

                # check to make sure start,goal pos is empty?
                if rail_array[goal] != 0 or rail_array[start] != 0:
                    continue
                # check min/max distance
                dist_sg = distance_on_rail(start, goal)
                if dist_sg < min_dist:
                    continue
                if dist_sg > max_dist:
                    continue
                # check distance to existing points
                sg_new = [start, goal]

                def check_all_dist(sg_new):
                    """
                    Function to check the distance betweens start and goal
                    :param sg_new: start and goal tuple
                    :return: True if distance is larger than 2, False otherwise
                    """
                    for sg in start_goal:
                        for i in range(2):
                            for j in range(2):
                                dist = distance_on_rail(sg_new[i], sg[j])
                                if dist < 2:
                                    return False
                    return True

                if check_all_dist(sg_new):
                    all_ok = True
                    break

            if not all_ok:
                # we might as well give up at this point
                break

            new_path = connect_rail_in_grid_map(
                grid_map,
                start,
                goal,
                rail_trans,
                Vec2d.get_chebyshev_distance,
                flip_start_node_trans=True,
                flip_end_node_trans=True,
                respect_transition_validity=True,
                forbidden_cells=None)
            if len(new_path) >= 2:
                nr_created += 1
                start_goal.append([start, goal])
                start_dir.append(
                    mirror(get_direction(new_path[0], new_path[1])))
            else:
                # after too many failures we will give up
                created_sanity += 1

        # add extra connections between existing rail
        created_sanity = 0
        nr_created = 0
        while nr_created < nr_extra and created_sanity < sanity_max:
            all_ok = False
            for _ in range(sanity_max):
                start = (np_random.randint(0, height),
                         np_random.randint(0, width))
                goal = (np_random.randint(0,
                                          height), np_random.randint(0, width))
                # check to make sure start,goal pos are not empty
                if rail_array[goal] == 0 or rail_array[start] == 0:
                    continue
                else:
                    all_ok = True
                    break
            if not all_ok:
                break
            new_path = connect_rail_in_grid_map(
                grid_map,
                start,
                goal,
                rail_trans,
                Vec2d.get_chebyshev_distance,
                flip_start_node_trans=True,
                flip_end_node_trans=True,
                respect_transition_validity=True,
                forbidden_cells=None)

            if len(new_path) >= 2:
                nr_created += 1
            else:
                # after too many failures we will give up
                created_sanity += 1

        return grid_map, {
            'agents_hints': {
                'start_goal': start_goal,
                'start_dir': start_dir
            }
        }
def test_rotate_railenv_transition():
    rail_env_transitions = RailEnvTransitions()

    # TODO test all cases
    transition_cycles = [
        # empty cell - Case 0
        [
            int('0000000000000000', 2),
            int('0000000000000000', 2),
            int('0000000000000000', 2),
            int('0000000000000000', 2)
        ],
        # Case 1 - straight
        #     |
        #     |
        #     |
        [int(rw('1000 0000 0010 0000'), 2),
         int(rw('0000 0100 0000 0001'), 2)],
        # Case 1b (8)  - simple turn right
        #      _
        #     |
        #     |
        [
            int(rw('0100 0000 0000 0010'), 2),
            int(rw('0001 0010 0000 0000'), 2),
            int(rw('0000 1000 0001 0000'), 2),
            int(rw('0000 0000 0100 1000'), 2),
        ],
        # Case 1c (9)  - simple turn left
        #    _
        #     |
        #     |

        # int('0001001000000000', 2),\ #  noqa: E800

        # Case 2 - simple left switch
        #  _ _|
        #     |
        #     |
        [
            int(rw('1001 0010 0010 0000'), 2),
            int(rw('0000 1100 0001 0001'), 2),
            int(rw('1000 0000 0110 1000'), 2),
            int(rw('0100 0100 0000 0011'), 2),
        ],
        # Case 2b (10) - simple right switch
        #     |
        #     |
        #     |

        # int('1100000000100010', 2) \ #  noqa: E800

        # Case 3 - diamond drossing
        # int('1000010000100001', 2),   \ #  noqa: E800
        # Case 4 - single slip
        # int('1001011000100001', 2),   \ #  noqa: E800
        # Case 5 - double slip
        # int('1100110000110011', 2),   \ #  noqa: E800
        # Case 6 - symmetrical
        # int('0101001000000010', 2),   \ #  noqa: E800

        # Case 7 - dead end
        #
        #
        #     |
        [
            int(rw('0010 0000 0000 0000'), 2),
            int(rw('0000 0001 0000 0000'), 2),
            int(rw('0000 0000 1000 0000'), 2),
            int(rw('0000 0000 0000 0100'), 2),
        ],
    ]

    for index, cycle in enumerate(transition_cycles):
        for i in range(4):
            actual_transition = rail_env_transitions.rotate_transition(
                cycle[0], i * 90)
            expected_transition = cycle[i % len(cycle)]
            try:
                assert actual_transition == expected_transition, \
                    "Case {}: rotate_transition({}, {}) should equal {} but was {}.".format(
                        i, cycle[0], i, expected_transition, actual_transition)
            except Exception as e:
                print("expected:")
                rail_env_transitions.print(expected_transition)
                print("actual:")
                rail_env_transitions.print(actual_transition)

                raise e