Example #1
0
def demo(args=None):
    """Demo script to check installation"""
    env = RailEnv(width=15,
                  height=15,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=5)

    env._max_episode_steps = int(15 * (env.width + env.height))
    env_renderer = RenderTool(env)

    while True:
        obs, info = env.reset()
        _done = False
        # Run a single episode here
        step = 0
        while not _done:
            # Compute Action
            _action = {}
            for _idx, _ in enumerate(env.agents):
                _action[_idx] = np.random.randint(0, 5)
            obs, all_rewards, done, _ = env.step(_action)
            _done = done['__all__']
            step += 1
            env_renderer.render_env(show=True,
                                    frames=False,
                                    show_observations=False,
                                    show_predictions=False)
            time.sleep(0.3)
    return 0
Example #2
0
def test_rail_environment_single_agent(show=False):
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    
    
    
    if False:
        # This env creation doesn't quite work right.
        cells = transitions.transition_list
        vertical_line = cells[1]
        south_symmetrical_switch = cells[6]
        north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
        south_east_turn = int('0100000000000010', 2)
        south_west_turn = transitions.rotate_transition(south_east_turn, 90)
        north_east_turn = transitions.rotate_transition(south_east_turn, 270)
        north_west_turn = transitions.rotate_transition(south_east_turn, 180)

        rail_map = np.array([[south_east_turn, south_symmetrical_switch,
                            south_west_turn],
                            [vertical_line, vertical_line, vertical_line],
                            [north_east_turn, north_symmetrical_switch,
                            north_west_turn]],
                            dtype=np.uint16)

        rail = GridTransitionMap(width=3, height=3, transitions=transitions)
        rail.grid = rail_map
        rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
                        schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
    else:
        rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
        rail_map = rail_env.rail.grid
    
    rail_env._max_episode_steps = 1000

    _ = rail_env.reset(False, False, True)

    liActions = [int(a) for a in RailEnvActions]

    env_renderer = RenderTool(rail_env)

    #RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
    
    for _ in range(5):

        #rail_env.agents[0].initial_position = (1,2)
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        if show:
            print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
            print(transitions.get_transitions(rail_map[agent.position], agent.direction))

        #assert (transitions.get_transitions(
        #    rail_map[agent.position],
        #    agent.direction) != (0, 0, 0, 0))

        # HACK - force the direction to one we know is good.
        #agent.initial_position = agent.position = (2,3)
        agent.initial_direction = agent.direction = 0

        if show:
            print ("handle:", agent.handle)
        #agent.initial_position = initial_pos = agent.position

        valid_active_actions_done = 0
        pos = agent.position

        if show:
            env_renderer.render_env(show=show, show_agents=True)
            time.sleep(0.01)

        iStep = 0
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.choice(liActions)
            #action = RailEnvActions.MOVE_FORWARD

            _, _, dict_done, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]

            print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
            print(dict_done)
            if prev_pos != pos:
                valid_active_actions_done += 1
            iStep += 1
            
            if show:
                env_renderer.render_env(show=show, show_agents=True, step=iStep)
                time.sleep(0.01)
            assert iStep < 100, "valid actions should have been performed by now - hung agent"

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        #assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            rail_env.agents[0].direction = 0

            # JW - to avoid problem with random_schedule_generator.
            #rail_env.agents[0].position = (1,2)

            iStep = 0
            while iStep < 100:
                # We randomly select an action
                action = np.random.choice(liActions)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
                if done:
                    break
                iStep +=1
                assert iStep < 100, "agent should have finished by now"
                env_renderer.render_env(show=show)