コード例 #1
0
def test_dead_end():
    transitions = RailEnvTransitions()

    straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
    straight_horizontal = transitions.rotate_transition(straight_vertical,
                                                        90)

    dead_end_from_south = int('0010000000000000', 2)  # Case 7 - dead end

    # We instantiate the following railway
    # O->-- where > is the train and O the target. After 6 steps,
    # the train should be done.

    rail_map = np.array(
        [[transitions.rotate_transition(dead_end_from_south, 270)] +
         [straight_horizontal] * 3 +
         [transitions.rotate_transition(dead_end_from_south, 90)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    # We try the configuration in the 4 directions:
    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]

    # In the vertical configuration:
    rail_map = np.array(
        [[dead_end_from_south]] + [[straight_vertical]] * 3 +
        [[transitions.rotate_transition(dead_end_from_south, 180)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
コード例 #2
0
def test_seeding_and_observations():
    # Test if two different instances diverge with different observations
    rail, rail_map = make_simple_rail2()

    # Make two seperate envs with different observation builders
    # Global Observation
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=12),
                  number_of_agents=10,
                  obs_builder_object=GlobalObsForRailEnv())
    # Tree Observation
    env2 = RailEnv(width=25,
                   height=30,
                   rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(seed=12),
                   number_of_agents=10,
                   obs_builder_object=TreeObsForRailEnv(
                       max_depth=2,
                       predictor=ShortestPathPredictorForRailEnv()))

    env.reset(False, False, False, random_seed=12)
    env2.reset(False, False, False, random_seed=12)

    # Check that both environments produce the same initial start positions
    assert env.agents[0].initial_position == env2.agents[0].initial_position
    assert env.agents[1].initial_position == env2.agents[1].initial_position
    assert env.agents[2].initial_position == env2.agents[2].initial_position
    assert env.agents[3].initial_position == env2.agents[3].initial_position
    assert env.agents[4].initial_position == env2.agents[4].initial_position
    assert env.agents[5].initial_position == env2.agents[5].initial_position
    assert env.agents[6].initial_position == env2.agents[6].initial_position
    assert env.agents[7].initial_position == env2.agents[7].initial_position
    assert env.agents[8].initial_position == env2.agents[8].initial_position
    assert env.agents[9].initial_position == env2.agents[9].initial_position

    action_dict = {}
    for step in range(10):
        for a in range(env.get_num_agents()):
            action = np.random.randint(4)
            action_dict[a] = action
        env.step(action_dict)
        env2.step(action_dict)

    # Check that both environments end up in the same position

    assert env.agents[0].position == env2.agents[0].position
    assert env.agents[1].position == env2.agents[1].position
    assert env.agents[2].position == env2.agents[2].position
    assert env.agents[3].position == env2.agents[3].position
    assert env.agents[4].position == env2.agents[4].position
    assert env.agents[5].position == env2.agents[5].position
    assert env.agents[6].position == env2.agents[6].position
    assert env.agents[7].position == env2.agents[7].position
    assert env.agents[8].position == env2.agents[8].position
    assert env.agents[9].position == env2.agents[9].position
    for a in range(env.get_num_agents()):
        print("assert env.agents[{}].position == env2.agents[{}].position".
              format(a, a))
コード例 #3
0
def test_get_shortest_paths_unreachable():
    rail, rail_map = make_disconnected_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.position = (3, 1)  # west dead-end
    agent.initial_position = (3, 1)  # west dead-end
    agent.direction = Grid4TransitionsEnum.WEST
    agent.target = (3, 9)  # east dead-end
    agent.moving = True

    env.reset(False, False)

    actual = get_shortest_paths(env.distance_map)
    expected = {0: None}

    assert actual == expected, "actual={},expected={}".format(actual, expected)
コード例 #4
0
def test_global_obs():
    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())

    global_obs, info = env.reset()

    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

    assert (global_obs[0][0].shape == rail_map.shape + (16, ))

    rail_map_recons = np.zeros_like(rail_map)
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)

    assert (rail_map_recons.all() == rail_map.all())

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
コード例 #5
0
def load_env(env_dict, obs_builder_object=GlobalObsForRailEnv()):
    """
    Loads an env
    """
    env = RailEnv(height=4, width=4, obs_builder_object=obs_builder_object)
    env.reset(regenerate_rail=False, regenerate_schedule=False)
    RailEnvPersister.set_full_state(env, env_dict)
    return env
コード例 #6
0
ファイル: main.py プロジェクト: TeamSerpentine/flatlands-2020
def env_creator():
    """
    Creates an env and returns it
    """
    return RailEnv(width=20,
                   height=30,
                   rail_generator=complex_rail_generator(nr_start_goal=100,
                                                         nr_extra=2,
                                                         min_dist=8,
                                                         max_dist=99999,
                                                         seed=False),
                   schedule_generator=complex_schedule_generator(seed=False),
                   obs_builder_object=GlobalObsForRailEnv(),
                   number_of_agents=3,
                   random_seed=True)
コード例 #7
0
class PaddedGlobalObsForRailEnv(ObservationBuilder):

    def __init__(self, max_width, max_height):
        super().__init__()
        self._max_width = max_width
        self._max_height = max_height
        self._builder = GlobalObsForRailEnv()

    def set_env(self, env: Environment):
        self._builder.set_env(env)

    def reset(self):
        self._builder.reset()

    def get(self, handle: int = 0):
        obs = list(self._builder.get(handle))
        height, width = obs[0].shape[:2]
        pad_height, pad_width = self._max_height - height, self._max_width - width
        obs[1] = obs[1] + 1  # get rid of -1
        assert pad_height >= 0 and pad_width >= 0
        return tuple([
            np.pad(o, ((0, pad_height), (0, pad_height), (0, 0)), constant_values=0)
            for o in obs
        ])
コード例 #8
0
ファイル: main.py プロジェクト: MelsHakobyan96/flatland_2.0
def env_gradual_update(input_env, agent=False, hardness_lvl=1):

    agent_num = input_env.number_of_agents
    env_width = input_env.width + 4
    env_height = input_env.height + 4

    map_agent_ratio = int(np.round(((env_width + env_height) / 2) / 5 - 2))

    if map_agent_ratio > 0:
        agent_num = int(np.round(((env_width + env_height) / 2) / 5 - 2))
    else:
        agent_num = 1

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
コード例 #9
0
def create_env(nr_start_goal=10,
               nr_extra=2,
               min_dist=8,
               max_dist=99999,
               nr_agent=10,
               seed=0,
               render_mode='PIL'):
    env = RailEnv(width=30,
                  height=30,
                  rail_generator=complex_rail_generator(
                      nr_start_goal, nr_extra, min_dist, max_dist, seed),
                  schedule_generator=complex_schedule_generator(),
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=nr_agent)
    env_renderer = RenderTool(env, gl=render_mode)
    obs = env.reset()

    return env, env_renderer, obs
コード例 #10
0
ファイル: main.py プロジェクト: MelsHakobyan96/flatland_2.0
def env_random_update(input_env, decay, agent=False, hardness_lvl=1):

    agent_num = np.random.randint(1, 5)
    env_width = (agent_num + 2) * 5
    env_height = (agent_num + 2) * 5

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
コード例 #11
0
def test_rail_environment_single_agent():
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    vertical_line = cells[1]
    south_symmetrical_switch = cells[6]
    north_symmetrical_switch = transitions.rotate_transition(
        south_symmetrical_switch, 180)
    south_east_turn = int('0100000000000010', 2)
    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
    north_west_turn = transitions.rotate_transition(south_east_turn, 180)

    rail_map = np.array(
        [[south_east_turn, south_symmetrical_switch, south_west_turn],
         [vertical_line, vertical_line, vertical_line],
         [north_east_turn, north_symmetrical_switch, north_west_turn]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=3, height=3, transitions=transitions)
    rail.grid = rail_map
    rail_env = RailEnv(width=3,
                       height=3,
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(),
                       number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    for _ in range(200):
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        assert (transitions.get_transitions(rail_map[agent.position],
                                            agent.direction) != (0, 0, 0, 0))

        initial_pos = agent.position

        valid_active_actions_done = 0
        pos = initial_pos
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.randint(4)

            _, _, _, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]
            if prev_pos != pos:
                valid_active_actions_done += 1

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            done = False
            while not done:
                # We randomly select an action
                action = np.random.randint(4)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
コード例 #12
0
def test_get_global_observation():
    number_of_agents = 20

    stochastic_data = {
        'prop_malfunction': 1.,  # Percentage of defective agents
        'malfunction_rate': 30,  # Rate of malfunction occurence
        'min_duration': 3,  # Minimal duration of malfunction
        'max_duration': 20  # Max duration of malfunction
    }

    speed_ration_map = {
        1.: 0.25,  # Fast passenger train
        1. / 2.: 0.25,  # Fast freight train
        1. / 3.: 0.25,  # Slow commuter train
        1. / 4.: 0.25
    }  # Slow freight train

    env = RailEnv(
        width=50,
        height=50,
        rail_generator=sparse_rail_generator(max_num_cities=6,
                                             max_rails_between_cities=4,
                                             seed=15,
                                             grid_mode=False),
        schedule_generator=sparse_schedule_generator(speed_ration_map),
        number_of_agents=number_of_agents,
        obs_builder_object=GlobalObsForRailEnv())
    env.reset()

    obs, all_rewards, done, _ = env.step(
        {i: RailEnvActions.MOVE_FORWARD
         for i in range(number_of_agents)})
    for i in range(len(env.agents)):
        agent: EnvAgent = env.agents[i]
        print("[{}] status={}, position={}, target={}, initial_position={}".
              format(i, agent.status, agent.position, agent.target,
                     agent.initial_position))

    for i, agent in enumerate(env.agents):
        obs_agents_state = obs[i][1]
        obs_targets = obs[i][2]

        # test first channel of obs_targets: own target
        nr_agents = np.count_nonzero(obs_targets[:, :, 0])
        assert nr_agents == 1, "agent {}: something wrong with own target, found {}".format(
            i, nr_agents)

        # test second channel of obs_targets: other agent's target
        for r in range(env.height):
            for c in range(env.width):
                _other_agent_target = 0
                for other_i, other_agent in enumerate(env.agents):
                    if other_agent.target == (r, c):
                        _other_agent_target = 1
                        break
                assert obs_targets[(
                    r, c
                )][1] == _other_agent_target, "agent {}: at {} expected to be other agent's target = {}".format(
                    i, (r, c), _other_agent_target)

        # test first channel of obs_agents_state: direction at own position
        for r in range(env.height):
            for c in range(env.width):
                if (agent.status == RailAgentStatus.ACTIVE or agent.status
                        == RailAgentStatus.DONE) and (r, c) == agent.position:
                    assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
                        "agent {} in status {} at {} expected to contain own direction {}, found {}" \
                            .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0])
                elif (agent.status == RailAgentStatus.READY_TO_DEPART) and (
                        r, c) == agent.initial_position:
                    assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
                        "agent {} in status {} at {} expected to contain own direction {}, found {}" \
                            .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0])
                else:
                    assert np.isclose(obs_agents_state[(r, c)][0], -1), \
                        "agent {} in status {} at {} expected contain -1 found {}" \
                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][0])

        # test second channel of obs_agents_state: direction at other agents position
        for r in range(env.height):
            for c in range(env.width):
                has_agent = False
                for other_i, other_agent in enumerate(env.agents):
                    if i == other_i:
                        continue
                    if other_agent.status in [
                            RailAgentStatus.ACTIVE, RailAgentStatus.DONE
                    ] and (r, c) == other_agent.position:
                        assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
                            "agent {} in status {} at {} should see other agent with direction {}, found = {}" \
                                .format(i, agent.status, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
                    has_agent = True
                if not has_agent:
                    assert np.isclose(obs_agents_state[(r, c)][1], -1), \
                        "agent {} in status {} at {} should see no other agent direction (-1), found = {}" \
                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][1])

        # test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
        for r in range(env.height):
            for c in range(env.width):
                has_agent = False
                for other_i, other_agent in enumerate(env.agents):
                    if other_agent.status in [
                            RailAgentStatus.ACTIVE, RailAgentStatus.DONE
                    ] and other_agent.position == (r, c):
                        assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \
                            "agent {} in status {} at {} should see agent malfunction {}, found = {}" \
                                .format(i, agent.status, (r, c), other_agent.malfunction_data['malfunction'],
                                        obs_agents_state[(r, c)][2])
                        assert np.isclose(obs_agents_state[(r, c)][3],
                                          other_agent.speed_data['speed'])
                        has_agent = True
                if not has_agent:
                    assert np.isclose(obs_agents_state[(r, c)][2], -1), \
                        "agent {} in status {} at {} should see no agent malfunction (-1), found = {}" \
                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][2])
                    assert np.isclose(obs_agents_state[(r, c)][3], -1), \
                        "agent {} in status {} at {} should see no agent speed (-1), found = {}" \
                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][3])

        # test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
        for r in range(env.height):
            for c in range(env.width):
                count = 0
                for other_i, other_agent in enumerate(env.agents):
                    if other_agent.status == RailAgentStatus.READY_TO_DEPART and other_agent.initial_position == (
                            r, c):
                        count += 1
                assert np.isclose(obs_agents_state[(r, c)][4], count), \
                    "agent {} in status {} at {} should see {} agents ready to depart, found{}" \
                        .format(i, agent.status, (r, c), count, obs_agents_state[(r, c)][4])
コード例 #13
0
def test_get_k_shortest_paths(rendering=False):
    rail, rail_map = make_simple_rail_with_alternatives()

    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=GlobalObsForRailEnv(),
    )
    env.reset()

    initial_position = (3, 1)  # west dead-end
    initial_direction = Grid4TransitionsEnum.WEST  # west
    target_position = (3, 9)  # east

    # set the initial position
    agent = env.agents[0]
    agent.position = initial_position
    agent.initial_position = initial_position
    agent.direction = initial_direction
    agent.target = target_position  # east dead-end
    agent.moving = True

    env.reset(False, False)
    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input()

    actual = set(
        get_k_shortest_paths(
            env=env,
            source_position=initial_position,  # west dead-end
            source_direction=int(initial_direction),  # east
            target_position=target_position,
            k=10))

    expected = set([
        (WayPoint(position=(3, 1),
                  direction=3), WayPoint(position=(3, 0), direction=3),
         WayPoint(position=(3, 1),
                  direction=1), WayPoint(position=(3, 2), direction=1),
         WayPoint(position=(3, 3),
                  direction=1), WayPoint(position=(2, 3), direction=0),
         WayPoint(position=(1, 3),
                  direction=0), WayPoint(position=(0, 3), direction=0),
         WayPoint(position=(0, 4),
                  direction=1), WayPoint(position=(0, 5), direction=1),
         WayPoint(position=(0, 6),
                  direction=1), WayPoint(position=(0, 7), direction=1),
         WayPoint(position=(0, 8),
                  direction=1), WayPoint(position=(0, 9), direction=1),
         WayPoint(position=(1, 9),
                  direction=2), WayPoint(position=(2, 9), direction=2),
         WayPoint(position=(3, 9), direction=2)),
        (WayPoint(position=(3, 1),
                  direction=3), WayPoint(position=(3, 0), direction=3),
         WayPoint(position=(3, 1),
                  direction=1), WayPoint(position=(3, 2), direction=1),
         WayPoint(position=(3, 3),
                  direction=1), WayPoint(position=(3, 4), direction=1),
         WayPoint(position=(3, 5),
                  direction=1), WayPoint(position=(3, 6), direction=1),
         WayPoint(position=(4, 6),
                  direction=2), WayPoint(position=(5, 6), direction=2),
         WayPoint(position=(6, 6),
                  direction=2), WayPoint(position=(5, 6), direction=0),
         WayPoint(position=(4, 6),
                  direction=0), WayPoint(position=(4, 7), direction=1),
         WayPoint(position=(4, 8),
                  direction=1), WayPoint(position=(4, 9), direction=1),
         WayPoint(position=(3, 9), direction=0))
    ])

    assert actual == expected, "actual={},expected={}".format(actual, expected)
コード例 #14
0
 def __init__(self, max_width, max_height):
     super().__init__()
     self._max_width = max_width
     self._max_height = max_height
     self._builder = GlobalObsForRailEnv()
コード例 #15
0
ファイル: rail_env.py プロジェクト: hagrid67/flatland
    def __init__(
            self,
            width,
            height,
            rail_generator: RailGenerator = random_rail_generator(),
            schedule_generator: ScheduleGenerator = random_schedule_generator(
            ),
            number_of_agents=1,
            obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
            malfunction_generator_and_process_data=no_malfunction_generator(),
            remove_agents_at_target=True,
            random_seed=1,
            record_steps=False):
        """
        Environment init.

        Parameters
        ----------
        rail_generator : function
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
            Implementations can be found in flatland/envs/rail_generators.py
        schedule_generator : function
            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
            Implementations can be found in flatland/envs/schedule_generators.py
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
        """
        super().__init__()

        self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        self.rail_generator: RailGenerator = rail_generator
        self.schedule_generator: ScheduleGenerator = schedule_generator
        self.rail: Optional[GridTransitionMap] = None
        self.width = width
        self.height = height

        self.remove_agents_at_target = remove_agents_at_target

        self.rewards = [0] * number_of_agents
        self.done = False
        self.obs_builder = obs_builder_object
        self.obs_builder.set_env(self)

        self._max_episode_steps: Optional[int] = None
        self._elapsed_steps = 0

        self.dones = dict.fromkeys(
            list(range(number_of_agents)) + ["__all__"], False)

        self.obs_dict = {}
        self.rewards_dict = {}
        self.dev_obs_dict = {}
        self.dev_pred_dict = {}

        self.agents: List[EnvAgent] = []
        self.number_of_agents = number_of_agents
        self.num_resets = 0
        self.distance_map = DistanceMap(self.agents, self.height, self.width)

        self.action_space = [5]

        self._seed()
        self._seed()
        self.random_seed = random_seed
        if self.random_seed:
            self._seed(seed=random_seed)

        self.valid_positions = None

        # global numpy array of agents position, True means that there is an agent at that cell
        self.agent_positions: np.ndarray = np.full((height, width), False)

        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
        self.cur_episode = []  # save timesteps in here
コード例 #16
0
ファイル: run_testMCP.py プロジェクト: Jiaoyang-Li/Flatland
# We can now initiate the schedule generator with the given speed profiles

schedule_generator = sparse_schedule_generator(speed_ration_map)

# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(
    malfunction_rate=0,  # Rate of malfunction occurence
    min_duration=3,  # Minimal duration of malfunction
    max_duration=20  # Max duration of malfunction
)
print(stochastic_data)

# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()

# Custom observation builder with predictor, uncomment line below if you want to try this one
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())

# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(
    width=width,
    height=height,
    rail_generator=rail_generator,
    schedule_generator=schedule_generator,
    number_of_agents=nr_trains,
    malfunction_generator_and_process_data=malfunction_from_params(
        stochastic_data),
    obs_builder_object=observation_builder,
    remove_agents_at_target=
コード例 #17
0

def log_video(_images, epoch):
    height, width, depth = _images[0].shape
    print(len(_images), height, width, depth)
    out = cv2.VideoWriter(f'video_{epoch}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
    [out.write(cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)) for image in _images]
    out.release()


np.random.seed(1)

# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment

obs_builder = GlobalObsForRailEnv()
env = RailEnv(width=20, height=20,
              rail_generator=complex_rail_generator(nr_start_goal=100, nr_extra=2, min_dist=8, max_dist=99999),
              schedule_generator=complex_schedule_generator(), obs_builder_object=obs_builder,
              number_of_agents=3)
env.reset()

env_renderer = RenderTool(env)

# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
agent_kwargs = {"state_size": 0, "action_size": 5}
controller = RandomController(5)

n_trials = 5
コード例 #18
0
def test_rail_environment_single_agent(show=False):
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    
    
    
    if False:
        # This env creation doesn't quite work right.
        cells = transitions.transition_list
        vertical_line = cells[1]
        south_symmetrical_switch = cells[6]
        north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
        south_east_turn = int('0100000000000010', 2)
        south_west_turn = transitions.rotate_transition(south_east_turn, 90)
        north_east_turn = transitions.rotate_transition(south_east_turn, 270)
        north_west_turn = transitions.rotate_transition(south_east_turn, 180)

        rail_map = np.array([[south_east_turn, south_symmetrical_switch,
                            south_west_turn],
                            [vertical_line, vertical_line, vertical_line],
                            [north_east_turn, north_symmetrical_switch,
                            north_west_turn]],
                            dtype=np.uint16)

        rail = GridTransitionMap(width=3, height=3, transitions=transitions)
        rail.grid = rail_map
        rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
                        schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
    else:
        rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
        rail_map = rail_env.rail.grid
    
    rail_env._max_episode_steps = 1000

    _ = rail_env.reset(False, False, True)

    liActions = [int(a) for a in RailEnvActions]

    env_renderer = RenderTool(rail_env)

    #RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
    
    for _ in range(5):

        #rail_env.agents[0].initial_position = (1,2)
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        if show:
            print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
            print(transitions.get_transitions(rail_map[agent.position], agent.direction))

        #assert (transitions.get_transitions(
        #    rail_map[agent.position],
        #    agent.direction) != (0, 0, 0, 0))

        # HACK - force the direction to one we know is good.
        #agent.initial_position = agent.position = (2,3)
        agent.initial_direction = agent.direction = 0

        if show:
            print ("handle:", agent.handle)
        #agent.initial_position = initial_pos = agent.position

        valid_active_actions_done = 0
        pos = agent.position

        if show:
            env_renderer.render_env(show=show, show_agents=True)
            time.sleep(0.01)

        iStep = 0
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.choice(liActions)
            #action = RailEnvActions.MOVE_FORWARD

            _, _, dict_done, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]

            print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
            print(dict_done)
            if prev_pos != pos:
                valid_active_actions_done += 1
            iStep += 1
            
            if show:
                env_renderer.render_env(show=show, show_agents=True, step=iStep)
                time.sleep(0.01)
            assert iStep < 100, "valid actions should have been performed by now - hung agent"

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        #assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            rail_env.agents[0].direction = 0

            # JW - to avoid problem with random_schedule_generator.
            #rail_env.agents[0].position = (1,2)

            iStep = 0
            while iStep < 100:
                # We randomly select an action
                action = np.random.choice(liActions)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
                if done:
                    break
                iStep +=1
                assert iStep < 100, "agent should have finished by now"
                env_renderer.render_env(show=show)
コード例 #19
0
def test_seeding_and_malfunction():
    # Test if two different instances diverge with different observations
    rail, rail_map = make_simple_rail2()

    stochastic_data = {
        'prop_malfunction': 0.4,
        'malfunction_rate': 2,
        'min_duration': 10,
        'max_duration': 10
    }
    # Make two seperate envs with different and see if the exhibit the same malfunctions
    # Global Observation
    for tests in range(1, 100):
        env = RailEnv(width=25,
                      height=30,
                      rail_generator=rail_from_grid_transition_map(rail),
                      schedule_generator=random_schedule_generator(),
                      number_of_agents=10,
                      obs_builder_object=GlobalObsForRailEnv())

        # Tree Observation
        env2 = RailEnv(width=25,
                       height=30,
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(),
                       number_of_agents=10,
                       obs_builder_object=GlobalObsForRailEnv())

        env.reset(True, False, True, random_seed=tests)
        env2.reset(True, False, True, random_seed=tests)

        # Check that both environments produce the same initial start positions
        assert env.agents[0].initial_position == env2.agents[
            0].initial_position
        assert env.agents[1].initial_position == env2.agents[
            1].initial_position
        assert env.agents[2].initial_position == env2.agents[
            2].initial_position
        assert env.agents[3].initial_position == env2.agents[
            3].initial_position
        assert env.agents[4].initial_position == env2.agents[
            4].initial_position
        assert env.agents[5].initial_position == env2.agents[
            5].initial_position
        assert env.agents[6].initial_position == env2.agents[
            6].initial_position
        assert env.agents[7].initial_position == env2.agents[
            7].initial_position
        assert env.agents[8].initial_position == env2.agents[
            8].initial_position
        assert env.agents[9].initial_position == env2.agents[
            9].initial_position

        action_dict = {}
        for step in range(10):
            for a in range(env.get_num_agents()):
                action = np.random.randint(4)
                action_dict[a] = action
                # print("----------------------")
                # print(env.agents[a].malfunction_data, env.agents[a].status)
                # print(env2.agents[a].malfunction_data, env2.agents[a].status)

            _, reward1, done1, _ = env.step(action_dict)
            _, reward2, done2, _ = env2.step(action_dict)
            for a in range(env.get_num_agents()):
                assert reward1[a] == reward2[a]
                assert done1[a] == done2[a]
        # Check that both environments end up in the same position

        assert env.agents[0].position == env2.agents[0].position
        assert env.agents[1].position == env2.agents[1].position
        assert env.agents[2].position == env2.agents[2].position
        assert env.agents[3].position == env2.agents[3].position
        assert env.agents[4].position == env2.agents[4].position
        assert env.agents[5].position == env2.agents[5].position
        assert env.agents[6].position == env2.agents[6].position
        assert env.agents[7].position == env2.agents[7].position
        assert env.agents[8].position == env2.agents[8].position
        assert env.agents[9].position == env2.agents[9].position
コード例 #20
0
ファイル: run_round1.py プロジェクト: Jiaoyang-Li/Flatland
                    'instance'] == folder + "_" + file_name:
                skip = True
                print("skip {} with {}".format(algo_name,
                                               folder + "_" + file_name))
        if skip:
            continue

        print("\n\n Instance " + folder + '/' + filename)

        #####################################################################
        # step loop information
        #####################################################################
        time_taken_by_controller = []
        time_taken_per_step = []
        steps = 0
        my_observation_builder = GlobalObsForRailEnv()

        # Construct the enviornment from file
        test = path + folder + '/' + filename
        local_env = RailEnv(
            width=1,
            height=1,
            rail_generator=rail_from_file(test),
            schedule_generator=schedule_from_file(test),
            malfunction_generator_and_process_data=malfunction_from_params(
                stochastic_data),
            obs_builder_object=GlobalObsForRailEnv(),
            remove_agents_at_target=True,
            record_steps=True)
        local_env.reset()
コード例 #21
0
def test_action_plan(rendering: bool = False):
    """Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=77),
                  number_of_agents=2,
                  obs_builder_object=GlobalObsForRailEnv(),
                  remove_agents_at_target=True)
    env.reset()
    env.agents[0].initial_position = (3, 0)
    env.agents[0].target = (3, 8)
    env.agents[0].initial_direction = Grid4TransitionsEnum.WEST
    env.agents[1].initial_position = (3, 8)
    env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
    env.agents[1].target = (0, 3)
    env.agents[1].speed_data['speed'] = 0.5  # two
    env.reset(False, False, False)
    for handle, agent in enumerate(env.agents):
        print("[{}] {} -> {}".format(handle, agent.initial_position,
                                     agent.target))

    chosen_path_dict = {
        0: [
            TrainRunWayPoint(scheduled_at=0,
                             way_point=WayPoint(position=(3, 0), direction=3)),
            TrainRunWayPoint(scheduled_at=2,
                             way_point=WayPoint(position=(3, 1), direction=1)),
            TrainRunWayPoint(scheduled_at=3,
                             way_point=WayPoint(position=(3, 2), direction=1)),
            TrainRunWayPoint(scheduled_at=14,
                             way_point=WayPoint(position=(3, 3), direction=1)),
            TrainRunWayPoint(scheduled_at=15,
                             way_point=WayPoint(position=(3, 4), direction=1)),
            TrainRunWayPoint(scheduled_at=16,
                             way_point=WayPoint(position=(3, 5), direction=1)),
            TrainRunWayPoint(scheduled_at=17,
                             way_point=WayPoint(position=(3, 6), direction=1)),
            TrainRunWayPoint(scheduled_at=18,
                             way_point=WayPoint(position=(3, 7), direction=1)),
            TrainRunWayPoint(scheduled_at=19,
                             way_point=WayPoint(position=(3, 8), direction=1)),
            TrainRunWayPoint(scheduled_at=20,
                             way_point=WayPoint(position=(3, 8), direction=5))
        ],
        1: [
            TrainRunWayPoint(scheduled_at=0,
                             way_point=WayPoint(position=(3, 8), direction=3)),
            TrainRunWayPoint(scheduled_at=3,
                             way_point=WayPoint(position=(3, 7), direction=3)),
            TrainRunWayPoint(scheduled_at=5,
                             way_point=WayPoint(position=(3, 6), direction=3)),
            TrainRunWayPoint(scheduled_at=7,
                             way_point=WayPoint(position=(3, 5), direction=3)),
            TrainRunWayPoint(scheduled_at=9,
                             way_point=WayPoint(position=(3, 4), direction=3)),
            TrainRunWayPoint(scheduled_at=11,
                             way_point=WayPoint(position=(3, 3), direction=3)),
            TrainRunWayPoint(scheduled_at=13,
                             way_point=WayPoint(position=(2, 3), direction=0)),
            TrainRunWayPoint(scheduled_at=15,
                             way_point=WayPoint(position=(1, 3), direction=0)),
            TrainRunWayPoint(scheduled_at=17,
                             way_point=WayPoint(position=(0, 3), direction=0))
        ]
    }
    expected_action_plan = [
        [
            # take action to enter the grid
            ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
            # take action to enter the cell properly
            ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(2, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(3, RailEnvActions.STOP_MOVING),
            ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(14, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(16, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(17, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(18, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(19, RailEnvActions.STOP_MOVING)
        ],
        [
            ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(3, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(5, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(7, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(9, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(11, RailEnvActions.MOVE_RIGHT),
            ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
            ActionPlanElement(17, RailEnvActions.STOP_MOVING),
        ]
    ]

    MAX_EPISODE_STEPS = 50

    deterministic_controller = ControllerFromTrainRuns(env, chosen_path_dict)
    deterministic_controller.print_action_plan()
    ControllerFromTrainRuns.assert_actions_plans_equal(
        expected_action_plan, deterministic_controller.action_plan)
    ControllerFromTrainRunsReplayer.replay_verify(MAX_EPISODE_STEPS,
                                                  deterministic_controller,
                                                  env, rendering)
コード例 #22
0
def create_test_env(fnParams, nTest, sDir):
    (seed, width, height, nr_trains, nr_cities, max_rails_between_cities,
     max_rails_in_cities, malfunction_rate, malfunction_min_duration,
     malfunction_max_duration) = fnParams(nTest)
    #if not ShouldRunTest(test_id):
    #    continue

    rail_generator = sparse_rail_generator(
        max_num_cities=nr_cities,
        seed=seed,
        grid_mode=False,
        max_rails_between_cities=max_rails_between_cities,
        max_rails_in_city=max_rails_in_cities,
    )

    #stochastic_data = {'malfunction_rate': malfunction_rate,
    #                    'min_duration': malfunction_min_duration,
    #                    'max_duration': malfunction_max_duration
    #                }

    stochastic_data = MalfunctionParameters(
        malfunction_rate=malfunction_rate,
        min_duration=malfunction_min_duration,
        max_duration=malfunction_max_duration)

    observation_builder = GlobalObsForRailEnv()

    DEFAULT_SPEED_RATIO_MAP = {
        1.: 0.25,
        1. / 2.: 0.25,
        1. / 3.: 0.25,
        1. / 4.: 0.25
    }

    schedule_generator = sparse_schedule_generator(DEFAULT_SPEED_RATIO_MAP)

    for iAttempt in range(5):
        try:
            env = RailEnv(
                width=width,
                height=height,
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=nr_trains,
                malfunction_generator_and_process_data=malfunction_from_params(
                    stochastic_data),
                obs_builder_object=observation_builder,
                remove_agents_at_target=True)
            obs = env.reset(random_seed=seed)
            break
        except ValueError as oErr:
            print("Error:", oErr)
            width += 5
            height += 5
            print("Try again with larger env: (w,h):", width, height)

    if not os.path.exists(sDir):
        os.makedirs(sDir)

    sfName = "{}/Level_{}.mpk".format(sDir, nTest)
    if os.path.exists(sfName):
        os.remove(sfName)
    env.save(sfName)

    sys.stdout.write(".")
    sys.stdout.flush()

    return env
コード例 #23
0
def tests_rail_from_file():
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    env.save(file_name)
    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

    env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                  schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    env.reset()
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

    # Check that distance map was not recomputed
    assert np.shape(env.distance_map.get()) == dist_map_shape
    assert env.distance_map.get() is not None

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

    env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(),
                   number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    env2.save(file_name_2)

    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents

    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env2.reset()
    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
    env3.reset()
    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
                   schedule_generator=schedule_from_file(file_name_2),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )
    env4.reset()
    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

    # Check that no distance map was saved
    assert not hasattr(env2.obs_builder, "distance_map")
    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4

    # Check that distance map was generated with correct shape
    assert env4.distance_map.get() is not None
    assert np.shape(env4.distance_map.get()) == dist_map_shape
コード例 #24
0
ファイル: this_works.py プロジェクト: CatLads/Notebooks
random_rail_generator = complex_rail_generator(
    nr_start_goal=10,  # @param{type:"integer"} number of start and end goals 
    # connections, the higher the easier it should be for
    # the trains
    nr_extra=10,  # @param{type:"integer"} extra connections 
    # (useful for alternite paths), the higher the easier
    min_dist=10,
    max_dist=99999,
    seed=seed)
from flatland.utils.rendertools import RenderTool

env = OurEnv(width=width,
             height=height,
             rail_generator=random_rail_generator,
             obs_builder_object=GlobalObsForRailEnv(),
             number_of_agents=agents)
env_renderer = RenderTool(env)

# env.reset is needed to build the first step of the env
_ = env.reset()  # assigned to _ just to suppress the output
"""To render the env we use RenderTool. I think `gl="PILSVG"` is the lib used to actually render, using the default one doesn't work.

The function `render_env(env)` shows the env using matplotlib.
"""

import matplotlib.pyplot as plt
"""Let's perform a basic action for each train.

The actions are (as defined [here](http://flatland-rl-docs.s3-website.eu-central-1.amazonaws.com/04_specifications.html#action-space)):
* 0 Do Nothing: If the agent is moving it continues moving, if it is stopped it stays stopped
コード例 #25
0
def run_test(parameters, agent, test_nr=0, tree_depth=3):
    # Parameter initialization
    lp = LineProfiler()
    features_per_node = 9
    start_time_scoring = time.time()
    action_dict = dict()
    nr_trials_per_test = 5
    print('Running Test {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.format(
        test_nr, parameters[0], parameters[1], parameters[2]))

    # Reset all measurements
    time_obs = deque(maxlen=2)
    test_scores = []
    test_dones = []

    # Reset environment
    random.seed(parameters[3])
    np.random.seed(parameters[3])
    nr_paths = max(2, parameters[2] + int(0.5 * parameters[2]))
    min_dist = int(min([parameters[0], parameters[1]]) * 0.75)
    env = RailEnv(width=parameters[0],
                  height=parameters[1],
                  rail_generator=complex_rail_generator(nr_start_goal=nr_paths,
                                                        nr_extra=5,
                                                        min_dist=min_dist,
                                                        max_dist=99999,
                                                        seed=parameters[3]),
                  schedule_generator=complex_schedule_generator(),
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=parameters[2])
    max_steps = int(3 * (env.height + env.width))
    lp_step = lp(env.step)
    lp_reset = lp(env.reset)

    agent_obs = [None] * env.get_num_agents()
    printProgressBar(0,
                     nr_trials_per_test,
                     prefix='Progress:',
                     suffix='Complete',
                     length=20)
    for trial in range(nr_trials_per_test):
        # Reset the env

        lp_reset(True, True)
        obs, info = env.reset(True, True)
        for a in range(env.get_num_agents()):
            data, distance, agent_data = split_tree_into_feature_groups(
                obs[a], tree_depth)
            data = norm_obs_clip(data)
            distance = norm_obs_clip(distance)
            agent_data = np.clip(agent_data, -1, 1)
            obs[a] = np.concatenate((np.concatenate(
                (data, distance)), agent_data))

        for i in range(2):
            time_obs.append(obs)

        for a in range(env.get_num_agents()):
            agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))

        # Run episode
        trial_score = 0
        for step in range(max_steps):

            for a in range(env.get_num_agents()):
                action = agent.act(agent_obs[a], eps=0)
                action_dict.update({a: action})

            # Environment step
            next_obs, all_rewards, done, _ = lp_step(action_dict)

            for a in range(env.get_num_agents()):
                data, distance, agent_data = split_tree_into_feature_groups(
                    next_obs[a], tree_depth)
                data = norm_obs_clip(data)
                distance = norm_obs_clip(distance)
                agent_data = np.clip(agent_data, -1, 1)
                next_obs[a] = np.concatenate((np.concatenate(
                    (data, distance)), agent_data))
            time_obs.append(next_obs)
            for a in range(env.get_num_agents()):
                agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
                trial_score += all_rewards[a] / env.get_num_agents()

            if done['__all__']:
                break
        test_scores.append(trial_score / max_steps)
        test_dones.append(done['__all__'])
        printProgressBar(trial + 1,
                         nr_trials_per_test,
                         prefix='Progress:',
                         suffix='Complete',
                         length=20)
    end_time_scoring = time.time()
    tot_test_time = end_time_scoring - start_time_scoring
    lp.print_stats()
    return test_scores, test_dones, tot_test_time