def test_seeding_and_observations(): # Test if two different instances diverge with different observations rail, rail_map = make_simple_rail2() # Make two seperate envs with different observation builders # Global Observation env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset(False, False, False, random_seed=12) env2.reset(False, False, False, random_seed=12) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[1].initial_position == env2.agents[1].initial_position assert env.agents[2].initial_position == env2.agents[2].initial_position assert env.agents[3].initial_position == env2.agents[3].initial_position assert env.agents[4].initial_position == env2.agents[4].initial_position assert env.agents[5].initial_position == env2.agents[5].initial_position assert env.agents[6].initial_position == env2.agents[6].initial_position assert env.agents[7].initial_position == env2.agents[7].initial_position assert env.agents[8].initial_position == env2.agents[8].initial_position assert env.agents[9].initial_position == env2.agents[9].initial_position action_dict = {} for step in range(10): for a in range(env.get_num_agents()): action = np.random.randint(4) action_dict[a] = action env.step(action_dict) env2.step(action_dict) # Check that both environments end up in the same position assert env.agents[0].position == env2.agents[0].position assert env.agents[1].position == env2.agents[1].position assert env.agents[2].position == env2.agents[2].position assert env.agents[3].position == env2.agents[3].position assert env.agents[4].position == env2.agents[4].position assert env.agents[5].position == env2.agents[5].position assert env.agents[6].position == env2.agents[6].position assert env.agents[7].position == env2.agents[7].position assert env.agents[8].position == env2.agents[8].position assert env.agents[9].position == env2.agents[9].position for a in range(env.get_num_agents()): print("assert env.agents[{}].position == env2.agents[{}].position". format(a, a))
def test_dead_end(): transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical, 90) dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end # We instantiate the following railway # O->-- where > is the train and O the target. After 6 steps, # the train should be done. rail_map = np.array( [[transitions.rotate_transition(dead_end_from_south, 270)] + [straight_horizontal] * 3 + [transitions.rotate_transition(dead_end_from_south, 90)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( [[dead_end_from_south]] + [[straight_vertical]] * 3 + [[transitions.rotate_transition(dead_end_from_south, 180)]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)] rail_env.reset() rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
def tests_random_interference_from_outside(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) env_data = [] for step in range(200): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: # We randomly select an action action_dict[agent.handle] = RailEnvActions(2) _, reward, _, _ = env.step(action_dict) # Append the rewards of the first trial env_data.append((reward[0], env.agents[0].position)) assert reward[0] == env_data[step][0] assert env.agents[0].position == env_data[step][1] # Run the same test as above but with an external random generator running # Check that the reward stays the same rail, rail_map = make_simple_rail2() random.seed(47) np.random.seed(1234) env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4] for step in range(200): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: # We randomly select an action action_dict[agent.handle] = RailEnvActions(2) # Do dummy random number generations random.shuffle(dummy_list) np.random.rand() _, reward, _, _ = env.step(action_dict) assert reward[0] == env_data[step][0] assert env.agents[0].position == env_data[step][1]
def test_random_seeding(): # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() # Move target to unreachable position in order to not interfere with test for idx in range(100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=12), number_of_agents=10) env.reset(True, True, False, random_seed=1) env.agents[0].target = (0, 0) for step in range(10): actions = {} actions[0] = 2 env.step(actions) agent_positions = [] env.agents[0].initial_position == (3, 2) env.agents[1].initial_position == (3, 5) env.agents[2].initial_position == (3, 6) env.agents[3].initial_position == (5, 6) env.agents[4].initial_position == (3, 4) env.agents[5].initial_position == (3, 1) env.agents[6].initial_position == (3, 9) env.agents[7].initial_position == (4, 6) env.agents[8].initial_position == (0, 3) env.agents[9].initial_position == (3, 7)
def __init__(self, width, height, rail_generator, number_of_agents, remove_agents_at_target, obs_builder_object, wait_for_all_done, schedule_generator=random_schedule_generator(), name=None): super().__init__() self.env = RailEnv( width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=number_of_agents, obs_builder_object=obs_builder_object, remove_agents_at_target=remove_agents_at_target, ) self.wait_for_all_done = wait_for_all_done self.env_renderer = None self.agents_done = [] self.frame_step = 0 self.name = name self.number_of_agents = number_of_agents # Track when targets are reached. Ony used for correct reward propagation # when using wait_for_all_done=True self.at_target = dict( zip(list(np.arange(self.number_of_agents)), [False for _ in range(self.number_of_agents)]))
def test_path_not_exists(rendering=False): rail, rail_map = make_simple_rail_unconnected() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # south dead-end 0, # north (0, 3), # north dead-end False) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?")
def test_get_entry_directions(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() def _assert(position, expected): actual = env.get_valid_directions_on_grid(*position) assert actual == expected, "[{},{}] actual={}, expected={}".format( *position, actual, expected) # north dead end _assert((0, 3), [True, False, False, False]) # west dead end _assert((3, 0), [False, False, False, True]) # switch _assert((3, 3), [False, True, True, True]) # horizontal _assert((3, 2), [False, True, False, True]) # vertical _assert((2, 3), [True, False, True, False]) # nowhere _assert((0, 0), [False, False, False, False])
def test_get_shortest_paths_unreachable(): rail, rail_map = make_disconnected_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env.reset() # set the initial position agent = env.agents[0] agent.position = (3, 1) # west dead-end agent.initial_position = (3, 1) # west dead-end agent.direction = Grid4TransitionsEnum.WEST agent.target = (3, 9) # east dead-end agent.moving = True env.reset(False, False) actual = get_shortest_paths(env.distance_map) expected = {0: None} assert actual == expected, "actual={},expected={}".format(actual, expected)
def test_malfanction_from_params(): """ Test loading malfunction from Returns ------- """ stochastic_data = MalfunctionParameters( malfunction_rate=1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data)) env.reset() assert env.malfunction_process_data.malfunction_rate == 1000 assert env.malfunction_process_data.min_duration == 2 assert env.malfunction_process_data.max_duration == 5
def test_malfunction_values_and_behavior(): """ Test the malfunction counts down as desired Returns ------- """ # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} stochastic_data = MalfunctionParameters(malfunction_rate=0.001, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) env.reset(False, False, activate_agents=True, random_seed=10) # Assertions assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5] print("[") for time_step in range(15): # Move in the env env.step(action_dict) # Check that next_step decreases as expected assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
def test_single_malfunction_generator(): """ Test single malfunction generator Returns ------- """ rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=single_malfunction_generator( earlierst_malfunction=10, malfunction_duration=5)) for test in range(10): env.reset() action_dict = dict() tot_malfunctions = 0 print(test) for i in range(10): for agent in env.agents: # Go forward all the time action_dict[agent.handle] = RailEnvActions(2) env.step(action_dict) for agent in env.agents: # Go forward all the time tot_malfunctions += agent.malfunction_data['nr_malfunctions'] assert tot_malfunctions == 1
def test_global_obs(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) global_obs, info = env.reset() # we have to take step for the agent to enter the grid. global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) assert (global_obs[0][0].shape == rail_map.shape + (16, )) rail_map_recons = np.zeros_like(rail_map) for i in range(global_obs[0][0].shape[0]): for j in range(global_obs[0][0].shape[1]): rail_map_recons[i, j] = int( ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) assert (rail_map_recons.all() == rail_map.all()) # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell obs_agents_state = global_obs[0][1] obs_agents_state = obs_agents_state + 1 assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents assert env.agents[0].malfunction_data['malfunction'] == 0 assert env.agents[1].malfunction_data['malfunction'] == 10 assert env.agents[2].malfunction_data['malfunction'] == 0 assert env.agents[3].malfunction_data['malfunction'] == 10 assert env.agents[4].malfunction_data['malfunction'] == 10 assert env.agents[5].malfunction_data['malfunction'] == 10 assert env.agents[6].malfunction_data['malfunction'] == 10 assert env.agents[7].malfunction_data['malfunction'] == 10 assert env.agents[8].malfunction_data['malfunction'] == 10 assert env.agents[9].malfunction_data['malfunction'] == 10
def test_path_exists(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # north of south dead-end 0, # north (3, 9), # east dead-end True) check_path( env, rail, (6, 6), # south dead-end 2, # south (3, 9), # east dead-end True) check_path( env, rail, (3, 0), # east dead-end 3, # west (0, 3), # north dead-end True) check_path( env, rail, (5, 6), # east dead-end 0, # west (1, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 2, # south (3, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 0, # north (3, 3), # north dead-end True)
def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_invalid_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() # set the initial position agent = env.agents[0] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE agent = env.agents[1] agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE observations, info = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?") # get the trees to test obs_builder: TreeObsForRailEnv = env.obs_builder pp = pprint.PrettyPrinter(indent=4) tree_0 = observations[0] tree_1 = observations[1] env.obs_builder.util_print_obs_subtree(tree_0) env.obs_builder.util_print_obs_subtree(tree_1) # check the expectations expected_conflicts_0 = [('F', 'R')] expected_conflicts_1 = [('F', 'L')] _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def test_malfunction_process(): # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters( malfunction_rate=1, # Rate of malfunction occurence min_duration=3, # Minimal duration of malfunction max_duration=3 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data), obs_builder_object=SingleAgentNavigationObs()) obs, info = env.reset(False, False, True, random_seed=10) agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position # Move target to unreachable position in order to not interfere with test env.agents[0].target = (0, 0) for step in range(100): actions = {} for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: agent_malfunctioning = True else: agent_malfunctioning = False if agent_malfunctioning: # Check that agent is not moving while malfunctioning assert agent_old_position == env.agents[0].position agent_old_position = env.agents[0].position total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved assert env.agents[0].malfunction_data[ 'nr_malfunctions'] == 23, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that malfunctioning data was standing around assert total_down_time > 0
def test_rail_env_reset(): file_name = "test_rail_env_reset.pkl" # Test to save and load file. rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() #env.save(file_name) RailEnvPersister.save(env, file_name) dist_map_shape = np.shape(env.distance_map.get()) rails_initial = env.rail.grid agents_initial = env.agents #env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), # schedule_generator=schedule_from_file(file_name), number_of_agents=1, # obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) #env2.reset(False, False, False) env2, env2_dict = RailEnvPersister.load_new(file_name) rails_loaded = env2.rail.grid agents_loaded = env2.agents assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env3.reset(False, True, False) rails_loaded = env3.rail.grid agents_loaded = env3.agents assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env4.reset(True, False, False) rails_loaded = env4.rail.grid agents_loaded = env4.agents assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded
def test_malfunction_process_statistically(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters( malfunction_rate=1 / 5, # Rate of malfunction occurence min_duration=5, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data), obs_builder_object=SingleAgentNavigationObs()) env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) # Next line only for test generation # agent_malfunction_list = [[] for i in range(10)] agent_malfunction_list = [ [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4], [0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2], [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], [0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], [5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 5], [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2], [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4] ] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent_idx in range(env.get_num_agents()): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data[ 'malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict)
def test_rail_from_grid_transition_map(): rail, rail_map = make_simple_rail() n_agents = 3 env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) # Check if the number of non-empty rail cells is ok assert nr_rail_elements == 16 # Check that agents are placed on a rail for a in env.agents: assert env.rail.grid[a.position] != 0 assert env.get_num_agents() == n_agents
def test_last_malfunction_step(): """ Test to check that agent moves when it is not malfunctioning """ # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 1. / 3. env.agents[0].target = (0, 0) env.reset(False, False, True) # Force malfunction to be off at beginning and next malfunction to happen in 2 steps env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 env_data = [] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: # Go forward all the time action_dict[agent.handle] = RailEnvActions(2) if env.agents[0].malfunction_data['malfunction'] < 1: agent_can_move = True # Store the position before and after the step pre_position = env.agents[0].speed_data['position_fraction'] _, reward, _, _ = env.step(action_dict) # Check if the agent is still allowed to move in this step if env.agents[0].malfunction_data['malfunction'] > 0: agent_can_move = False post_position = env.agents[0].speed_data['position_fraction'] # Assert that the agent moved while it was still allowed if agent_can_move: assert pre_position != post_position else: assert post_position == pre_position
def __init__( self, width, height, rail_generator: RailGenerator = random_rail_generator(), schedule_generator: ScheduleGenerator = random_schedule_generator( ), number_of_agents=1, obs_builder_object: ObservationBuilder = TreeObsForRailEnv( max_depth=2), max_episode_steps=None, stochastic_data=None): super().__init__(width, height, rail_generator, schedule_generator, number_of_agents, obs_builder_object) self.graph_low_level = nx.DiGraph() self.graph_high_level = nx.Graph() self.create_graph_from_env(self.obs_builder)
def test_walker(): # _ _ _ transitions = RailEnvTransitions() cells = transitions.transition_list dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) vertical_straight = cells[1] horizontal_straight = transitions.rotate_transition(vertical_straight, 90) rail_map = np.array( [[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), ) env.reset() # set initial position and direction for testing... env.agents[0].position = (0, 1) env.agents[0].direction = 1 env.agents[0].target = (0, 0) # reset to set agents from agents_static env.reset(False, False) print(env.distance_map.get()[(0, *[0, 1], 1)]) assert env.distance_map.get()[(0, *[0, 1], 1)] == 3 print(env.distance_map.get()[(0, *[0, 2], 3)]) assert env.distance_map.get()[(0, *[0, 2], 1)] == 2
def test_schedule_from_file_random(): """ Test to see that all parameters are loaded as expected Returns ------- """ # Different agent types (trains) with different speeds. speed_ration_map = { 1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25 } # Slow freight train # Generate random test env rail_generator = random_rail_generator() schedule_generator = random_schedule_generator(speed_ration_map) create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, schedule_generator=schedule_generator) # Random generator rail_generator = rail_from_file("./random_env_test.pkl") schedule_generator = schedule_from_file("./random_env_test.pkl") random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) random_env_from_file.reset(True, True) # Assert loaded agent number is correct assert random_env_from_file.get_num_agents() == 10 # Assert max steps is correct assert random_env_from_file._max_episode_steps == 1350
def test_initial_status(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( position=None, # not entered grid yet direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.READY_TO_DEPART, action=RailEnvActions.DO_NOTHING, reward=env.step_penalty * 0.5, ), Replay( position=None, # not entered grid yet before step direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.READY_TO_DEPART, action=RailEnvActions.MOVE_LEFT, reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty! ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.ACTIVE, action=RailEnvActions.MOVE_LEFT, reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.ACTIVE, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, status=RailAgentStatus.ACTIVE, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, status=RailAgentStatus.ACTIVE, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 status=RailAgentStatus.ACTIVE), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! status=RailAgentStatus.ACTIVE), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_RIGHT, reward=env.step_penalty * 0.5, # status=RailAgentStatus.ACTIVE), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # status=RailAgentStatus.ACTIVE), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # already done status=RailAgentStatus.DONE), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # already done status=RailAgentStatus.DONE) ], initial_position=(3, 9), # east dead-end initial_direction=Grid4TransitionsEnum.EAST, target=(3, 5), speed=0.5) run_replay_config(env, [test_config], activate_agents=False)
def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_LEFT, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.STOP_MOVING, reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty ), # Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.STOP_MOVING, reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when stopped ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting + running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end speed=0.5, initial_position=(3, 9), # east dead-end initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config])
def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # add additional step in the cell Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, set_malfunction=2, # recovers in two steps from now!, malfunction=2, reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, malfunction=1, reward=env.step_penalty * 0.5 # recovered: running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, set_malfunction=2, # recovers in two steps from now! malfunction=2, reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, malfunction=1, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.STOP_MOVING, reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.STOP_MOVING, reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # DO_NOTHING keeps moving! Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.DO_NOTHING, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end speed=0.5, initial_position=(3, 9), # east dead-end initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config])
def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) test_configs = [ ReplayConfig( replay=[ Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 1.0 / 3.0 # starting and running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ) ], target=(3, 0), # west dead-end speed=1 / 3, initial_position=(3, 8), initial_direction=Grid4TransitionsEnum.WEST, ), ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_LEFT, reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # not blocked, action required! Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end speed=0.5, initial_position=(3, 9), # east dead-end initial_direction=Grid4TransitionsEnum.EAST, ) ] run_replay_config(env, test_configs)
def test_initial_malfunction_do_nothing(): stochastic_data = MalfunctionParameters(malfunction_rate=70, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator ) env.reset() set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, reward=env.step_penalty, # full step penalty while malfunctioning status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving # Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), # we haven't started moving yet --> stay here Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # step penalty for speed 1.0 status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_malfunction_stop_moving(): rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) env.reset() print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, reward=env.step_penalty, # full step penalty when stopped status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty when stopped status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving # Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), # we have stopped and do nothing --> should stand still Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), Replay( position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # full step penalty while stopped status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_malfunction(): stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static env.reset(False, False, True, random_seed=10) print(env.agents[0].malfunction_data) env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, reward=env.step_penalty # full step penalty when malfunctioning ), Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=2, reward=env.step_penalty # full step penalty when malfunctioning ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=1, reward=env.step_penalty ), # malfunctioning ends: starting and running at speed 1.0 Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0 ), Replay( position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty # running at speed 1.0 ) ], speed=env.agents[0].speed_data['speed'], target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [replay_config])