def test_save_load():
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=2,
                                                        nr_extra=5,
                                                        min_dist=6,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2)
    env.reset()
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
    env.save("test_save.dat")
    env.load("test_save.dat")
    assert (env.width == 10)
    assert (env.height == 10)
    assert (len(env.agents) == 2)
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
Пример #2
0
def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100):
    # Parameter initialization
    print('Creating {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.format(
        test_nr, parameters[0], parameters[1], parameters[2]))
    # Reset environment
    random.seed(parameters[3])
    np.random.seed(parameters[3])
    nr_paths = max(4, parameters[2] + int(0.5 * parameters[2]))
    min_dist = int(min([parameters[0], parameters[1]]) * 0.75)
    env = RailEnv(width=parameters[0],
                  height=parameters[1],
                  rail_generator=complex_rail_generator(nr_start_goal=nr_paths,
                                                        nr_extra=5,
                                                        min_dist=min_dist,
                                                        max_dist=99999,
                                                        seed=parameters[3]),
                  schedule_generator=complex_schedule_generator(),
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
                  number_of_agents=parameters[2])
    printProgressBar(0,
                     nr_trials_per_test,
                     prefix='Progress:',
                     suffix='Complete',
                     length=20)
    for trial in range(nr_trials_per_test):
        # Reset the env
        env.reset(True, True)
        env.save("./Tests/{}/Level_{}.pkl".format(test_nr, trial))
        printProgressBar(trial + 1,
                         nr_trials_per_test,
                         prefix='Progress:',
                         suffix='Complete',
                         length=20)

    return
Пример #3
0
def test_save_load():
    env = RailEnv(width=10, height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
                  schedule_generator=complex_schedule_generator(), number_of_agents=2)
    env.reset()
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
    
    os.makedirs("tmp", exist_ok=True)

    RailEnvPersister.save(env, "tmp/test_save.pkl")
    env.save("tmp/test_save_2.pkl")

    #env.load("test_save.dat")
    env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
    assert (env.width == 10)
    assert (env.height == 10)
    assert (len(env.agents) == 2)
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
Пример #4
0
def test_malfanction_to_and_from_file():
    """
    Test loading malfunction from
    Returns
    -------

    """
    stochastic_data = MalfunctionParameters(
        malfunction_rate=1000,  # Rate of malfunction occurence
        min_duration=2,  # Minimal duration of malfunction
        max_duration=5  # Max duration of malfunction
    )

    rail, rail_map = make_simple_rail2()

    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data))
    env.reset()
    env.save("./malfunction_saving_loading_tests.pkl")

    malfunction_generator, malfunction_process_data = malfunction_from_file(
        "./malfunction_saving_loading_tests.pkl")
    env2 = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data))

    env2.reset()

    assert env2.malfunction_process_data == env.malfunction_process_data
    assert env2.malfunction_process_data.malfunction_rate == 1000
    assert env2.malfunction_process_data.min_duration == 2
    assert env2.malfunction_process_data.max_duration == 5
Пример #5
0
def create_save_env(path,
                    width,
                    height,
                    num_trains,
                    max_cities,
                    max_rails_between_cities,
                    max_rails_in_cities,
                    grid=False,
                    seed=0):
    '''
    Create a RailEnv environment with the given settings and save it as pickle
    '''
    rail_generator = sparse_rail_generator(
        max_num_cities=max_cities,
        seed=seed,
        grid_mode=grid,
        max_rails_between_cities=max_rails_between_cities,
        max_rails_in_city=max_rails_in_cities,
    )
    env = RailEnv(width=width,
                  height=height,
                  rail_generator=rail_generator,
                  number_of_agents=num_trains)
    env.save(path)
Пример #6
0
def create_test_env(fnParams, nTest, sDir):
    (seed, width, height, nr_trains, nr_cities, max_rails_between_cities,
     max_rails_in_cities, malfunction_rate, malfunction_min_duration,
     malfunction_max_duration) = fnParams(nTest)
    #if not ShouldRunTest(test_id):
    #    continue

    rail_generator = sparse_rail_generator(
        max_num_cities=nr_cities,
        seed=seed,
        grid_mode=False,
        max_rails_between_cities=max_rails_between_cities,
        max_rails_in_city=max_rails_in_cities,
    )

    #stochastic_data = {'malfunction_rate': malfunction_rate,
    #                    'min_duration': malfunction_min_duration,
    #                    'max_duration': malfunction_max_duration
    #                }

    stochastic_data = MalfunctionParameters(
        malfunction_rate=malfunction_rate,
        min_duration=malfunction_min_duration,
        max_duration=malfunction_max_duration)

    observation_builder = GlobalObsForRailEnv()

    DEFAULT_SPEED_RATIO_MAP = {
        1.: 0.25,
        1. / 2.: 0.25,
        1. / 3.: 0.25,
        1. / 4.: 0.25
    }

    schedule_generator = sparse_schedule_generator(DEFAULT_SPEED_RATIO_MAP)

    for iAttempt in range(5):
        try:
            env = RailEnv(
                width=width,
                height=height,
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=nr_trains,
                malfunction_generator_and_process_data=malfunction_from_params(
                    stochastic_data),
                obs_builder_object=observation_builder,
                remove_agents_at_target=True)
            obs = env.reset(random_seed=seed)
            break
        except ValueError as oErr:
            print("Error:", oErr)
            width += 5
            height += 5
            print("Try again with larger env: (w,h):", width, height)

    if not os.path.exists(sDir):
        os.makedirs(sDir)

    sfName = "{}/Level_{}.mpk".format(sDir, nTest)
    if os.path.exists(sfName):
        os.remove(sfName)
    env.save(sfName)

    sys.stdout.write(".")
    sys.stdout.flush()

    return env
def test_rail_env_reset():
    file_name = "test_rail_env_reset.pkl"

    # Test to save and load file.

    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    env.save(file_name)
    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

    env2 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(
                       max_depth=2,
                       predictor=ShortestPathPredictorForRailEnv()))
    env2.reset(False, False, False)
    rails_loaded = env2.rail.grid
    agents_loaded = env2.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    env3 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(
                       max_depth=2,
                       predictor=ShortestPathPredictorForRailEnv()))
    env3.reset(False, True, False)
    rails_loaded = env3.rail.grid
    agents_loaded = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(
                       max_depth=2,
                       predictor=ShortestPathPredictorForRailEnv()))
    env4.reset(True, False, False)
    rails_loaded = env4.rail.grid
    agents_loaded = env4.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
Пример #8
0
def tests_rail_from_file():
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    env.save(file_name)
    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

    env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                  schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    # Check that distance map was not recomputed
    assert np.shape(env.distance_map.get()) == dist_map_shape
    assert env.distance_map.get() is not None

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

    env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(),
                   number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    env2.save(file_name_2)

    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents

    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env3.reset()
    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )
    env4.reset()
    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

    # Check that no distance map was saved
    assert not hasattr(env2.obs_builder, "distance_map")
    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4

    # Check that distance map was generated with correct shape
    assert env4.distance_map.get() is not None
    assert np.shape(env4.distance_map.get()) == dist_map_shape