示例#1
0
def test_global_obs():
    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())

    global_obs, info = env.reset()

    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

    assert (global_obs[0][0].shape == rail_map.shape + (16, ))

    rail_map_recons = np.zeros_like(rail_map)
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)

    assert (rail_map_recons.all() == rail_map.all())

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def test_get_entry_directions():
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    def _assert(position, expected):
        actual = env.get_valid_directions_on_grid(*position)
        assert actual == expected, "[{},{}] actual={}, expected={}".format(
            *position, actual, expected)

    # north dead end
    _assert((0, 3), [True, False, False, False])

    # west dead end
    _assert((3, 0), [False, False, False, True])

    # switch
    _assert((3, 3), [False, True, True, True])

    # horizontal
    _assert((3, 2), [False, True, False, True])

    # vertical
    _assert((2, 3), [True, False, True, False])

    # nowhere
    _assert((0, 0), [False, False, False, False])
示例#3
0
def test_path_exists(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # north of south dead-end
        0,  # north
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (6, 6),  # south dead-end
        2,  # south
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (3, 0),  # east dead-end
        3,  # west
        (0, 3),  # north dead-end
        True)
    check_path(
        env,
        rail,
        (5, 6),  # east dead-end
        0,  # west
        (1, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        2,  # south
        (3, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        0,  # north
        (3, 3),  # north dead-end
        True)
示例#4
0
def test_rail_env_reset():
    file_name = "test_rail_env_reset.pkl"

    # Test to save and load file.

    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    #env.save(file_name)
    RailEnvPersister.save(env, file_name)

    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

    #env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
    #               schedule_generator=schedule_from_file(file_name), number_of_agents=1,
    #               obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    #env2.reset(False, False, False)
    env2, env2_dict = RailEnvPersister.load_new(file_name)

    rails_loaded = env2.rail.grid
    agents_loaded = env2.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env3.reset(False, True, False)
    rails_loaded = env3.rail.grid
    agents_loaded = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env4.reset(True, False, False)
    rails_loaded = env4.rail.grid
    agents_loaded = env4.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
示例#5
0
def test_rail_from_grid_transition_map():
    rail, rail_map = make_simple_rail()
    n_agents = 3
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=n_agents)
    env.reset(False, False, True)
    nr_rail_elements = np.count_nonzero(env.rail.grid)

    # Check if the number of non-empty rail cells is ok
    assert nr_rail_elements == 16

    # Check that agents are placed on a rail
    for a in env.agents:
        assert env.rail.grid[a.position] != 0

    assert env.get_num_agents() == n_agents
def test_initial_status():
    """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()),
                  remove_agents_at_target=False)
    env.reset()
    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=None,  # not entered grid yet
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.READY_TO_DEPART,
                action=RailEnvActions.DO_NOTHING,
                reward=env.step_penalty * 0.5,
            ),
            Replay(
                position=None,  # not entered grid yet before step
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.READY_TO_DEPART,
                action=RailEnvActions.MOVE_LEFT,
                reward=env.step_penalty *
                0.5,  # auto-correction left to forward without penalty!
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.ACTIVE,
                action=RailEnvActions.MOVE_LEFT,
                reward=env.start_penalty +
                env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.ACTIVE,
                action=None,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                status=RailAgentStatus.ACTIVE,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                status=RailAgentStatus.ACTIVE,
                action=None,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
                status=RailAgentStatus.ACTIVE),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty *
                0.5,  # wrong action is corrected to forward without penalty!
                status=RailAgentStatus.ACTIVE),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5,  #
                status=RailAgentStatus.ACTIVE),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  #
                status=RailAgentStatus.ACTIVE),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE)
        ],
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
        target=(3, 5),
        speed=0.5)

    run_replay_config(env, [test_config], activate_agents=False)
def test_multispeed_actions_no_malfunction_no_blocking():
    """Test that actions are correctly performed on cell exit for a single agent."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty +
                env.step_penalty * 0.5  # starting and running at speed 0.5
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_LEFT,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty +
                env.step_penalty * 0.5  # stopping and step penalty
            ),
            #
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty *
                0.5  # step penalty for speed 0.5 when stopped
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty +
                env.step_penalty * 0.5  # starting + running at speed 0.5
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(5, 6),
                direction=Grid4TransitionsEnum.SOUTH,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
        ],
        target=(3, 0),  # west dead-end
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
    )

    run_replay_config(env, [test_config])
def test_multispeed_actions_malfunction_no_blocking():
    """Test on a single agent whether action on cell exit work correctly despite malfunction."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty +
                env.step_penalty * 0.5  # starting and running at speed 0.5
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            # add additional step in the cell
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                set_malfunction=2,  # recovers in two steps from now!,
                malfunction=2,
                reward=env.step_penalty *
                0.5  # step penalty for speed 0.5 when malfunctioning
            ),
            # agent recovers in this step
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                malfunction=1,
                reward=env.step_penalty *
                0.5  # recovered: running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                set_malfunction=2,  # recovers in two steps from now!
                malfunction=2,
                reward=env.step_penalty *
                0.5  # step penalty for speed 0.5 when malfunctioning
            ),
            # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                malfunction=1,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty + env.step_penalty *
                0.5  # stopping and step penalty for speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty *
                0.5  # step penalty for speed 0.5 while stopped
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty +
                env.step_penalty * 0.5  # starting and running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            # DO_NOTHING keeps moving!
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.DO_NOTHING,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
        ],
        target=(3, 0),  # west dead-end
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
    )
    run_replay_config(env, [test_config])
def test_multispeed_actions_no_malfunction_blocking():
    """The second agent blocks the first because it is slower."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    set_penalties_for_replay(env)
    test_configs = [
        ReplayConfig(
            replay=[
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty + env.step_penalty * 1.0 /
                    3.0  # starting and running at speed 1/3
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                ),
                Replay(
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                )
            ],
            target=(3, 0),  # west dead-end
            speed=1 / 3,
            initial_position=(3, 8),
            initial_direction=Grid4TransitionsEnum.WEST,
        ),
        ReplayConfig(
            replay=[
                Replay(
                    position=(3, 9),  # east dead-end
                    direction=Grid4TransitionsEnum.EAST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty +
                    env.step_penalty * 0.5  # starting and running at speed 0.5
                ),
                Replay(
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                # blocked although fraction >= 1.0
                Replay(
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                # blocked although fraction >= 1.0
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                # blocked although fraction >= 1.0
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
                    action=RailEnvActions.MOVE_LEFT,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
                # not blocked, action required!
                Replay(
                    position=(4, 6),
                    direction=Grid4TransitionsEnum.SOUTH,
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
                ),
            ],
            target=(3, 0),  # west dead-end
            speed=0.5,
            initial_position=(3, 9),  # east dead-end
            initial_direction=Grid4TransitionsEnum.EAST,
        )
    ]
    run_replay_config(env, test_configs)
def test_shortest_path_predictor(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.initial_position = (5, 6)  # south dead-end
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    env.reset(False, False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")

    # compute the observations and predictions
    distance_map = env.distance_map.get()
    assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \
        "found {} instead of {}".format(
            distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0)

    paths = get_shortest_paths(env.distance_map)[0]
    assert paths == [
        Waypoint((5, 6), 0),
        Waypoint((4, 6), 0),
        Waypoint((3, 6), 0),
        Waypoint((3, 7), 1),
        Waypoint((3, 8), 1),
        Waypoint((3, 9), 1)
    ]

    # extract the data
    predictions = env.obs_builder.predictions
    positions = np.array(
        list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
    directions = np.array(
        list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(
        list(map(lambda prediction: [prediction[0]], predictions[0])))

    # test if data meets expectations
    expected_positions = [
        [5, 6],
        [4, 6],
        [3, 6],
        [3, 7],
        [3, 8],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
    ]
    expected_directions = [
        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
    ]

    expected_time_offsets = np.array([
        [0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
    ])

    assert np.array_equal(time_offsets, expected_time_offsets), \
        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)

    assert np.array_equal(positions, expected_positions), \
        "positions {}, expected {}".format(positions, expected_positions)
    assert np.array_equal(directions, expected_directions), \
        "directions {}, expected {}".format(directions, expected_directions)
示例#11
0
def test_reward_function_waiting(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()),
                  remove_agents_at_target=False)
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.initial_position = (3, 8)  # east dead-end
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.initial_direction = 3  # west
    agent.target = (3, 1)  # west dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    agent = env.agents[1]
    agent.initial_position = (5, 6)  # south dead-end
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 8)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    env.reset(False, False)
    env.agents[0].moving = True
    env.agents[1].moving = True
    env.agents[0].status = RailAgentStatus.ACTIVE
    env.agents[1].status = RailAgentStatus.ACTIVE
    env.agents[0].position = (3, 8)
    env.agents[1].position = (5, 6)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=True)

    iteration = 0
    expectations = {
        0: {
            'positions': {
                0: (3, 8),
                1: (5, 6),
            },
            'rewards': [-1, -1],
        },
        1: {
            'positions': {
                0: (3, 7),
                1: (4, 6),
            },
            'rewards': [-1, -1],
        },
        # second agent has to wait for first, first can continue
        2: {
            'positions': {
                0: (3, 6),
                1: (4, 6),
            },
            'rewards': [-1, -1],
        },
        # both can move again
        3: {
            'positions': {
                0: (3, 5),
                1: (3, 6),
            },
            'rewards': [-1, -1],
        },
        4: {
            'positions': {
                0: (3, 4),
                1: (3, 7),
            },
            'rewards': [-1, -1],
        },
        # second reached target
        5: {
            'positions': {
                0: (3, 3),
                1: (3, 8),
            },
            'rewards': [-1, 0],
        },
        6: {
            'positions': {
                0: (3, 2),
                1: (3, 8),
            },
            'rewards': [-1, 0],
        },
        # first reaches, target too
        7: {
            'positions': {
                0: (3, 1),
                1: (3, 8),
            },
            'rewards': [1, 1],
        },
        8: {
            'positions': {
                0: (3, 1),
                1: (3, 8),
            },
            'rewards': [1, 1],
        },
    }
    while iteration < 7:

        rewards = _step_along_shortest_path(env, obs_builder, rail)

        if rendering:
            renderer.render_env(show=True, show_observations=True)

        print(env.dones["__all__"])
        for agent in env.agents:
            agent: EnvAgent
            print("[{}] agent {} at {}, target {} ".format(
                iteration + 1, agent.handle, agent.position, agent.target))
        print(
            np.all([
                np.array_equal(agent2.position, agent2.target)
                for agent2 in env.agents
            ]))
        for agent in env.agents:
            expected_position = expectations[iteration +
                                             1]['positions'][agent.handle]
            assert agent.position == expected_position, \
                "[{}] agent {} at {}, expected {}".format(iteration + 1,
                                                          agent.handle,
                                                          agent.position,
                                                          expected_position)
            expected_reward = expectations[iteration +
                                           1]['rewards'][agent.handle]
            actual_reward = rewards[agent.handle]
            assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(
                iteration + 1, agent.handle, actual_reward, expected_reward)
        iteration += 1
示例#12
0
def test_reward_function_conflict(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.position = (5, 6)  # south dead-end
    agent.initial_position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    agent = env.agents[1]
    agent.position = (3, 8)  # east dead-end
    agent.initial_position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.initial_direction = 3  # west
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    env.reset(False, False)
    env.agents[0].moving = True
    env.agents[1].moving = True
    env.agents[0].status = RailAgentStatus.ACTIVE
    env.agents[1].status = RailAgentStatus.ACTIVE
    env.agents[0].position = (5, 6)
    env.agents[1].position = (3, 8)
    print("\n")
    print(env.agents[0])
    print(env.agents[1])

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=True)

    iteration = 0
    expected_positions = {
        0: {
            0: (5, 6),
            1: (3, 8)
        },
        # both can move
        1: {
            0: (4, 6),
            1: (3, 7)
        },
        # first can move, second stuck
        2: {
            0: (3, 6),
            1: (3, 7)
        },
        # both stuck from now on
        3: {
            0: (3, 6),
            1: (3, 7)
        },
        4: {
            0: (3, 6),
            1: (3, 7)
        },
        5: {
            0: (3, 6),
            1: (3, 7)
        },
    }
    while iteration < 5:
        rewards = _step_along_shortest_path(env, obs_builder, rail)

        for agent in env.agents:
            assert rewards[agent.handle] == -1
            expected_position = expected_positions[iteration + 1][agent.handle]
            assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(
                iteration + 1, agent.handle, agent.position, expected_position)
        if rendering:
            renderer.render_env(show=True, show_observations=True)

        iteration += 1
示例#13
0
def tests_rail_from_file():
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    env.save(file_name)
    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

    env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                  schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    # Check that distance map was not recomputed
    assert np.shape(env.distance_map.get()) == dist_map_shape
    assert env.distance_map.get() is not None

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

    env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(),
                   number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    env2.save(file_name_2)

    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents

    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env3.reset()
    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )
    env4.reset()
    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

    # Check that no distance map was saved
    assert not hasattr(env2.obs_builder, "distance_map")
    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4

    # Check that distance map was generated with correct shape
    assert env4.distance_map.get() is not None
    assert np.shape(env4.distance_map.get()) == dist_map_shape
def test_action_plan(rendering: bool = False):
    """Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=77),
                  number_of_agents=2,
                  obs_builder_object=GlobalObsForRailEnv(),
                  remove_agents_at_target=True)
    env.reset()
    env.agents[0].initial_position = (3, 0)
    env.agents[0].target = (3, 8)
    env.agents[0].initial_direction = Grid4TransitionsEnum.WEST
    env.agents[1].initial_position = (3, 8)
    env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
    env.agents[1].target = (0, 3)
    env.agents[1].speed_data['speed'] = 0.5  # two
    env.reset(False, False, False)
    for handle, agent in enumerate(env.agents):
        print("[{}] {} -> {}".format(handle, agent.initial_position,
                                     agent.target))

    chosen_path_dict = {
        0: [
            TrainRunWayPoint(scheduled_at=0,
                             way_point=WayPoint(position=(3, 0), direction=3)),
            TrainRunWayPoint(scheduled_at=2,
                             way_point=WayPoint(position=(3, 1), direction=1)),
            TrainRunWayPoint(scheduled_at=3,
                             way_point=WayPoint(position=(3, 2), direction=1)),
            TrainRunWayPoint(scheduled_at=14,
                             way_point=WayPoint(position=(3, 3), direction=1)),
            TrainRunWayPoint(scheduled_at=15,
                             way_point=WayPoint(position=(3, 4), direction=1)),
            TrainRunWayPoint(scheduled_at=16,
                             way_point=WayPoint(position=(3, 5), direction=1)),
            TrainRunWayPoint(scheduled_at=17,
                             way_point=WayPoint(position=(3, 6), direction=1)),
            TrainRunWayPoint(scheduled_at=18,
                             way_point=WayPoint(position=(3, 7), direction=1)),
            TrainRunWayPoint(scheduled_at=19,
                             way_point=WayPoint(position=(3, 8), direction=1)),
            TrainRunWayPoint(scheduled_at=20,
                             way_point=WayPoint(position=(3, 8), direction=5))
        ],
        1: [
            TrainRunWayPoint(scheduled_at=0,
                             way_point=WayPoint(position=(3, 8), direction=3)),
            TrainRunWayPoint(scheduled_at=3,
                             way_point=WayPoint(position=(3, 7), direction=3)),
            TrainRunWayPoint(scheduled_at=5,
                             way_point=WayPoint(position=(3, 6), direction=3)),
            TrainRunWayPoint(scheduled_at=7,
                             way_point=WayPoint(position=(3, 5), direction=3)),
            TrainRunWayPoint(scheduled_at=9,
                             way_point=WayPoint(position=(3, 4), direction=3)),
            TrainRunWayPoint(scheduled_at=11,
                             way_point=WayPoint(position=(3, 3), direction=3)),
            TrainRunWayPoint(scheduled_at=13,
                             way_point=WayPoint(position=(2, 3), direction=0)),
            TrainRunWayPoint(scheduled_at=15,
                             way_point=WayPoint(position=(1, 3), direction=0)),
            TrainRunWayPoint(scheduled_at=17,
                             way_point=WayPoint(position=(0, 3), direction=0))
        ]
    }
    expected_action_plan = [
        [
            # take action to enter the grid
            ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
            # take action to enter the cell properly
            ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(2, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(3, RailEnvActions.STOP_MOVING),
            ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(14, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(16, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(17, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(18, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(19, RailEnvActions.STOP_MOVING)
        ],
        [
            ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(3, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(5, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(7, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(9, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(11, RailEnvActions.MOVE_RIGHT),
            ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(17, RailEnvActions.STOP_MOVING),
        ]
    ]

    MAX_EPISODE_STEPS = 50

    deterministic_controller = ControllerFromTrainRuns(env, chosen_path_dict)
    deterministic_controller.print_action_plan()
    ControllerFromTrainRuns.assert_actions_plans_equal(
        expected_action_plan, deterministic_controller.action_plan)
    ControllerFromTrainRunsReplayer.replay_verify(MAX_EPISODE_STEPS,
                                                  deterministic_controller,
                                                  env, rendering)