Пример #1
0
    def _fix_transitions(city_cells: IntVector2DArray,
                         inter_city_lines: List[IntVector2DArray],
                         grid_map: GridTransitionMap, vector_field):
        """
        Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity
        while drawing the rails.

        Parameters
        ----------
        city_cells: IntVector2DArray
            Cells within cities. All of these might have changed and are thus checked
        inter_city_lines: List[IntVector2DArray]
            All cells within rails drawn between cities
        vector_field: IntVector2DArray
            Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell.
            Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
        grid_map: RailEnvTransitions
            The grid map containing the rails. Used to draw new rails

        """

        # Fix all cities with illegal transition maps
        rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2,
                                dtype='int')
        rails_to_fix_cnt = 0
        cells_to_fix = city_cells + inter_city_lines
        for cell in cells_to_fix:
            try:
                cell_valid = grid_map.cell_neighbours_valid(cell, True)
            except:
                import pdb
                pdb.set_trace()

            if not cell_valid:
                rails_to_fix[3 * rails_to_fix_cnt] = cell[0]
                rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1]
                rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell]

                rails_to_fix_cnt += 1
        # Fix all other cells
        for cell in range(rails_to_fix_cnt):
            grid_map.fix_transitions(
                (rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]),
                rails_to_fix[3 * cell + 2])
Пример #2
0
def test_fix_inner_nodes():
    rail_trans = RailEnvTransitions()
    grid_map = GridTransitionMap(width=6, height=10, transitions=rail_trans)
    grid_map.grid.fill(0)

    start = (2, 2)
    target = (8, 2)
    parallel_start = (3, 3)
    parallel_target = (7, 3)
    parallel_start_1 = (4, 4)
    parallel_target_1 = (6, 4)

    inner_nodes = [
        start, target, parallel_start, parallel_target, parallel_start_1,
        parallel_target_1
    ]
    track_0 = connect_straight_line_in_grid_map(grid_map, start, target,
                                                rail_trans)
    track_1 = connect_straight_line_in_grid_map(grid_map, parallel_start,
                                                parallel_target, rail_trans)
    track_2 = connect_straight_line_in_grid_map(grid_map, parallel_start_1,
                                                parallel_target_1, rail_trans)

    # Fix the ends of the inner node
    # This is not a fix in transition type but rather makes the necessary connections to the parallel tracks

    for node in inner_nodes:
        fix_inner_nodes(grid_map, node, rail_trans)

    def orienation(pos):
        if pos[0] < grid_map.grid.shape[0] / 2:
            return 2
        else:
            return 0

    # Fix all the different transitions to legal elements
    for c in range(grid_map.grid.shape[1]):
        for r in range(grid_map.grid.shape[0]):
            grid_map.fix_transitions((r, c), orienation((r, c)))
            # Print for assertion tests
            # print("assert grid_map.grid[{}] == {}".format((r,c),grid_map.grid[(r,c)]))

    assert grid_map.grid[(1, 0)] == 0
    assert grid_map.grid[(2, 0)] == 0
    assert grid_map.grid[(3, 0)] == 0
    assert grid_map.grid[(4, 0)] == 0
    assert grid_map.grid[(5, 0)] == 0
    assert grid_map.grid[(6, 0)] == 0
    assert grid_map.grid[(7, 0)] == 0
    assert grid_map.grid[(8, 0)] == 0
    assert grid_map.grid[(9, 0)] == 0
    assert grid_map.grid[(0, 1)] == 0
    assert grid_map.grid[(1, 1)] == 0
    assert grid_map.grid[(2, 1)] == 0
    assert grid_map.grid[(3, 1)] == 0
    assert grid_map.grid[(4, 1)] == 0
    assert grid_map.grid[(5, 1)] == 0
    assert grid_map.grid[(6, 1)] == 0
    assert grid_map.grid[(7, 1)] == 0
    assert grid_map.grid[(8, 1)] == 0
    assert grid_map.grid[(9, 1)] == 0
    assert grid_map.grid[(0, 2)] == 0
    assert grid_map.grid[(1, 2)] == 0
    assert grid_map.grid[(2, 2)] == 8192
    assert grid_map.grid[(3, 2)] == 49186
    assert grid_map.grid[(4, 2)] == 32800
    assert grid_map.grid[(5, 2)] == 32800
    assert grid_map.grid[(6, 2)] == 32800
    assert grid_map.grid[(7, 2)] == 32872
    assert grid_map.grid[(8, 2)] == 128
    assert grid_map.grid[(9, 2)] == 0
    assert grid_map.grid[(0, 3)] == 0
    assert grid_map.grid[(1, 3)] == 0
    assert grid_map.grid[(2, 3)] == 0
    assert grid_map.grid[(3, 3)] == 4608
    assert grid_map.grid[(4, 3)] == 49186
    assert grid_map.grid[(5, 3)] == 32800
    assert grid_map.grid[(6, 3)] == 32872
    assert grid_map.grid[(7, 3)] == 2064
    assert grid_map.grid[(8, 3)] == 0
    assert grid_map.grid[(9, 3)] == 0
    assert grid_map.grid[(0, 4)] == 0
    assert grid_map.grid[(1, 4)] == 0
    assert grid_map.grid[(2, 4)] == 0
    assert grid_map.grid[(3, 4)] == 0
    assert grid_map.grid[(4, 4)] == 4608
    assert grid_map.grid[(5, 4)] == 32800
    assert grid_map.grid[(6, 4)] == 2064
    assert grid_map.grid[(7, 4)] == 0
    assert grid_map.grid[(8, 4)] == 0
    assert grid_map.grid[(9, 4)] == 0
    assert grid_map.grid[(0, 5)] == 0
    assert grid_map.grid[(1, 5)] == 0
    assert grid_map.grid[(2, 5)] == 0
    assert grid_map.grid[(3, 5)] == 0
    assert grid_map.grid[(4, 5)] == 0
    assert grid_map.grid[(5, 5)] == 0
    assert grid_map.grid[(6, 5)] == 0
    assert grid_map.grid[(7, 5)] == 0
    assert grid_map.grid[(8, 5)] == 0
    assert grid_map.grid[(9, 5)] == 0