def test_single_malfunction_generator():
    """
    Test single malfunction generator
    Returns
    -------

    """

    rail, rail_map = make_simple_rail2()
    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=single_malfunction_generator(
            earlierst_malfunction=10, malfunction_duration=5))
    for test in range(10):
        env.reset()
        action_dict = dict()
        tot_malfunctions = 0
        print(test)
        for i in range(10):
            for agent in env.agents:
                # Go forward all the time
                action_dict[agent.handle] = RailEnvActions(2)

            env.step(action_dict)
        for agent in env.agents:
            # Go forward all the time
            tot_malfunctions += agent.malfunction_data['nr_malfunctions']
        assert tot_malfunctions == 1
示例#2
0
def test_path_not_exists(rendering=False):
    rail, rail_map = make_simple_rail_unconnected()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # south dead-end
        0,  # north
        (0, 3),  # north dead-end
        False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
def test_seeding_and_observations():
    # Test if two different instances diverge with different observations
    rail, rail_map = make_simple_rail2()

    # Make two seperate envs with different observation builders
    # Global Observation
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=12),
                  number_of_agents=10,
                  obs_builder_object=GlobalObsForRailEnv())
    # Tree Observation
    env2 = RailEnv(width=25,
                   height=30,
                   rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(seed=12),
                   number_of_agents=10,
                   obs_builder_object=TreeObsForRailEnv(
                       max_depth=2,
                       predictor=ShortestPathPredictorForRailEnv()))

    env.reset(False, False, False, random_seed=12)
    env2.reset(False, False, False, random_seed=12)

    # Check that both environments produce the same initial start positions
    assert env.agents[0].initial_position == env2.agents[0].initial_position
    assert env.agents[1].initial_position == env2.agents[1].initial_position
    assert env.agents[2].initial_position == env2.agents[2].initial_position
    assert env.agents[3].initial_position == env2.agents[3].initial_position
    assert env.agents[4].initial_position == env2.agents[4].initial_position
    assert env.agents[5].initial_position == env2.agents[5].initial_position
    assert env.agents[6].initial_position == env2.agents[6].initial_position
    assert env.agents[7].initial_position == env2.agents[7].initial_position
    assert env.agents[8].initial_position == env2.agents[8].initial_position
    assert env.agents[9].initial_position == env2.agents[9].initial_position

    action_dict = {}
    for step in range(10):
        for a in range(env.get_num_agents()):
            action = np.random.randint(4)
            action_dict[a] = action
        env.step(action_dict)
        env2.step(action_dict)

    # Check that both environments end up in the same position

    assert env.agents[0].position == env2.agents[0].position
    assert env.agents[1].position == env2.agents[1].position
    assert env.agents[2].position == env2.agents[2].position
    assert env.agents[3].position == env2.agents[3].position
    assert env.agents[4].position == env2.agents[4].position
    assert env.agents[5].position == env2.agents[5].position
    assert env.agents[6].position == env2.agents[6].position
    assert env.agents[7].position == env2.agents[7].position
    assert env.agents[8].position == env2.agents[8].position
    assert env.agents[9].position == env2.agents[9].position
    for a in range(env.get_num_agents()):
        print("assert env.agents[{}].position == env2.agents[{}].position".
              format(a, a))
def test_random_seeding():
    # Set fixed malfunction duration for this test
    rail, rail_map = make_simple_rail2()

    # Move target to unreachable position in order to not interfere with test
    for idx in range(100):
        env = RailEnv(width=25,
                      height=30,
                      rail_generator=rail_from_grid_transition_map(rail),
                      schedule_generator=random_schedule_generator(seed=12),
                      number_of_agents=10)
        env.reset(True, True, False, random_seed=1)

        env.agents[0].target = (0, 0)
        for step in range(10):
            actions = {}
            actions[0] = 2
            env.step(actions)
        agent_positions = []

        env.agents[0].initial_position == (3, 2)
        env.agents[1].initial_position == (3, 5)
        env.agents[2].initial_position == (3, 6)
        env.agents[3].initial_position == (5, 6)
        env.agents[4].initial_position == (3, 4)
        env.agents[5].initial_position == (3, 1)
        env.agents[6].initial_position == (3, 9)
        env.agents[7].initial_position == (4, 6)
        env.agents[8].initial_position == (0, 3)
        env.agents[9].initial_position == (3, 7)
示例#5
0
def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100):
    # Parameter initialization
    print('Creating {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.format(
        test_nr, parameters[0], parameters[1], parameters[2]))
    # Reset environment
    random.seed(parameters[3])
    np.random.seed(parameters[3])
    nr_paths = max(4, parameters[2] + int(0.5 * parameters[2]))
    min_dist = int(min([parameters[0], parameters[1]]) * 0.75)
    env = RailEnv(width=parameters[0],
                  height=parameters[1],
                  rail_generator=complex_rail_generator(nr_start_goal=nr_paths,
                                                        nr_extra=5,
                                                        min_dist=min_dist,
                                                        max_dist=99999,
                                                        seed=parameters[3]),
                  schedule_generator=complex_schedule_generator(),
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
                  number_of_agents=parameters[2])
    printProgressBar(0,
                     nr_trials_per_test,
                     prefix='Progress:',
                     suffix='Complete',
                     length=20)
    for trial in range(nr_trials_per_test):
        # Reset the env
        env.reset(True, True)
        env.save("./Tests/{}/Level_{}.pkl".format(test_nr, trial))
        printProgressBar(trial + 1,
                         nr_trials_per_test,
                         prefix='Progress:',
                         suffix='Complete',
                         length=20)

    return
def test_malfanction_from_params():
    """
    Test loading malfunction from
    Returns
    -------

    """
    stochastic_data = MalfunctionParameters(
        malfunction_rate=1000,  # Rate of malfunction occurence
        min_duration=2,  # Minimal duration of malfunction
        max_duration=5  # Max duration of malfunction
    )
    rail, rail_map = make_simple_rail2()

    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data))
    env.reset()
    assert env.malfunction_process_data.malfunction_rate == 1000
    assert env.malfunction_process_data.min_duration == 2
    assert env.malfunction_process_data.max_duration == 5
示例#7
0
def test_get_shortest_paths_unreachable():
    rail, rail_map = make_disconnected_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.position = (3, 1)  # west dead-end
    agent.initial_position = (3, 1)  # west dead-end
    agent.direction = Grid4TransitionsEnum.WEST
    agent.target = (3, 9)  # east dead-end
    agent.moving = True

    env.reset(False, False)

    actual = get_shortest_paths(env.distance_map)
    expected = {0: None}

    assert actual == expected, "actual={},expected={}".format(actual, expected)
def test_get_entry_directions():
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    def _assert(position, expected):
        actual = env.get_valid_directions_on_grid(*position)
        assert actual == expected, "[{},{}] actual={}, expected={}".format(
            *position, actual, expected)

    # north dead end
    _assert((0, 3), [True, False, False, False])

    # west dead end
    _assert((3, 0), [False, False, False, True])

    # switch
    _assert((3, 3), [False, True, True, True])

    # horizontal
    _assert((3, 2), [False, True, False, True])

    # vertical
    _assert((2, 3), [True, False, True, False])

    # nowhere
    _assert((0, 0), [False, False, False, False])
示例#9
0
def test_save_load():
    env = RailEnv(width=10, height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
                  schedule_generator=complex_schedule_generator(), number_of_agents=2)
    env.reset()
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
    
    os.makedirs("tmp", exist_ok=True)

    RailEnvPersister.save(env, "tmp/test_save.pkl")
    env.save("tmp/test_save_2.pkl")

    #env.load("test_save.dat")
    env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
    assert (env.width == 10)
    assert (env.height == 10)
    assert (len(env.agents) == 2)
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
示例#10
0
def gen_env(number_agents, width, height, n_start_goal, seed):

    speed_ration_map = {
        1.: 0.25,  # Fast passenger train
        1. / 2.: 0.25,  # Fast freight train
        1. / 3.: 0.25,  # Slow commuter train
        1. / 4.: 0.25
    }  # Slow freight train

    env = RailEnv(width=width,
                  height=height,
                  rail_generator=complex_rail_generator(
                      nr_start_goal=n_start_goal,
                      nr_extra=3,
                      min_dist=6,
                      max_dist=99999,
                      seed=seed),
                  schedule_generator=complex_schedule_generator(
                      speed_ratio_map=speed_ration_map),
                  number_of_agents=number_agents,
                  obs_builder_object=TreeObsForRailEnv(max_depth=5))

    env.reset()
    env.step(dict(zip(range(number_agents), [2] * number_agents)))

    return env
示例#11
0
def test_malfunction_before_entry():
    """Tests that malfunctions are working properly for agents before entering the environment!"""
    # Set fixed malfunction duration for this test
    stochastic_data = MalfunctionParameters(malfunction_rate=2,  # Rate of malfunction occurence
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    env.reset(False, False, False, random_seed=10)
    env.agents[0].target = (0, 0)

    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 10
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 10
    assert env.agents[4].malfunction_data['malfunction'] == 10
    assert env.agents[5].malfunction_data['malfunction'] == 10
    assert env.agents[6].malfunction_data['malfunction'] == 10
    assert env.agents[7].malfunction_data['malfunction'] == 10
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10
def test_save_load():
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=2,
                                                        nr_extra=5,
                                                        min_dist=6,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2)
    env.reset()
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
    env.save("test_save.dat")
    env.load("test_save.dat")
    assert (env.width == 10)
    assert (env.height == 10)
    assert (len(env.agents) == 2)
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
示例#13
0
def render_test(parameters, test_nr=0, nr_examples=5):
    for trial in range(nr_examples):
        # Reset the env
        print(
            'Showing {} Level {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.
            format(test_nr, trial, parameters[0], parameters[1],
                   parameters[2]))
        file_name = "./Tests/{}/Level_{}.pkl".format(test_nr, trial)

        env = RailEnv(
            width=1,
            height=1,
            rail_generator=rail_from_file(file_name),
            obs_builder_object=TreeObsForRailEnv(max_depth=2),
            number_of_agents=1,
        )
        env_renderer = RenderTool(
            env,
            gl="PILSVG",
        )
        env_renderer.set_new_rail()

        env.reset(False, False)
        env_renderer.render_env(show=True, show_observations=False)

        time.sleep(0.1)
        env_renderer.close_window()
    return
示例#14
0
def test_malfunction_values_and_behavior():
    """
    Test the malfunction counts down as desired
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
    stochastic_data = MalfunctionParameters(malfunction_rate=0.001,  # Rate of malfunction occurence
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )

    env.reset(False, False, activate_agents=True, random_seed=10)

    # Assertions
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
    print("[")
    for time_step in range(15):
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
示例#15
0
def test_path_exists(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # north of south dead-end
        0,  # north
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (6, 6),  # south dead-end
        2,  # south
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (3, 0),  # east dead-end
        3,  # west
        (0, 3),  # north dead-end
        True)
    check_path(
        env,
        rail,
        (5, 6),  # east dead-end
        0,  # west
        (1, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        2,  # south
        (3, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        0,  # north
        (3, 3),  # north dead-end
        True)
示例#16
0
def load_env(env_dict, obs_builder_object=GlobalObsForRailEnv()):
    """
    Loads an env
    """
    env = RailEnv(height=4, width=4, obs_builder_object=obs_builder_object)
    env.reset(regenerate_rail=False, regenerate_schedule=False)
    RailEnvPersister.set_full_state(env, env_dict)
    return env
def test_load_env():
    env = RailEnv(10, 10)
    env.reset()
    env.load_resource('env_data.tests', 'test-10x10.mpk')

    agent_static = EnvAgent((0, 0), 2, (5, 5), False)
    env.add_agent(agent_static)
    assert env.get_num_agents() == 1
示例#18
0
def test_random_rail_generator():
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents)
    env.reset()
    assert env.rail.grid.shape == (y_dim, x_dim)
    assert env.get_num_agents() == n_agents
def test_shortest_path_predictor_conflicts(rendering=False):
    rail, rail_map = make_invalid_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=2,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.initial_position = (5, 6)  # south dead-end
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    agent = env.agents[1]
    agent.initial_position = (3, 8)  # east dead-end
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.initial_direction = 3  # west
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    observations, info = env.reset(False, False, True)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
    tree_0 = observations[0]
    tree_1 = observations[1]
    env.obs_builder.util_print_obs_subtree(tree_0)
    env.obs_builder.util_print_obs_subtree(tree_1)

    # check the expectations
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0,
                              "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1,
                              "agent[1]: ")
示例#20
0
def main():
    env = RailEnv(width=7,
                  height=7,
                  rail_generator=random_rail_generator(),
                  number_of_agents=3,
                  obs_builder_object=SimpleObs())
    env.reset()

    # Print the observation vector for each agents
    obs, all_rewards, done, _ = env.step({0: 0})
    for i in range(env.get_num_agents()):
        print("Agent ", i, "'s observation: ", obs[i])
示例#21
0
def test_render_env(save_new_images=False):
    np.random.seed(100)
    oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2))
    oEnv.reset()
    oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
    oRT = rt.RenderTool(oEnv, gl="PILSVG")
    oRT.render_env(show=False)
    checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)

    oRT = rt.RenderTool(oEnv, gl="PIL")
    oRT.render_env()
    checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
示例#22
0
def test_empty_rail_generator():
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents)
    env.reset()
    # Check the dimensions
    assert env.rail.grid.shape == (y_dim, x_dim)
    # Check that no grid was generated
    assert np.count_nonzero(env.rail.grid) == 0
    # Check that no agents where placed
    assert env.get_num_agents() == 0
示例#23
0
    def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"):
        """ Create an Editor MVC assembly around a railenv, or create one if None.
        """
        if env is None:
            env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
                          obs_builder_object=TreeObsForRailEnv(max_depth=2))

        env.reset()

        self.editor = EditorModel(env, env_filename=env_filename)
        self.editor.view = self.view = View(self.editor, sGL=sGL)
        self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view)
        self.view.init_canvas()
        self.view.init_widgets()  # has to be done after controller
示例#24
0
def test_malfunction_process_statistically():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = MalfunctionParameters(
        malfunction_rate=1 / 5,  # Rate of malfunction occurence
        min_duration=5,  # Minimal duration of malfunction
        max_duration=5  # Max duration of malfunction
    )

    rail, rail_map = make_simple_rail2()

    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data),
        obs_builder_object=SingleAgentNavigationObs())

    env.reset(True, True, False, random_seed=10)

    env.agents[0].target = (0, 0)
    # Next line only for test generation
    # agent_malfunction_list = [[] for i in range(10)]
    agent_malfunction_list = [
        [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4],
        [0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2],
        [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
        [0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 5],
        [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2],
        [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4]
    ]

    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent_idx in range(env.get_num_agents()):
            # We randomly select an action
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
            assert env.agents[agent_idx].malfunction_data[
                'malfunction'] == agent_malfunction_list[agent_idx][step]
        env.step(action_dict)
示例#25
0
    def _launch(self):
        rail_generator = sparse_rail_generator(
            seed=self._config['seed'],
            max_num_cities=self._config['max_num_cities'],
            grid_mode=self._config['grid_mode'],
            max_rails_between_cities=self._config['max_rails_between_cities'],
            max_rails_in_city=self._config['max_rails_in_city'])

        malfunction_generator = no_malfunction_generator()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = malfunction_from_params(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(
                width=self._config['width'],
                height=self._config['height'],
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=self._config['number_of_agents'],
                malfunction_generator_and_process_data=malfunction_generator,
                obs_builder_object=self._observation.builder(),
                remove_agents_at_target=False,
                random_seed=self._config['seed'])

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
示例#26
0
def test_rail_from_grid_transition_map():
    rail, rail_map = make_simple_rail()
    n_agents = 3
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=n_agents)
    env.reset(False, False, True)
    nr_rail_elements = np.count_nonzero(env.rail.grid)

    # Check if the number of non-empty rail cells is ok
    assert nr_rail_elements == 16

    # Check that agents are placed on a rail
    for a in env.agents:
        assert env.rail.grid[a.position] != 0

    assert env.get_num_agents() == n_agents
示例#27
0
文件: cli.py 项目: hagrid67/flatland
def demo(args=None):
    """Demo script to check installation"""
    env = RailEnv(width=15,
                  height=15,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=5)

    env._max_episode_steps = int(15 * (env.width + env.height))
    env_renderer = RenderTool(env)

    while True:
        obs, info = env.reset()
        _done = False
        # Run a single episode here
        step = 0
        while not _done:
            # Compute Action
            _action = {}
            for _idx, _ in enumerate(env.agents):
                _action[_idx] = np.random.randint(0, 5)
            obs, all_rewards, done, _ = env.step(_action)
            _done = done['__all__']
            step += 1
            env_renderer.render_env(show=True,
                                    frames=False,
                                    show_observations=False,
                                    show_predictions=False)
            time.sleep(0.3)
    return 0
示例#28
0
def test_global_obs():
    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())

    global_obs, info = env.reset()

    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

    assert (global_obs[0][0].shape == rail_map.shape + (16, ))

    rail_map_recons = np.zeros_like(rail_map)
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)

    assert (rail_map_recons.all() == rail_map.all())

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
示例#29
0
def test_dead_end():
    transitions = RailEnvTransitions()

    straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
    straight_horizontal = transitions.rotate_transition(straight_vertical,
                                                        90)

    dead_end_from_south = int('0010000000000000', 2)  # Case 7 - dead end

    # We instantiate the following railway
    # O->-- where > is the train and O the target. After 6 steps,
    # the train should be done.

    rail_map = np.array(
        [[transitions.rotate_transition(dead_end_from_south, 270)] +
         [straight_horizontal] * 3 +
         [transitions.rotate_transition(dead_end_from_south, 90)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    # We try the configuration in the 4 directions:
    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]

    # In the vertical configuration:
    rail_map = np.array(
        [[dead_end_from_south]] + [[straight_vertical]] * 3 +
        [[transitions.rotate_transition(dead_end_from_south, 180)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
def test_multi_speed_init():
    env = RailEnv(width=50,
                  height=50,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=5)
    # Initialize the agent with the parameters corresponding to the environment and observation_builder
    agent = RandomAgent(218, 4)

    # Empty dictionary for all agent action
    action_dict = dict()

    # Set all the different speeds
    # Reset environment and get initial observations for all agents
    env.reset(False, False, True)

    # Here you can also further enhance the provided observation by means of normalization
    # See training navigation example in the baseline repository
    old_pos = []
    for i_agent in range(env.get_num_agents()):
        env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
        old_pos.append(env.agents[i_agent].position)

    # Run episode
    for step in range(100):

        # Choose an action for each agent in the environment
        for a in range(env.get_num_agents()):
            action = agent.act(0)
            action_dict.update({a: action})

            # Check that agent did not move in between its speed updates
            assert old_pos[a] == env.agents[a].position

        # Environment step which returns the observations for all agents, their corresponding
        # reward and whether they are done
        _, _, _, _ = env.step(action_dict)

        # Update old position whenever an agent was allowed to move
        for i_agent in range(env.get_num_agents()):
            if (step + 1) % (i_agent + 1) == 0:
                print(step, i_agent, env.agents[i_agent].position)
                old_pos[i_agent] = env.agents[i_agent].position