def test_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()

    # directions:
    #            'N': 0
    #            'E': 1
    #            'S': 2
    #            'W': 3

    for i in range(2):
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               i) == (1, 1, 0, 0))
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               2 + i) == (0, 0, 1, 1))

    no_transition_cell = int('0000000000000000', 2)

    for i in range(4):
        assert (rail_env_trans.get_transitions(no_transition_cell,
                                               i) == (0, 0, 0, 0))

    # Facing south, going south
    north_south_transition = rail_env_trans.set_transitions(
        no_transition_cell, 2, (0, 0, 1, 0))
    assert (rail_env_trans.set_transition(north_south_transition, 2, 2,
                                          0) == no_transition_cell)
    assert (rail_env_trans.get_transition(north_south_transition, 2, 2))

    # Facing north, going east
    south_east_transition = \
        rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
    assert (rail_env_trans.get_transition(south_east_transition, 0, 1))

    # The opposite transitions are not feasible
    assert (not rail_env_trans.get_transition(north_south_transition, 2, 0))
    assert (not rail_env_trans.get_transition(south_east_transition, 2, 1))

    east_west_transition = rail_env_trans.rotate_transition(
        north_south_transition, 90)
    north_west_transition = rail_env_trans.rotate_transition(
        south_east_transition, 180)

    # Facing west, going west
    assert (rail_env_trans.get_transition(east_west_transition, 3, 3))
    # Facing south, going west
    assert (rail_env_trans.get_transition(north_west_transition, 2, 3))

    assert (south_east_transition == rail_env_trans.rotate_transition(
        south_east_transition, 360))
 def generator(width: int,
               height: int,
               num_agents: int = 0,
               num_resets: int = 0) -> RailGeneratorProduct:
     rail_trans = RailEnvTransitions()
     grid_map = GridTransitionMap(width=width,
                                  height=height,
                                  transitions=rail_trans)
     rail_array = grid_map.grid
     rail_array.fill(0)
     new_tran = rail_trans.set_transition(1, 1, 1, 1)
     print(new_tran)
     rail_array[0, 0] = new_tran
     rail_array[0, 1] = new_tran
     return grid_map, None