def load_pngs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None): pil = {} transitions = RailEnvTransitions() directions = list("NESW") for transition, file in file_directory.items(): # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) transition_16_bit = ["0"] * 16 for sTran in transition.split(" "): if len(sTran) == 2: in_direction = directions.index(sTran[0]) out_direction = directions.index(sTran[1]) transition_idx = 4 * in_direction + out_direction transition_16_bit[transition_idx] = "1" transition_16_bit_string = "".join(transition_16_bit) binary_trans = int(transition_16_bit_string, 2) pil_rail = self.pil_from_png_file('flatland.png', file).convert("RGBA") if background_image is not None: img_bg = self.pil_from_png_file( 'flatland.png', background_image).convert("RGBA") pil_rail = Image.alpha_composite(img_bg, pil_rail) if whitefilter is not None: img_bg = self.pil_from_png_file('flatland.png', whitefilter).convert("RGBA") pil_rail = Image.alpha_composite(pil_rail, img_bg) if rotate: # For rotations, we also store the base image pil[binary_trans] = pil_rail # Rotate both the transition binary and the image and save in the dict for nRot in [90, 180, 270]: binary_trans_2 = transitions.rotate_transition( binary_trans, nRot) # PIL rotates anticlockwise for positive theta pil_rail_2 = pil_rail.rotate(-nRot) pil[binary_trans_2] = pil_rail_2 if agent_colors: # For recoloring, we don't store the base image. base_color = self.rgb_s2i("d50000") pils = self.recolor_image(pil_rail, base_color, self.agent_colors) for color_idx, pil_rail_2 in enumerate(pils): pil[(binary_trans, color_idx)] = pils[color_idx] return pil
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_env_transitions = RailEnvTransitions() height = len(rail_spec) width = len(rail_spec[0]) rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) for r in range(height): for c in range(width): rail_spec_of_cell = rail_spec[r][c] index_basic_type_of_cell_ = rail_spec_of_cell[0] rotation_cell_ = rail_spec_of_cell[1] if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len( rail_env_transitions.transitions): print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_) return [] basic_type_of_cell_ = rail_env_transitions.transitions[ index_basic_type_of_cell_] effective_transition_cell = rail_env_transitions.rotate_transition( basic_type_of_cell_, rotation_cell_) rail.set_transitions((r, c), effective_transition_cell) return [rail, None]
def test_get_entry_directions(): transitions = RailEnvTransitions() cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition( south_symmetrical_switch, 180) south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) north_west_turn = transitions.rotate_transition(south_east_turn, 180) def _assert(transition, expected): actual = Grid4Transitions.get_entry_directions(transition) assert actual == expected, "Found {}, expected {}.".format( actual, expected) _assert(south_east_turn, [True, False, False, True]) _assert(south_west_turn, [True, True, False, False]) _assert(north_east_turn, [False, False, True, True]) _assert(north_west_turn, [False, True, True, False]) _assert(vertical_line, [True, False, True, False]) _assert(south_symmetrical_switch, [True, True, False, True]) _assert(north_symmetrical_switch, [False, True, True, True])
def compute_all_possible_transitions(): # Bitmaps are read in decimal numbers transitions = RailEnvTransitions() transition_list = transitions.transition_list ''' transition_list = [int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch int('1000010000100001', 2), # Case 3 - diamond drossing int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical int('0010000000000000', 2), # Case 7 - dead end int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2), # Case 1c (9) - simple turn left int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored ''' transitions_with_rotation_dict = {} rotation_degrees = [0, 90, 180, 270] for i in range(len(transition_list)): for r in rotation_degrees: t = transition_list[i] rot_transition = transitions.rotate_transition(t, r) if rot_transition not in transitions_with_rotation_dict: transitions_with_rotation_dict[rot_transition] = np.array( [i, r]) return transitions_with_rotation_dict
def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # 0 1 2 3 4 5 6 7 8 9 10 # 0 /-------------\ # 1 | | # 2 | | # 3 _ _ _ /_ _ _ | # 4 \ ___ / # 5 |/ # 6 | # 7 | transitions = RailEnvTransitions() cells = transitions.transition_list empty = cells[0] dead_end_from_south = cells[7] right_turn_from_south = cells[8] right_turn_from_west = transitions.rotate_transition( right_turn_from_south, 90) right_turn_from_north = transitions.rotate_transition( right_turn_from_south, 180) dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_north = transitions.rotate_transition( dead_end_from_south, 180) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) vertical_straight = cells[1] simple_switch_north_left = cells[2] simple_switch_north_right = cells[10] simple_switch_left_east = transitions.rotate_transition( simple_switch_north_left, 90) horizontal_straight = transitions.rotate_transition(vertical_straight, 90) double_switch_south_horizontal_straight = horizontal_straight + cells[6] double_switch_north_horizontal_straight = transitions.rotate_transition( double_switch_south_horizontal_straight, 180) rail_map = np.array( [[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west]] + [[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight] ] * 2 + [[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [right_turn_from_west] + [empty] * 2 + [vertical_straight]] + [[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] + [[empty] * 6 + [vertical_straight] + [empty] * 3] + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map return rail, rail_map
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | # | # | # _ _ _ _\ _ _ _ _ _ _ # \ # | # | # | transitions = RailEnvTransitions() cells = transitions.transition_list empty = cells[0] dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) vertical_straight = cells[1] horizontal_straight = transitions.rotate_transition(vertical_straight, 90) simple_switch_north_right = cells[10] simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90) rail_map = np.array( [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + [[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_east_west_north] + [horizontal_straight] * 2 + [simple_switch_west_east_south] + [horizontal_straight] * 2 + [dead_end_from_west]] + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map return rail, rail_map
def test_rail_env_has_deadend(): deadends = set([ int(rw('0010 0000 0000 0000'), 2), int(rw('0000 0001 0000 0000'), 2), int(rw('0000 0000 1000 0000'), 2), int(rw('0000 0000 0000 0100'), 2) ]) ret = RailEnvTransitions() transitions_all = ret.transitions_all for t in transitions_all: expected_has_deadend = t in deadends actual_had_deadend = ret.has_deadend(t) assert actual_had_deadend == expected_has_deadend, \ "{} should be deadend = {}, actual = {}".format(t, )
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) new_tran = rail_trans.set_transition(1, 1, 1, 1) print(new_tran) rail_array[0, 0] = new_tran rail_array[0, 1] = new_tran return grid_map, None
def compute_all_possible_transitions(self): ''' Given transitions list considering cell types, outputs all possible transitions bitmap, considering cell rotations too ''' # Bitmaps are read in decimal numbers transitions = RailEnvTransitions() transitions_with_rotation_dict = {} rotation_degrees = [0, 90, 180, 270] for index, transition in enumerate(transitions.transition_list): for rot_type, rot in enumerate(rotation_degrees): rot_transition = transitions.rotate_transition(transition, rot) if rot_transition not in transitions_with_rotation_dict: transitions_with_rotation_dict[rot_transition] = (np.array( [index, rot_type])) return transitions_with_rotation_dict
def test_walker(): # _ _ _ transitions = RailEnvTransitions() cells = transitions.transition_list dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) vertical_straight = cells[1] horizontal_straight = transitions.rotate_transition(vertical_straight, 90) rail_map = np.array( [[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), ) env.reset() # set initial position and direction for testing... env.agents[0].position = (0, 1) env.agents[0].direction = 1 env.agents[0].target = (0, 0) # reset to set agents from agents_static env.reset(False, False) print(env.distance_map.get()[(0, *[0, 1], 1)]) assert env.distance_map.get()[(0, *[0, 1], 1)] == 3 print(env.distance_map.get()[(0, *[0, 2], 3)]) assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) return grid_map, None
def test_dead_end(): transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical, 90) dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end # We instantiate the following railway # O->-- where > is the train and O the target. After 6 steps, # the train should be done. rail_map = np.array( [[transitions.rotate_transition(dead_end_from_south, 270)] + [straight_horizontal] * 3 + [transitions.rotate_transition(dead_end_from_south, 90)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( [[dead_end_from_south]] + [[straight_vertical]] * 3 + [[transitions.rotate_transition(dead_end_from_south, 180)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
def test_is_valid_railenv_transitions(): rail_env_trans = RailEnvTransitions() transition_list = rail_env_trans.transitions for t in transition_list: assert (rail_env_trans.is_valid(t) is True) for i in range(3): rot_trans = rail_env_trans.rotate_transition(t, 90 * i) assert (rail_env_trans.is_valid(rot_trans) is True) assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False) assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False) assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> List: env_dict = persistence.RailEnvPersister.load_env_dict( filename, load_from_package=load_from_package) rail_env_transitions = RailEnvTransitions() grid = np.array(env_dict["grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid if "distance_map" in env_dict: distance_map = env_dict["distance_map"] if len(distance_map) > 0: return rail, {'distance_map': distance_map} return [rail, None]
def test_adding_new_valid_transition(): rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans) # adding straight assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn grid_map.grid[(5, 5)] = rail_trans.transitions[2] assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid grid_map.grid[(5, 5)] = rail_trans.transitions[3] assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn grid_map.grid[(5, 5)] = rail_trans.transitions[7] assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition grid_map.grid[(5, 5)] = rail_trans.transitions[0] assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition grid_map.grid[(5, 5)] = rail_trans.transitions[0] assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = np.random) -> RailGenerator: """ Parameters ---------- width: int Width of the environment height: int Height of the environment num_agents: Number of agents to be placed within the environment num_resets: int Count for how often the environment has been reset Returns ------- grid_map: """ rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, \ transitions=rail_trans) vector_field = np.zeros(shape=(height, width)) - 1. city_padding = 2 city_radius = city_padding rails_between_cities = 1 rails_in_city = 2 np_random.seed(seed) # Calculate the max number of cities allowed # and reduce the number of cities to build to avoid problems max_feasible_cities = min(num_cities, ((height - 2) // (2 * (city_radius + 1))) * \ ((width - 2) // (2 * (city_radius + 1)))) if max_feasible_cities < num_cities: sys.exit( f"[ABORT] Cannot fit more than {max_feasible_cities} city in this map, no feasible environment possible! Aborting." ) # obtain city positions city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, \ city_radius, width, height) # Set up connection points for all cities inner_connection_points, outer_connection_points, city_orientations, city_cells = \ _generate_city_connection_points( city_positions, city_radius, vector_field, rails_between_cities, rails_in_city, np_random=np_random) # import pdb; pdb.set_trace() # connect the cities through the connection points inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells, rail_trans, grid_map) # Build inner cities free_rails = _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans, grid_map) # Populate cities train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails) # Fix all transition elements _fix_transitions(city_cells, inter_city_lines, grid_map, vector_field) return grid_map, { 'agents_hints': { 'num_agents': num_agents, 'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations } }
def __init__(self): dFiles = { "": "Background_#9CCB89.svg", "WE": "Gleis_Deadend.svg", "WW EE NN SS": "Gleis_Diamond_Crossing.svg", "WW EE": "Gleis_horizontal.svg", "EN SW": "Gleis_Kurve_oben_links.svg", "WN SE": "Gleis_Kurve_oben_rechts.svg", "ES NW": "Gleis_Kurve_unten_links.svg", "NE WS": "Gleis_Kurve_unten_rechts.svg", "NN SS": "Gleis_vertikal.svg", "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg", "EE WW EN SW": "Weiche_horizontal_oben_links.svg", "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg", "EE WW ES NW": "Weiche_horizontal_unten_links.svg", "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg", "NN SS EE WW NW ES": "Weiche_Single_Slip.svg", "NE NW ES WS": "Weiche_Symetrical.svg", "NN SS EN SW": "Weiche_vertikal_oben_links.svg", "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg", "NN SS NW ES": "Weiche_vertikal_unten_links.svg", "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg" } self.dSvg = {} transitions = RailEnvTransitions() lDirs = list("NESW") svgBG = SVG("./svg/Background_#9CCB89.svg") for sTrans, sFile in dFiles.items(): svg = SVG("./svg/" + sFile) # Translate the ascii transition descption in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) lTrans16 = ["0"] * 16 for sTran in sTrans.split(" "): if len(sTran) == 2: iDirIn = lDirs.index(sTran[0]) iDirOut = lDirs.index(sTran[1]) iTrans = 4 * iDirIn + iDirOut lTrans16[iTrans] = "1" sTrans16 = "".join(lTrans16) binTrans = int(sTrans16, 2) print(sTrans, sTrans16, sFile) # Merge the transition svg image with the background colour. # This is a shortcut / hack and will need re-working. if binTrans > 0: svg = svg.merge(svgBG) self.dSvg[binTrans] = svg # Rotate both the transition binary and the image and save in the dict for nRot in [90, 180, 270]: binTrans2 = transitions.rotate_transition(binTrans, nRot) svg2 = svg.copy() svg2.set_rotate(nRot) self.dSvg[binTrans2] = svg2
def test_rail_environment_single_agent(): # We instantiate the following map on a 3x3 grid # _ _ # / \/ \ # | | | # \_/\_/ transitions = RailEnvTransitions() cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition( south_symmetrical_switch, 180) south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) north_west_turn = transitions.rotate_transition(south_east_turn, 180) rail_map = np.array( [[south_east_turn, south_symmetrical_switch, south_west_turn], [vertical_line, vertical_line, vertical_line], [north_east_turn, north_symmetrical_switch, north_west_turn]], dtype=np.uint16) rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) for _ in range(200): _ = rail_env.reset(False, False, True) # We do not care about target for the moment agent = rail_env.agents[0] agent.target = [-1, -1] # Check that trains are always initialized at a consistent position # or direction. # They should always be able to go somewhere. assert (transitions.get_transitions(rail_map[agent.position], agent.direction) != (0, 0, 0, 0)) initial_pos = agent.position valid_active_actions_done = 0 pos = initial_pos while valid_active_actions_done < 6: # We randomly select an action action = np.random.randint(4) _, _, _, _ = rail_env.step({0: action}) prev_pos = pos pos = agent.position # rail_env.agents_position[0] if prev_pos != pos: valid_active_actions_done += 1 # After 6 movements on this railway network, the train should be back # to its original height on the map. assert (initial_pos[0] == agent.position[0]) # We check that the train always attains its target after some time for _ in range(10): _ = rail_env.reset() done = False while not done: # We randomly select an action action = np.random.randint(4) _, _, dones, _ = rail_env.step({0: action}) done = dones['__all__']
def test_rail_environment_single_agent(show=False): # We instantiate the following map on a 3x3 grid # _ _ # / \/ \ # | | | # \_/\_/ transitions = RailEnvTransitions() if False: # This env creation doesn't quite work right. cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) north_west_turn = transitions.rotate_transition(south_east_turn, 180) rail_map = np.array([[south_east_turn, south_symmetrical_switch, south_west_turn], [vertical_line, vertical_line, vertical_line], [north_east_turn, north_symmetrical_switch, north_west_turn]], dtype=np.uint16) rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) else: rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests") rail_map = rail_env.rail.grid rail_env._max_episode_steps = 1000 _ = rail_env.reset(False, False, True) liActions = [int(a) for a in RailEnvActions] env_renderer = RenderTool(rail_env) #RailEnvPersister.save(rail_env, "test_env_figure8.pkl") for _ in range(5): #rail_env.agents[0].initial_position = (1,2) _ = rail_env.reset(False, False, True) # We do not care about target for the moment agent = rail_env.agents[0] agent.target = [-1, -1] # Check that trains are always initialized at a consistent position # or direction. # They should always be able to go somewhere. if show: print("After reset - agent pos:", agent.position, "dir: ", agent.direction) print(transitions.get_transitions(rail_map[agent.position], agent.direction)) #assert (transitions.get_transitions( # rail_map[agent.position], # agent.direction) != (0, 0, 0, 0)) # HACK - force the direction to one we know is good. #agent.initial_position = agent.position = (2,3) agent.initial_direction = agent.direction = 0 if show: print ("handle:", agent.handle) #agent.initial_position = initial_pos = agent.position valid_active_actions_done = 0 pos = agent.position if show: env_renderer.render_env(show=show, show_agents=True) time.sleep(0.01) iStep = 0 while valid_active_actions_done < 6: # We randomly select an action action = np.random.choice(liActions) #action = RailEnvActions.MOVE_FORWARD _, _, dict_done, _ = rail_env.step({0: action}) prev_pos = pos pos = agent.position # rail_env.agents_position[0] print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction) print(dict_done) if prev_pos != pos: valid_active_actions_done += 1 iStep += 1 if show: env_renderer.render_env(show=show, show_agents=True, step=iStep) time.sleep(0.01) assert iStep < 100, "valid actions should have been performed by now - hung agent" # After 6 movements on this railway network, the train should be back # to its original height on the map. #assert (initial_pos[0] == agent.position[0]) # We check that the train always attains its target after some time for _ in range(10): _ = rail_env.reset() rail_env.agents[0].direction = 0 # JW - to avoid problem with random_schedule_generator. #rail_env.agents[0].position = (1,2) iStep = 0 while iStep < 100: # We randomly select an action action = np.random.choice(liActions) _, _, dones, _ = rail_env.step({0: action}) done = dones['__all__'] if done: break iStep +=1 assert iStep < 100, "agent should have finished by now" env_renderer.render_env(show=show)
def test_build_railway_infrastructure(): rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=20, height=20, transitions=rail_trans) grid_map.grid.fill(0) # Make connection with dead-ends on both sides start_point = (2, 2) end_point = (8, 8) connection_001 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=True, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8), (8, 8)] # Make connection with open ends on both sides start_point = (1, 3) end_point = (1, 7) connection_002 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=False, flip_end_node_trans=False, respect_transition_validity=True, forbidden_cells=None) connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] # Make connection with open end at beginning and dead end on end start_point = (6, 2) end_point = (6, 5) connection_003 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=False, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] # Make connection with dead end on start and opend end start_point = (7, 5) end_point = (8, 9) connection_004 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=True, flip_end_node_trans=False, respect_transition_validity=True, forbidden_cells=None) connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] assert connection_001 == connection_001_expected, \ "actual={}, expected={}".format(connection_001, connection_001_expected) assert connection_002 == connection_002_expected, \ "actual={}, expected={}".format(connection_002, connection_002_expected) assert connection_003 == connection_003_expected, \ "actual={}, expected={}".format(connection_003, connection_003_expected) assert connection_004 == connection_004_expected, \ "actual={}, expected={}".format(connection_004, connection_004_expected) grid_map_grid_expected = [ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1025, 1025, 1025, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 4, 1025, 1025, 1025, 1025, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 1025, 1025, 256, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], [ 0, 0, 0, 0, 0, 4, 1025, 1025, 33825, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ] for i in range(len(grid_map_grid_expected)): assert np.all(grid_map.grid[i] == grid_map_grid_expected[i])
def test_fix_inner_nodes(): rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=6, height=10, transitions=rail_trans) grid_map.grid.fill(0) start = (2, 2) target = (8, 2) parallel_start = (3, 3) parallel_target = (7, 3) parallel_start_1 = (4, 4) parallel_target_1 = (6, 4) inner_nodes = [ start, target, parallel_start, parallel_target, parallel_start_1, parallel_target_1 ] track_0 = connect_straight_line_in_grid_map(grid_map, start, target, rail_trans) track_1 = connect_straight_line_in_grid_map(grid_map, parallel_start, parallel_target, rail_trans) track_2 = connect_straight_line_in_grid_map(grid_map, parallel_start_1, parallel_target_1, rail_trans) # Fix the ends of the inner node # This is not a fix in transition type but rather makes the necessary connections to the parallel tracks for node in inner_nodes: fix_inner_nodes(grid_map, node, rail_trans) def orienation(pos): if pos[0] < grid_map.grid.shape[0] / 2: return 2 else: return 0 # Fix all the different transitions to legal elements for c in range(grid_map.grid.shape[1]): for r in range(grid_map.grid.shape[0]): grid_map.fix_transitions((r, c), orienation((r, c))) # Print for assertion tests # print("assert grid_map.grid[{}] == {}".format((r,c),grid_map.grid[(r,c)])) assert grid_map.grid[(1, 0)] == 0 assert grid_map.grid[(2, 0)] == 0 assert grid_map.grid[(3, 0)] == 0 assert grid_map.grid[(4, 0)] == 0 assert grid_map.grid[(5, 0)] == 0 assert grid_map.grid[(6, 0)] == 0 assert grid_map.grid[(7, 0)] == 0 assert grid_map.grid[(8, 0)] == 0 assert grid_map.grid[(9, 0)] == 0 assert grid_map.grid[(0, 1)] == 0 assert grid_map.grid[(1, 1)] == 0 assert grid_map.grid[(2, 1)] == 0 assert grid_map.grid[(3, 1)] == 0 assert grid_map.grid[(4, 1)] == 0 assert grid_map.grid[(5, 1)] == 0 assert grid_map.grid[(6, 1)] == 0 assert grid_map.grid[(7, 1)] == 0 assert grid_map.grid[(8, 1)] == 0 assert grid_map.grid[(9, 1)] == 0 assert grid_map.grid[(0, 2)] == 0 assert grid_map.grid[(1, 2)] == 0 assert grid_map.grid[(2, 2)] == 8192 assert grid_map.grid[(3, 2)] == 49186 assert grid_map.grid[(4, 2)] == 32800 assert grid_map.grid[(5, 2)] == 32800 assert grid_map.grid[(6, 2)] == 32800 assert grid_map.grid[(7, 2)] == 32872 assert grid_map.grid[(8, 2)] == 128 assert grid_map.grid[(9, 2)] == 0 assert grid_map.grid[(0, 3)] == 0 assert grid_map.grid[(1, 3)] == 0 assert grid_map.grid[(2, 3)] == 0 assert grid_map.grid[(3, 3)] == 4608 assert grid_map.grid[(4, 3)] == 49186 assert grid_map.grid[(5, 3)] == 32800 assert grid_map.grid[(6, 3)] == 32872 assert grid_map.grid[(7, 3)] == 2064 assert grid_map.grid[(8, 3)] == 0 assert grid_map.grid[(9, 3)] == 0 assert grid_map.grid[(0, 4)] == 0 assert grid_map.grid[(1, 4)] == 0 assert grid_map.grid[(2, 4)] == 0 assert grid_map.grid[(3, 4)] == 0 assert grid_map.grid[(4, 4)] == 4608 assert grid_map.grid[(5, 4)] == 32800 assert grid_map.grid[(6, 4)] == 2064 assert grid_map.grid[(7, 4)] == 0 assert grid_map.grid[(8, 4)] == 0 assert grid_map.grid[(9, 4)] == 0 assert grid_map.grid[(0, 5)] == 0 assert grid_map.grid[(1, 5)] == 0 assert grid_map.grid[(2, 5)] == 0 assert grid_map.grid[(3, 5)] == 0 assert grid_map.grid[(4, 5)] == 0 assert grid_map.grid[(5, 5)] == 0 assert grid_map.grid[(6, 5)] == 0 assert grid_map.grid[(7, 5)] == 0 assert grid_map.grid[(8, 5)] == 0 assert grid_map.grid[(9, 5)] == 0
def generator() -> RailGenerator: """ Arguments are ignored and taken directly from the curriculum except the np_random: RandomState which is the last argument (args[-1]) """ if curriculum.get("n_agents") > curriculum.get("n_cities"): raise Exception("complex_rail_generator: n_agents > n_cities!") grid_map = GridTransitionMap(width=curriculum.get("x_dim"), height=curriculum.get("y_size"), transitions=RailEnvTransitions()) rail_array = grid_map.grid rail_array.fill(0) # generate rail array # step 1: # - generate a start and goal position # - validate min/max distance allowed # - validate that start/goals are not placed too close to other start/goals # - draw a rail from [start,goal] # - if rail crosses existing rail then validate new connection # - possibility that this fails to create a path to goal # - on failure generate new start/goal # # step 2: # - add more rails to map randomly between cells that have rails # - validate all new rails, on failure don't add new rails # # step 3: # - return transition map + list of [start_pos, start_dir, goal_pos] points # rail_trans = grid_map.transitions start_goal = [] start_dir = [] nr_created = 0 created_sanity = 0 sanity_max = 9000 free_cells = set([(r, c) for r, row in enumerate(rail_array) for c, col in enumerate(row) if col == 0]) while nr_created < curriculum.get( "n_cities") and created_sanity < sanity_max: all_ok = False if len(free_cells) == 0: break for _ in range(sanity_max): start = random.sample(free_cells, 1)[0] goal = random.sample(free_cells, 1)[0] # check min/max distance dist_sg = distance_on_rail(start, goal) if dist_sg < curriculum.get("min_dist"): continue if dist_sg > curriculum.get("max_dist"): continue # check distance to existing points sg_new = [start, goal] def check_all_dist(): """ Function to check the distance betweens start and goal :param sg_new: start and goal tuple :return: True if distance is larger than 2, False otherwise """ for sg in start_goal: for i in range(2): for j in range(2): dist = distance_on_rail(sg_new[i], sg[j]) if dist < 2: return False return True if check_all_dist(): all_ok = True free_cells.remove(start) free_cells.remove(goal) break if not all_ok: # we might as well give up at this point break new_path = connect_rail_in_grid_map( grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, flip_start_node_trans=True, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) start_dir.append( mirror(get_direction(new_path[0], new_path[1]))) else: # after too many failures we will give up created_sanity += 1 # add extra connections between existing rail created_sanity = 0 nr_created = 0 while nr_created < curriculum.get( "n_extra") and created_sanity < sanity_max: if len(free_cells) == 0: break for _ in range(sanity_max): start = random.sample(free_cells, 1)[0] goal = random.sample(free_cells, 1)[0] new_path = connect_rail_in_grid_map( grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, flip_start_node_trans=True, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 else: # after too many failures we will give up created_sanity += 1 return grid_map, { 'agents_hints': { 'start_goal': start_goal, 'start_dir': start_dir } }
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1): """ Fixes broken transitions """ gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] grcPos = array(rcPos) grcMax = self.grid.shape # Transition elements transitions = RailEnvTransitions() cells = transitions.transition_list simple_switch_east_south = transitions.rotate_transition(cells[10], 90) simple_switch_west_south = transitions.rotate_transition(cells[2], 270) symmetrical = cells[6] double_slip = cells[5] three_way_transitions = [ simple_switch_east_south, simple_switch_west_south ] # loop over available outbound directions (indices) for rcPos incoming_connections = np.zeros(4) for iDirOut in np.arange(4): gdRC = gDir2dRC[iDirOut] # row,col increment gPos2 = grcPos + gdRC # next cell in that direction # Check the adjacent cell is within bounds # if not, then ignore it for the count of incoming connections if np.any(gPos2 < 0): continue if np.any(gPos2 >= grcMax): continue # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid connected = 0 for orientation in range(4): connected += self.get_transition( (gPos2[0], gPos2[1], orientation), mirror(iDirOut)) if connected > 0: incoming_connections[iDirOut] = 1 number_of_incoming = np.sum(incoming_connections) # Only one incoming direction --> Straight line set deadend if number_of_incoming == 1: if self.get_full_transitions(*rcPos) == 0: self.set_transitions(rcPos, 0) else: self.set_transitions(rcPos, 0) for direction in range(4): if incoming_connections[direction] > 0: self.set_transition( (rcPos[0], rcPos[1], mirror(direction)), direction, 1) # Connect all incoming connections if number_of_incoming == 2: self.set_transitions(rcPos, 0) connect_directions = np.argwhere(incoming_connections > 0) self.set_transition( (rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition( (rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) # Find feasible connection for three entries if number_of_incoming == 3: self.set_transitions(rcPos, 0) hole = np.argwhere(incoming_connections < 1)[0][0] if direction >= 0: switch_type_idx = (direction - hole + 3) % 4 if switch_type_idx == 0: transition = simple_switch_west_south elif switch_type_idx == 2: transition = simple_switch_east_south else: transition = self.random_generator.choice( three_way_transitions, 1) else: transition = self.random_generator.choice( three_way_transitions, 1) transition = transitions.rotate_transition(transition, int(hole * 90)) self.set_transitions((rcPos[0], rcPos[1]), transition) # Make a double slip switch if number_of_incoming == 4: rotation = self.random_generator.randint(2) transition = transitions.rotate_transition(double_slip, int(rotation * 90)) self.set_transitions((rcPos[0], rcPos[1]), transition) return True
def test_valid_railenv_transitions(): rail_env_trans = RailEnvTransitions() # directions: # 'N': 0 # 'E': 1 # 'S': 2 # 'W': 3 for i in range(2): assert (rail_env_trans.get_transitions(int('1100110000110011', 2), i) == (1, 1, 0, 0)) assert (rail_env_trans.get_transitions(int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1)) no_transition_cell = int('0000000000000000', 2) for i in range(4): assert (rail_env_trans.get_transitions(no_transition_cell, i) == (0, 0, 0, 0)) # Facing south, going south north_south_transition = rail_env_trans.set_transitions( no_transition_cell, 2, (0, 0, 1, 0)) assert (rail_env_trans.set_transition(north_south_transition, 2, 2, 0) == no_transition_cell) assert (rail_env_trans.get_transition(north_south_transition, 2, 2)) # Facing north, going east south_east_transition = \ rail_env_trans.set_transition(no_transition_cell, 0, 1, 1) assert (rail_env_trans.get_transition(south_east_transition, 0, 1)) # The opposite transitions are not feasible assert (not rail_env_trans.get_transition(north_south_transition, 2, 0)) assert (not rail_env_trans.get_transition(south_east_transition, 2, 1)) east_west_transition = rail_env_trans.rotate_transition( north_south_transition, 90) north_west_transition = rail_env_trans.rotate_transition( south_east_transition, 180) # Facing west, going west assert (rail_env_trans.get_transition(east_west_transition, 3, 3)) # Facing south, going west assert (rail_env_trans.get_transition(north_west_transition, 2, 3)) assert (south_east_transition == rail_env_trans.rotate_transition( south_east_transition, 360))
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion transitions_templates_ = [] transition_probabilities = [] for i in range(len(t_utils.transitions)): # don't include dead-ends if t_utils.transitions[i] == int('0010000000000000', 2): continue all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) all_transitions |= (trans[0] << 3) | \ (trans[1] << 2) | \ (trans[2] << 1) | \ (trans[3]) template = [int(x) for x in bin(all_transitions)[2:]] template = [0] * (4 - len(template)) + template # add all rotations for rot in [0, 90, 180, 270]: transitions_templates_.append( (template, t_utils.rotate_transition(t_utils.transitions[i], rot))) transition_probabilities.append(transition_probability[i]) template = [template[-1]] + template[:-1] def get_matching_templates(template): """ Returns a list of possible transition maps for a given template Parameters: ------ template:List[int] Returns: ------ List[int] """ ret = [] for i in range(len(transitions_templates_)): is_match = True for j in range(4): if template[j] >= 0 and template[ j] != transitions_templates_[i][0][j]: is_match = False break if is_match: ret.append((transitions_templates_[i][1], transition_probabilities[i])) return ret MAX_INSERTIONS = (width - 2) * (height - 2) * 10 MAX_ATTEMPTS_FROM_SCRATCH = 10 attempt_number = 0 while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: cells_to_fill = [] rail = [] for r in range(height): rail.append([None] * width) if r > 0 and r < height - 1: cells_to_fill = cells_to_fill + [ (r, c) for c in range(1, width - 1) ] num_insertions = 0 while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: cell = cells_to_fill[np_random.choice(len(cells_to_fill), 1)[0]] cells_to_fill.remove(cell) row = cell[0] col = cell[1] # look at its neighbors and see what are the possible transitions # that can be chosen from, if any. valid_template = [-1, -1, -1, -1] for el in [(0, 2, (-1, 0)), (1, 3, (0, 1)), (2, 0, (1, 0)), (3, 1, (0, -1))]: # N, E, S, W neigh_trans = rail[row + el[2][0]][col + el[2][1]] if neigh_trans is not None: # select transition coming from facing direction el[1] and # moving to direction el[1] max_bit = 0 for k in range(4): max_bit |= t_utils.get_transition( neigh_trans, k, el[1]) if max_bit: valid_template[el[0]] = 1 else: valid_template[el[0]] = 0 possible_cell_transitions = get_matching_templates( valid_template) if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS # no cell can be filled in without violating some transitions # can a dead-end solve the problem? if valid_template.count(1) == 1: for k in range(4): if valid_template[k] == 1: rot = 0 if k == 0: rot = 180 elif k == 1: rot = 270 elif k == 2: rot = 0 elif k == 3: rot = 90 rail[row][col] = t_utils.rotate_transition( int('0010000000000000', 2), rot) num_insertions += 1 break else: # can I get valid transitions by removing a single # neighboring cell? bestk = -1 besttrans = [] for k in range(4): tmp_template = valid_template[:] tmp_template[k] = -1 possible_cell_transitions = get_matching_templates( tmp_template) if len(possible_cell_transitions) > len(besttrans): besttrans = possible_cell_transitions bestk = k if bestk >= 0: # Replace the corresponding cell with None, append it # to cells to fill, fill in a transition in the current # cell. replace_row = row - 1 replace_col = col if bestk == 1: replace_row = row replace_col = col + 1 elif bestk == 2: replace_row = row + 1 replace_col = col elif bestk == 3: replace_row = row replace_col = col - 1 cells_to_fill.append((replace_row, replace_col)) rail[replace_row][replace_col] = None possible_transitions, possible_probabilities = zip( *besttrans) possible_probabilities = [ p / sum(possible_probabilities) for p in possible_probabilities ] rail[row][col] = np_random.choice( possible_transitions, p=possible_probabilities) num_insertions += 1 else: print('WARNING: still nothing!') rail[row][col] = int('0000000000000000', 2) num_insertions += 1 pass else: possible_transitions, possible_probabilities = zip( *possible_cell_transitions) possible_probabilities = [ p / sum(possible_probabilities) for p in possible_probabilities ] rail[row][col] = np_random.choice(possible_transitions, p=possible_probabilities) num_insertions += 1 if num_insertions == MAX_INSERTIONS: # Failed to generate a valid level; try again for a number of times attempt_number += 1 else: break if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: print('ERROR: failed to generate level') # Finally pad the border of the map with dead-ends to avoid border issues; # at most 1 transition in the neigh cell for r in range(height): # Check for transitions coming from [r][1] to WEST max_bit = 0 neigh_trans = rail[r][1] if neigh_trans is not None: for k in range(4): neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & 1) if max_bit: rail[r][0] = t_utils.rotate_transition( int('0010000000000000', 2), 270) else: rail[r][0] = int('0000000000000000', 2) # Check for transitions coming from [r][-2] to EAST max_bit = 0 neigh_trans = rail[r][-2] if neigh_trans is not None: for k in range(4): neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) if max_bit: rail[r][-1] = t_utils.rotate_transition( int('0010000000000000', 2), 90) else: rail[r][-1] = int('0000000000000000', 2) for c in range(width): # Check for transitions coming from [1][c] to NORTH max_bit = 0 neigh_trans = rail[1][c] if neigh_trans is not None: for k in range(4): neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) if max_bit: rail[0][c] = int('0010000000000000', 2) else: rail[0][c] = int('0000000000000000', 2) # Check for transitions coming from [-2][c] to SOUTH max_bit = 0 neigh_trans = rail[-2][c] if neigh_trans is not None: for k in range(4): neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) if max_bit: rail[-1][c] = t_utils.rotate_transition( int('0010000000000000', 2), 180) else: rail[-1][c] = int('0000000000000000', 2) # For display only, wrong levels for r in range(height): for c in range(width): if rail[r][c] is None: rail[r][c] = int('0000000000000000', 2) tmp_rail = np.asarray(rail, dtype=np.uint16) return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail return return_rail, None
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: if num_agents > nr_start_goal: num_agents = nr_start_goal print( "complex_rail_generator: num_agents > nr_start_goal, changing num_agents" ) grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions()) rail_array = grid_map.grid rail_array.fill(0) # generate rail array # step 1: # - generate a start and goal position # - validate min/max distance allowed # - validate that start/goals are not placed too close to other start/goals # - draw a rail from [start,goal] # - if rail crosses existing rail then validate new connection # - possibility that this fails to create a path to goal # - on failure generate new start/goal # # step 2: # - add more rails to map randomly between cells that have rails # - validate all new rails, on failure don't add new rails # # step 3: # - return transition map + list of [start_pos, start_dir, goal_pos] points # rail_trans = grid_map.transitions start_goal = [] start_dir = [] nr_created = 0 created_sanity = 0 sanity_max = 9000 while nr_created < nr_start_goal and created_sanity < sanity_max: all_ok = False for _ in range(sanity_max): start = (np_random.randint(0, height), np_random.randint(0, width)) goal = (np_random.randint(0, height), np_random.randint(0, width)) # check to make sure start,goal pos is empty? if rail_array[goal] != 0 or rail_array[start] != 0: continue # check min/max distance dist_sg = distance_on_rail(start, goal) if dist_sg < min_dist: continue if dist_sg > max_dist: continue # check distance to existing points sg_new = [start, goal] def check_all_dist(sg_new): """ Function to check the distance betweens start and goal :param sg_new: start and goal tuple :return: True if distance is larger than 2, False otherwise """ for sg in start_goal: for i in range(2): for j in range(2): dist = distance_on_rail(sg_new[i], sg[j]) if dist < 2: return False return True if check_all_dist(sg_new): all_ok = True break if not all_ok: # we might as well give up at this point break new_path = connect_rail_in_grid_map( grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, flip_start_node_trans=True, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) start_dir.append( mirror(get_direction(new_path[0], new_path[1]))) else: # after too many failures we will give up created_sanity += 1 # add extra connections between existing rail created_sanity = 0 nr_created = 0 while nr_created < nr_extra and created_sanity < sanity_max: all_ok = False for _ in range(sanity_max): start = (np_random.randint(0, height), np_random.randint(0, width)) goal = (np_random.randint(0, height), np_random.randint(0, width)) # check to make sure start,goal pos are not empty if rail_array[goal] == 0 or rail_array[start] == 0: continue else: all_ok = True break if not all_ok: break new_path = connect_rail_in_grid_map( grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, flip_start_node_trans=True, flip_end_node_trans=True, respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 else: # after too many failures we will give up created_sanity += 1 return grid_map, { 'agents_hints': { 'start_goal': start_goal, 'start_dir': start_dir } }
def test_rotate_railenv_transition(): rail_env_transitions = RailEnvTransitions() # TODO test all cases transition_cycles = [ # empty cell - Case 0 [ int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2) ], # Case 1 - straight # | # | # | [int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)], # Case 1b (8) - simple turn right # _ # | # | [ int(rw('0100 0000 0000 0010'), 2), int(rw('0001 0010 0000 0000'), 2), int(rw('0000 1000 0001 0000'), 2), int(rw('0000 0000 0100 1000'), 2), ], # Case 1c (9) - simple turn left # _ # | # | # int('0001001000000000', 2),\ # noqa: E800 # Case 2 - simple left switch # _ _| # | # | [ int(rw('1001 0010 0010 0000'), 2), int(rw('0000 1100 0001 0001'), 2), int(rw('1000 0000 0110 1000'), 2), int(rw('0100 0100 0000 0011'), 2), ], # Case 2b (10) - simple right switch # | # | # | # int('1100000000100010', 2) \ # noqa: E800 # Case 3 - diamond drossing # int('1000010000100001', 2), \ # noqa: E800 # Case 4 - single slip # int('1001011000100001', 2), \ # noqa: E800 # Case 5 - double slip # int('1100110000110011', 2), \ # noqa: E800 # Case 6 - symmetrical # int('0101001000000010', 2), \ # noqa: E800 # Case 7 - dead end # # # | [ int(rw('0010 0000 0000 0000'), 2), int(rw('0000 0001 0000 0000'), 2), int(rw('0000 0000 1000 0000'), 2), int(rw('0000 0000 0000 0100'), 2), ], ] for index, cycle in enumerate(transition_cycles): for i in range(4): actual_transition = rail_env_transitions.rotate_transition( cycle[0], i * 90) expected_transition = cycle[i % len(cycle)] try: assert actual_transition == expected_transition, \ "Case {}: rotate_transition({}, {}) should equal {} but was {}.".format( i, cycle[0], i, expected_transition, actual_transition) except Exception as e: print("expected:") rail_env_transitions.print(expected_transition) print("actual:") rail_env_transitions.print(actual_transition) raise e