def test_single_malfunction_generator(): """ Test single malfunction generator Returns ------- """ rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=single_malfunction_generator( earlierst_malfunction=10, malfunction_duration=5)) for test in range(10): env.reset() action_dict = dict() tot_malfunctions = 0 print(test) for i in range(10): for agent in env.agents: # Go forward all the time action_dict[agent.handle] = RailEnvActions(2) env.step(action_dict) for agent in env.agents: # Go forward all the time tot_malfunctions += agent.malfunction_data['nr_malfunctions'] assert tot_malfunctions == 1
def test_path_not_exists(rendering=False): rail, rail_map = make_simple_rail_unconnected() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # south dead-end 0, # north (0, 3), # north dead-end False) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?")
def test_seeding_and_observations(): # Test if two different instances diverge with different observations rail, rail_map = make_simple_rail2() # Make two seperate envs with different observation builders # Global Observation env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset(False, False, False, random_seed=12) env2.reset(False, False, False, random_seed=12) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[1].initial_position == env2.agents[1].initial_position assert env.agents[2].initial_position == env2.agents[2].initial_position assert env.agents[3].initial_position == env2.agents[3].initial_position assert env.agents[4].initial_position == env2.agents[4].initial_position assert env.agents[5].initial_position == env2.agents[5].initial_position assert env.agents[6].initial_position == env2.agents[6].initial_position assert env.agents[7].initial_position == env2.agents[7].initial_position assert env.agents[8].initial_position == env2.agents[8].initial_position assert env.agents[9].initial_position == env2.agents[9].initial_position action_dict = {} for step in range(10): for a in range(env.get_num_agents()): action = np.random.randint(4) action_dict[a] = action env.step(action_dict) env2.step(action_dict) # Check that both environments end up in the same position assert env.agents[0].position == env2.agents[0].position assert env.agents[1].position == env2.agents[1].position assert env.agents[2].position == env2.agents[2].position assert env.agents[3].position == env2.agents[3].position assert env.agents[4].position == env2.agents[4].position assert env.agents[5].position == env2.agents[5].position assert env.agents[6].position == env2.agents[6].position assert env.agents[7].position == env2.agents[7].position assert env.agents[8].position == env2.agents[8].position assert env.agents[9].position == env2.agents[9].position for a in range(env.get_num_agents()): print("assert env.agents[{}].position == env2.agents[{}].position". format(a, a))
def test_random_seeding(): # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() # Move target to unreachable position in order to not interfere with test for idx in range(100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10) env.reset(True, True, False, random_seed=1) env.agents[0].target = (0, 0) for step in range(10): actions = {} actions[0] = 2 env.step(actions) agent_positions = [] env.agents[0].initial_position == (3, 2) env.agents[1].initial_position == (3, 5) env.agents[2].initial_position == (3, 6) env.agents[3].initial_position == (5, 6) env.agents[4].initial_position == (3, 4) env.agents[5].initial_position == (3, 1) env.agents[6].initial_position == (3, 9) env.agents[7].initial_position == (4, 6) env.agents[8].initial_position == (0, 3) env.agents[9].initial_position == (3, 7)
def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): # Parameter initialization print('Creating {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.format( test_nr, parameters[0], parameters[1], parameters[2])) # Reset environment random.seed(parameters[3]) np.random.seed(parameters[3]) nr_paths = max(4, parameters[2] + int(0.5 * parameters[2])) min_dist = int(min([parameters[0], parameters[1]]) * 0.75) env = RailEnv(width=parameters[0], height=parameters[1], rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=parameters[2]) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env env.reset(True, True) env.save("./Tests/{}/Level_{}.pkl".format(test_nr, trial)) printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) return
def test_malfanction_from_params(): """ Test loading malfunction from Returns ------- """ stochastic_data = MalfunctionParameters( malfunction_rate=1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data)) env.reset() assert env.malfunction_process_data.malfunction_rate == 1000 assert env.malfunction_process_data.min_duration == 2 assert env.malfunction_process_data.max_duration == 5
def test_get_shortest_paths_unreachable(): rail, rail_map = make_disconnected_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env.reset() # set the initial position agent = env.agents[0] agent.position = (3, 1) # west dead-end agent.initial_position = (3, 1) # west dead-end agent.direction = Grid4TransitionsEnum.WEST agent.target = (3, 9) # east dead-end agent.moving = True env.reset(False, False) actual = get_shortest_paths(env.distance_map) expected = {0: None} assert actual == expected, "actual={},expected={}".format(actual, expected)
def test_get_entry_directions(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() def _assert(position, expected): actual = env.get_valid_directions_on_grid(*position) assert actual == expected, "[{},{}] actual={}, expected={}".format( *position, actual, expected) # north dead end _assert((0, 3), [True, False, False, False]) # west dead end _assert((3, 0), [False, False, False, True]) # switch _assert((3, 3), [False, True, True, True]) # horizontal _assert((3, 2), [False, True, False, True]) # vertical _assert((2, 3), [True, False, True, False]) # nowhere _assert((0, 0), [False, False, False, False])
def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction agent_1_tar = env.agents[0].target agent_2_pos = env.agents[1].position agent_2_dir = env.agents[1].direction agent_2_tar = env.agents[1].target os.makedirs("tmp", exist_ok=True) RailEnvPersister.save(env, "tmp/test_save.pkl") env.save("tmp/test_save_2.pkl") #env.load("test_save.dat") env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) assert (agent_1_pos == env.agents[0].position) assert (agent_1_dir == env.agents[0].direction) assert (agent_1_tar == env.agents[0].target) assert (agent_2_pos == env.agents[1].position) assert (agent_2_dir == env.agents[1].direction) assert (agent_2_tar == env.agents[1].target)
def gen_env(number_agents, width, height, n_start_goal, seed): speed_ration_map = { 1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25 } # Slow freight train env = RailEnv(width=width, height=height, rail_generator=complex_rail_generator( nr_start_goal=n_start_goal, nr_extra=3, min_dist=6, max_dist=99999, seed=seed), schedule_generator=complex_schedule_generator( speed_ratio_map=speed_ration_map), number_of_agents=number_agents, obs_builder_object=TreeObsForRailEnv(max_depth=5)) env.reset() env.step(dict(zip(range(number_agents), [2] * number_agents))) return env
def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents assert env.agents[0].malfunction_data['malfunction'] == 0 assert env.agents[1].malfunction_data['malfunction'] == 10 assert env.agents[2].malfunction_data['malfunction'] == 0 assert env.agents[3].malfunction_data['malfunction'] == 10 assert env.agents[4].malfunction_data['malfunction'] == 10 assert env.agents[5].malfunction_data['malfunction'] == 10 assert env.agents[6].malfunction_data['malfunction'] == 10 assert env.agents[7].malfunction_data['malfunction'] == 10 assert env.agents[8].malfunction_data['malfunction'] == 10 assert env.agents[9].malfunction_data['malfunction'] == 10
def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction agent_1_tar = env.agents[0].target agent_2_pos = env.agents[1].position agent_2_dir = env.agents[1].direction agent_2_tar = env.agents[1].target env.save("test_save.dat") env.load("test_save.dat") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) assert (agent_1_pos == env.agents[0].position) assert (agent_1_dir == env.agents[0].direction) assert (agent_1_tar == env.agents[0].target) assert (agent_2_pos == env.agents[1].position) assert (agent_2_dir == env.agents[1].direction) assert (agent_2_tar == env.agents[1].target)
def render_test(parameters, test_nr=0, nr_examples=5): for trial in range(nr_examples): # Reset the env print( 'Showing {} Level {} with (x_dim,y_dim) = ({},{}) and {} Agents.'. format(test_nr, trial, parameters[0], parameters[1], parameters[2])) file_name = "./Tests/{}/Level_{}.pkl".format(test_nr, trial) env = RailEnv( width=1, height=1, rail_generator=rail_from_file(file_name), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=1, ) env_renderer = RenderTool( env, gl="PILSVG", ) env_renderer.set_new_rail() env.reset(False, False) env_renderer.render_env(show=True, show_observations=False) time.sleep(0.1) env_renderer.close_window() return
def test_malfunction_values_and_behavior(): """ Test the malfunction counts down as desired Returns ------- """ # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} stochastic_data = MalfunctionParameters(malfunction_rate=0.001, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) env.reset(False, False, activate_agents=True, random_seed=10) # Assertions assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5] print("[") for time_step in range(15): # Move in the env env.step(action_dict) # Check that next_step decreases as expected assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
def test_path_exists(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # north of south dead-end 0, # north (3, 9), # east dead-end True) check_path( env, rail, (6, 6), # south dead-end 2, # south (3, 9), # east dead-end True) check_path( env, rail, (3, 0), # east dead-end 3, # west (0, 3), # north dead-end True) check_path( env, rail, (5, 6), # east dead-end 0, # west (1, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 2, # south (3, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 0, # north (3, 3), # north dead-end True)
def load_env(env_dict, obs_builder_object=GlobalObsForRailEnv()): """ Loads an env """ env = RailEnv(height=4, width=4, obs_builder_object=obs_builder_object) env.reset(regenerate_rail=False, regenerate_schedule=False) RailEnvPersister.set_full_state(env, env_dict) return env
def test_load_env(): env = RailEnv(10, 10) env.reset() env.load_resource('env_data.tests', 'test-10x10.mpk') agent_static = EnvAgent((0, 0), 2, (5, 5), False) env.add_agent(agent_static) assert env.get_num_agents() == 1
def test_random_rail_generator(): n_agents = 1 x_dim = 5 y_dim = 10 # Check that a random level at with correct parameters is generated env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents) env.reset() assert env.rail.grid.shape == (y_dim, x_dim) assert env.get_num_agents() == n_agents
def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_invalid_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() # set the initial position agent = env.agents[0] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE agent = env.agents[1] agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE observations, info = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?") # get the trees to test obs_builder: TreeObsForRailEnv = env.obs_builder pp = pprint.PrettyPrinter(indent=4) tree_0 = observations[0] tree_1 = observations[1] env.obs_builder.util_print_obs_subtree(tree_0) env.obs_builder.util_print_obs_subtree(tree_1) # check the expectations expected_conflicts_0 = [('F', 'R')] expected_conflicts_1 = [('F', 'L')] _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def main(): env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), number_of_agents=3, obs_builder_object=SimpleObs()) env.reset() # Print the observation vector for each agents obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): print("Agent ", i, "'s observation: ", obs[i])
def test_render_env(save_new_images=False): np.random.seed(100) oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) oEnv.reset() oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
def test_empty_rail_generator(): n_agents = 1 x_dim = 5 y_dim = 10 # Check that a random level at with correct parameters is generated env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents) env.reset() # Check the dimensions assert env.rail.grid.shape == (y_dim, x_dim) # Check that no grid was generated assert np.count_nonzero(env.rail.grid) == 0 # Check that no agents where placed assert env.get_num_agents() == 0
def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"): """ Create an Editor MVC assembly around a railenv, or create one if None. """ if env is None: env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() self.editor = EditorModel(env, env_filename=env_filename) self.editor.view = self.view = View(self.editor, sGL=sGL) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.init_canvas() self.view.init_widgets() # has to be done after controller
def test_malfunction_process_statistically(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters( malfunction_rate=1 / 5, # Rate of malfunction occurence min_duration=5, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data), obs_builder_object=SingleAgentNavigationObs()) env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) # Next line only for test generation # agent_malfunction_list = [[] for i in range(10)] agent_malfunction_list = [ [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4], [0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2], [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], [0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], [5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 5], [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2], [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4] ] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent_idx in range(env.get_num_agents()): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data[ 'malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict)
def _launch(self): rail_generator = sparse_rail_generator( seed=self._config['seed'], max_num_cities=self._config['max_num_cities'], grid_mode=self._config['grid_mode'], max_rails_between_cities=self._config['max_rails_between_cities'], max_rails_in_city=self._config['max_rails_in_city']) malfunction_generator = no_malfunction_generator() if { 'malfunction_rate', 'malfunction_min_duration', 'malfunction_max_duration' } <= self._config.keys(): stochastic_data = { 'malfunction_rate': self._config['malfunction_rate'], 'min_duration': self._config['malfunction_min_duration'], 'max_duration': self._config['malfunction_max_duration'] } malfunction_generator = malfunction_from_params(stochastic_data) speed_ratio_map = None if 'speed_ratio_map' in self._config: speed_ratio_map = { float(k): float(v) for k, v in self._config['speed_ratio_map'].items() } schedule_generator = sparse_schedule_generator(speed_ratio_map) env = None try: env = RailEnv( width=self._config['width'], height=self._config['height'], rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=self._config['number_of_agents'], malfunction_generator_and_process_data=malfunction_generator, obs_builder_object=self._observation.builder(), remove_agents_at_target=False, random_seed=self._config['seed']) env.reset() except ValueError as e: logging.error("=" * 50) logging.error(f"Error while creating env: {e}") logging.error("=" * 50) return env
def test_rail_from_grid_transition_map(): rail, rail_map = make_simple_rail() n_agents = 3 env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) # Check if the number of non-empty rail cells is ok assert nr_rail_elements == 16 # Check that agents are placed on a rail for a in env.agents: assert env.rail.grid[a.position] != 0 assert env.get_num_agents() == n_agents
def demo(args=None): """Demo script to check installation""" env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999), schedule_generator=complex_schedule_generator(), number_of_agents=5) env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) while True: obs, info = env.reset() _done = False # Run a single episode here step = 0 while not _done: # Compute Action _action = {} for _idx, _ in enumerate(env.agents): _action[_idx] = np.random.randint(0, 5) obs, all_rewards, done, _ = env.step(_action) _done = done['__all__'] step += 1 env_renderer.render_env(show=True, frames=False, show_observations=False, show_predictions=False) time.sleep(0.3) return 0
def test_global_obs(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) global_obs, info = env.reset() # we have to take step for the agent to enter the grid. global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) assert (global_obs[0][0].shape == rail_map.shape + (16, )) rail_map_recons = np.zeros_like(rail_map) for i in range(global_obs[0][0].shape[0]): for j in range(global_obs[0][0].shape[1]): rail_map_recons[i, j] = int( ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) assert (rail_map_recons.all() == rail_map.all()) # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell obs_agents_state = global_obs[0][1] obs_agents_state = obs_agents_state + 1 assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def test_dead_end(): transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical, 90) dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end # We instantiate the following railway # O->-- where > is the train and O the target. After 6 steps, # the train should be done. rail_map = np.array( [[transitions.rotate_transition(dead_end_from_south, 270)] + [straight_horizontal] * 3 + [transitions.rotate_transition(dead_end_from_south, 90)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( [[dead_end_from_south]] + [[straight_vertical]] * 3 + [[transitions.rotate_transition(dead_end_from_south, 180)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
def test_multi_speed_init(): env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) # Empty dictionary for all agent action action_dict = dict() # Set all the different speeds # Reset environment and get initial observations for all agents env.reset(False, False, True) # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository old_pos = [] for i_agent in range(env.get_num_agents()): env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1) old_pos.append(env.agents[i_agent].position) # Run episode for step in range(100): # Choose an action for each agent in the environment for a in range(env.get_num_agents()): action = agent.act(0) action_dict.update({a: action}) # Check that agent did not move in between its speed updates assert old_pos[a] == env.agents[a].position # Environment step which returns the observations for all agents, their corresponding # reward and whether they are done _, _, _, _ = env.step(action_dict) # Update old position whenever an agent was allowed to move for i_agent in range(env.get_num_agents()): if (step + 1) % (i_agent + 1) == 0: print(step, i_agent, env.agents[i_agent].position) old_pos[i_agent] = env.agents[i_agent].position