def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray], grid_map: GridTransitionMap, vector_field): """ Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity while drawing the rails. Parameters ---------- city_cells: IntVector2DArray Cells within cities. All of these might have changed and are thus checked inter_city_lines: List[IntVector2DArray] All cells within rails drawn between cities vector_field: IntVector2DArray Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell. Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1 grid_map: RailEnvTransitions The grid map containing the rails. Used to draw new rails """ # Fix all cities with illegal transition maps rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2, dtype='int') rails_to_fix_cnt = 0 cells_to_fix = city_cells + inter_city_lines for cell in cells_to_fix: try: cell_valid = grid_map.cell_neighbours_valid(cell, True) except: import pdb pdb.set_trace() if not cell_valid: rails_to_fix[3 * rails_to_fix_cnt] = cell[0] rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1] rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell] rails_to_fix_cnt += 1 # Fix all other cells for cell in range(rails_to_fix_cnt): grid_map.fix_transitions( (rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
def test_fix_inner_nodes(): rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=6, height=10, transitions=rail_trans) grid_map.grid.fill(0) start = (2, 2) target = (8, 2) parallel_start = (3, 3) parallel_target = (7, 3) parallel_start_1 = (4, 4) parallel_target_1 = (6, 4) inner_nodes = [ start, target, parallel_start, parallel_target, parallel_start_1, parallel_target_1 ] track_0 = connect_straight_line_in_grid_map(grid_map, start, target, rail_trans) track_1 = connect_straight_line_in_grid_map(grid_map, parallel_start, parallel_target, rail_trans) track_2 = connect_straight_line_in_grid_map(grid_map, parallel_start_1, parallel_target_1, rail_trans) # Fix the ends of the inner node # This is not a fix in transition type but rather makes the necessary connections to the parallel tracks for node in inner_nodes: fix_inner_nodes(grid_map, node, rail_trans) def orienation(pos): if pos[0] < grid_map.grid.shape[0] / 2: return 2 else: return 0 # Fix all the different transitions to legal elements for c in range(grid_map.grid.shape[1]): for r in range(grid_map.grid.shape[0]): grid_map.fix_transitions((r, c), orienation((r, c))) # Print for assertion tests # print("assert grid_map.grid[{}] == {}".format((r,c),grid_map.grid[(r,c)])) assert grid_map.grid[(1, 0)] == 0 assert grid_map.grid[(2, 0)] == 0 assert grid_map.grid[(3, 0)] == 0 assert grid_map.grid[(4, 0)] == 0 assert grid_map.grid[(5, 0)] == 0 assert grid_map.grid[(6, 0)] == 0 assert grid_map.grid[(7, 0)] == 0 assert grid_map.grid[(8, 0)] == 0 assert grid_map.grid[(9, 0)] == 0 assert grid_map.grid[(0, 1)] == 0 assert grid_map.grid[(1, 1)] == 0 assert grid_map.grid[(2, 1)] == 0 assert grid_map.grid[(3, 1)] == 0 assert grid_map.grid[(4, 1)] == 0 assert grid_map.grid[(5, 1)] == 0 assert grid_map.grid[(6, 1)] == 0 assert grid_map.grid[(7, 1)] == 0 assert grid_map.grid[(8, 1)] == 0 assert grid_map.grid[(9, 1)] == 0 assert grid_map.grid[(0, 2)] == 0 assert grid_map.grid[(1, 2)] == 0 assert grid_map.grid[(2, 2)] == 8192 assert grid_map.grid[(3, 2)] == 49186 assert grid_map.grid[(4, 2)] == 32800 assert grid_map.grid[(5, 2)] == 32800 assert grid_map.grid[(6, 2)] == 32800 assert grid_map.grid[(7, 2)] == 32872 assert grid_map.grid[(8, 2)] == 128 assert grid_map.grid[(9, 2)] == 0 assert grid_map.grid[(0, 3)] == 0 assert grid_map.grid[(1, 3)] == 0 assert grid_map.grid[(2, 3)] == 0 assert grid_map.grid[(3, 3)] == 4608 assert grid_map.grid[(4, 3)] == 49186 assert grid_map.grid[(5, 3)] == 32800 assert grid_map.grid[(6, 3)] == 32872 assert grid_map.grid[(7, 3)] == 2064 assert grid_map.grid[(8, 3)] == 0 assert grid_map.grid[(9, 3)] == 0 assert grid_map.grid[(0, 4)] == 0 assert grid_map.grid[(1, 4)] == 0 assert grid_map.grid[(2, 4)] == 0 assert grid_map.grid[(3, 4)] == 0 assert grid_map.grid[(4, 4)] == 4608 assert grid_map.grid[(5, 4)] == 32800 assert grid_map.grid[(6, 4)] == 2064 assert grid_map.grid[(7, 4)] == 0 assert grid_map.grid[(8, 4)] == 0 assert grid_map.grid[(9, 4)] == 0 assert grid_map.grid[(0, 5)] == 0 assert grid_map.grid[(1, 5)] == 0 assert grid_map.grid[(2, 5)] == 0 assert grid_map.grid[(3, 5)] == 0 assert grid_map.grid[(4, 5)] == 0 assert grid_map.grid[(5, 5)] == 0 assert grid_map.grid[(6, 5)] == 0 assert grid_map.grid[(7, 5)] == 0 assert grid_map.grid[(8, 5)] == 0 assert grid_map.grid[(9, 5)] == 0