示例#1
0
    def reset(self):
        obs, info = self.rail_env.reset(regenerate_rail=self._regenerate_rail_on_reset,
                                        regenerate_schedule=self._regenerate_schedule_on_reset)
        # Reset rendering
        if self.render:
            self.env_renderer = RenderTool(self.rail_env, gl="PGL")
            self.env_renderer.set_new_rail()

        # Reset custom observations
        self.observation_normalizer.reset_custom_obs(self.rail_env)

        # Compute deadlocks
        self.deadlocks_detector.reset(self.rail_env.get_num_agents())
        info["deadlocks"] = {}

        for agent in range(self.rail_env.get_num_agents()):
            info["deadlocks"][agent] = self.deadlocks_detector.deadlocks[agent]

        # Normalization
        for agent in obs:
            if obs[agent] is not None:
                obs[agent] = self.observation_normalizer.normalize_observation(obs[agent], self.rail_env,
                                                                               agent, info["deadlocks"][agent])

        return obs, info
示例#2
0
文件: cli.py 项目: hagrid67/flatland
def demo(args=None):
    """Demo script to check installation"""
    env = RailEnv(width=15,
                  height=15,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=5)

    env._max_episode_steps = int(15 * (env.width + env.height))
    env_renderer = RenderTool(env)

    while True:
        obs, info = env.reset()
        _done = False
        # Run a single episode here
        step = 0
        while not _done:
            # Compute Action
            _action = {}
            for _idx, _ in enumerate(env.agents):
                _action[_idx] = np.random.randint(0, 5)
            obs, all_rewards, done, _ = env.step(_action)
            _done = done['__all__']
            step += 1
            env_renderer.render_env(show=True,
                                    frames=False,
                                    show_observations=False,
                                    show_predictions=False)
            time.sleep(0.3)
    return 0
示例#3
0
def render_env(env, fname):
    env_renderer = RenderTool(env, gl="PGL")
    env_renderer.render_env()

    image = env_renderer.get_image()
    pil_image = PIL.Image.fromarray(image)
    pil_image.save(fname)
示例#4
0
def test_path_not_exists(rendering=False):
    rail, rail_map = make_simple_rail_unconnected()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # south dead-end
        0,  # north
        (0, 3),  # north dead-end
        False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
示例#5
0
 def reset_renderer(self):
     self.renderer = RenderTool(self.flatland_env,
                                gl="PILSVG",
                                agent_render_variant=AgentRenderVariant.
                                AGENT_SHOWS_OPTIONS_AND_BOX,
                                show_debug=True,
                                screen_height=700,
                                screen_width=1300)
示例#6
0
def solve(env, width, height, naive, predictor):
    env_renderer = RenderTool(env)
    solver = r2_solver.Solver(1)

    obs, _ = env.reset()
    env.obs_builder.find_safe_edges(env)

    predictor.env = env
    predictor.get()
    for step in range(100):

        # print(obs)
        # print(obs.shape)

        if naive:
            _action = naive_solver(env, obs)
        else:
            _action = solver.GetMoves(env.agents, obs)

        obs_paths = TL_detector(env, obs, _action)
        for k in obs_paths.keys():
            if obs_paths[k] is not None and improved_solver(obs_paths[k]) == 0:
                _action[k] = 4

        for k in _action.keys():
            if env.agents[k].position is None:
                continue

            pos = (env.agents[k].position[0], env.agents[k].position[1],
                   env.agents[k].direction)
            if _action[k] != 0 and _action[
                    k] != 4 and pos in env.dev_pred_dict[k]:
                env.dev_pred_dict[k].remove(pos)

        next_obs, all_rewards, done, _ = env.step(_action)

        print("Rewards: {}, [done={}]".format(all_rewards, done))
        img = env_renderer.render_env(show=True,
                                      show_inactive_agents=False,
                                      show_predictions=True,
                                      show_observations=False,
                                      frames=True,
                                      return_image=True)
        cv2.imwrite("./env_images/" + str(step).zfill(3) + ".jpg", img)

        obs = next_obs.copy()
        if obs is None or done['__all__']:
            break

    unfinished_agents = []
    for k in done.keys():
        if not done[k] and type(k) is int:
            unfinished_agents.append(k)

    with open('observations_and_agents.pickle', 'wb') as f:
        pickle.dump((env.obs_builder.obs_dict, unfinished_agents,
                     env.obs_builder.branches, env.obs_builder.safe_map), f)
    return
示例#7
0
def render_env(env):
    env_renderer = RenderTool(env, gl="PGL")
    env_renderer.render_env()

    image = env_renderer.get_image()
    pil_image = PIL.Image.fromarray(image)
    #print("RENDER")
    #pil_image.show()
    images.append(pil_image)
    print(len(images))
示例#8
0
 def initialize_renderer(self, mode="human"):
     # Initiate the renderer
     from flatland.utils.rendertools import RenderTool, AgentRenderVariant
     self.renderer = RenderTool(
         self,
         gl="PGL",  # gl="TKPILSVG",
         agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
         show_debug=False,
         screen_height=600,  # Adjust these parameters to fit your resolution
         screen_width=800)  # Adjust these parameters to fit your resolution
def test_shortest_path_predictor_conflicts(rendering=False):
    rail, rail_map = make_invalid_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=2,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.initial_position = (5, 6)  # south dead-end
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    agent = env.agents[1]
    agent.initial_position = (3, 8)  # east dead-end
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.initial_direction = 3  # west
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    observations, info = env.reset(False, False, True)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
    tree_0 = observations[0]
    tree_1 = observations[1]
    env.obs_builder.util_print_obs_subtree(tree_0)
    env.obs_builder.util_print_obs_subtree(tree_1)

    # check the expectations
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0,
                              "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1,
                              "agent[1]: ")
示例#10
0
class OurEnv(RailEnv):
    def reset(self, *args, **kwargs):
        return_val = super().reset(*args, **kwargs)
        self.env_renderer = RenderTool(env)
        self.step({0: RailEnvActions.MOVE_FORWARD})
        return return_val

    def step(self, *args, **kwargs):
        self.env_renderer.render_env(show=True)
        print(args[0])
        observation, reward, done, info = super().step(*args, **kwargs)
        return observation, reward, done["__all__"], info
示例#11
0
    def render(self, mode='human'):
        # TODO: Merge both strategies (Jupyter vs .py)
        # In .py files
        # self.renderer.render_env(show=False, show_observations=False, show_predictions=False)
        # In Jupyter Notebooks
        env_renderer = RenderTool(self.flatland_env, gl="PILSVG")
        env_renderer.render_env()

        image = env_renderer.get_image()
        pil_image = Image.fromarray(image)
        display(pil_image)
        return image
示例#12
0
def render_test(parameters, test_nr=0, nr_examples=5):
    for trial in range(nr_examples):
        # Reset the env
        print(
            'Showing {} Level {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.
            format(test_nr, trial, parameters[0], parameters[1],
                   parameters[2]))
        file_name = "./Tests/{}/Level_{}.pkl".format(test_nr, trial)

        env = RailEnv(
            width=1,
            height=1,
            rail_generator=rail_from_file(file_name),
            obs_builder_object=TreeObsForRailEnv(max_depth=2),
            number_of_agents=1,
        )
        env_renderer = RenderTool(
            env,
            gl="PILSVG",
        )
        env_renderer.set_new_rail()

        env.reset(False, False)
        env_renderer.render_env(show=True, show_observations=False)

        time.sleep(0.1)
        env_renderer.close_window()
    return
示例#13
0
def evaluate(n_episodes):
    run = SUBMISSIONS["rlps-tcpr"]
    config, run = init_run(run)
    agent = ShortestPathRllibAgent(get_agent(config, run))
    env = get_env(config, rl=True)
    env_renderer = RenderTool(env, screen_width=8800)
    returns = []
    pcs = []
    malfs = []

    for _ in tqdm(range(n_episodes)):

        obs, _ = env.reset(regenerate_schedule=True, regenerate_rail=True)
        if RENDER:
            env_renderer.reset()
            env_renderer.render_env(show=True,
                                    frames=True,
                                    show_observations=False)

        if not obs:
            break

        steps = 0
        ep_return = 0
        done = defaultdict(lambda: False)
        robust_env = RobustFlatlandGymEnv(rail_env=env,
                                          max_nr_active_agents=200,
                                          observation_space=None,
                                          priorizer=DistToTargetPriorizer(),
                                          allow_noop=True)

        sorted_handles = robust_env.priorizer.priorize(handles=list(
            obs.keys()),
                                                       rail_env=env)

        while not done['__all__']:
            actions = agent.compute_actions(obs, env)
            robust_actions = robust_env.get_robust_actions(
                actions, sorted_handles)
            obs, all_rewards, done, info = env.step(robust_actions)
            if RENDER:
                env_renderer.render_env(show=True,
                                        frames=True,
                                        show_observations=False)
            print('.', end='', flush=True)
            steps += 1
            ep_return += np.sum(list(all_rewards.values()))

        pc = np.sum(np.array([1 for a in env.agents if is_done(a)
                              ])) / env.get_num_agents()
        print("EPISODE PC:", pc)
        n_episodes += 1
        pcs.append(pc)
        returns.append(ep_return /
                       (env._max_episode_steps * env.get_num_agents()))
        malfs.append(
            np.sum([a.malfunction_data['nr_malfunctions']
                    for a in env.agents]))
    return pcs, returns, malfs
示例#14
0
    def render(self, **kwargs):
        from flatland.utils.rendertools import RenderTool

        if not self.env_renderer:
            self.env_renderer = RenderTool(self.env, gl="PILSVG")
            self.env_renderer.set_new_rail()
        self.env_renderer.render_env(show=True,
                                     frames=False,
                                     show_observations=False,
                                     **kwargs)
        time.sleep(0.1)
        self.env_renderer.render_env(show=True,
                                     frames=False,
                                     show_observations=False,
                                     **kwargs)
        return self.env_renderer.get_image()
示例#15
0
def check_path(env,
               rail,
               position,
               direction,
               target,
               expected,
               rendering=False):
    agent = env.agents[0]
    agent.position = position  # south dead-end
    agent.direction = direction  # north
    agent.target = target  # east dead-end
    agent.moving = True
    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
    assert rail.check_path_exists(agent.position, agent.direction,
                                  agent.target) == expected
示例#16
0
 def get_renderer(self):
     '''
     Return a renderer for the current environment
     '''
     return RenderTool(self,
                       agent_render_variant=AgentRenderVariant.
                       AGENT_SHOWS_OPTIONS_AND_BOX,
                       show_debug=True,
                       screen_height=1080,
                       screen_width=1920)
示例#17
0
文件: ENV.py 项目: Zeii2024/RL
    def env_renderer(self, env):
        # RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0)
        # To show the screen

        env_renderer = RenderTool(env,
                                  gl="PILSVG",
                                  agent_render_variant=AgentRenderVariant.
                                  AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True,
                                  screen_height=800,
                                  screen_width=800)
        return env_renderer
示例#18
0
def createEnvSet(nStart, nEnd, sDir, bSmall=True):
    # print("Generate small envs in train-envs-small:")
    print(f"Generate envs (small={bSmall}) in dir {sDir}:")

    sDirImages = "train-envs-small/images/"
    if not os.path.exists(sDirImages):
        os.makedirs(sDirImages)

    for test_id in range(nStart, nEnd, 1):
        env = create_test_env(RandomTestParams_small, test_id, sDir)

        oRender = RenderTool(env, gl="PILSVG")

        # oRender.envs = envs
        # oRender.set_new_rail()
        oRender.render_env()
        g2img = oRender.get_image()
        imgPIL = Image.fromarray(g2img)
        # imgPIL.show()

        imgPIL.save(sDirImages + "Level_{}.png".format(test_id))
示例#19
0
    def replay_verify(max_episode_steps: int, ctl: ControllerFromTrainRuns,
                      env: RailEnv, rendering: bool):
        """Replays this deterministic `ActionPlan` and verifies whether it is feasible."""
        if rendering:
            renderer = RenderTool(env,
                                  gl="PILSVG",
                                  agent_render_variant=AgentRenderVariant.
                                  AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True,
                                  clear_debug_text=True,
                                  screen_height=1000,
                                  screen_width=1000)
            renderer.render_env(show=True,
                                show_observations=False,
                                show_predictions=False)
        i = 0
        while not env.dones['__all__'] and i <= max_episode_steps:
            for agent_id, agent in enumerate(env.agents):
                way_point: WayPoint = ctl.get_way_point_before_or_at_step(
                    agent_id, i)
                assert agent.position == way_point.position, \
                    "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
                                                                    way_point.position)
            actions = ctl.act(i)
            print("actions for {}: {}".format(i, actions))

            obs, all_rewards, done, _ = env.step(actions)

            if rendering:
                renderer.render_env(show=True,
                                    show_observations=False,
                                    show_predictions=False)

            i += 1
示例#20
0
def evaluate(n_episodes):
    run = SUBMISSIONS["ato"]
    config, run = init_run(run)
    agent = get_agent(config, run)
    env = get_env(config, rl=True)
    env_renderer = RenderTool(env, screen_width=8800)
    returns = []
    pcs = []
    malfs = []

    for _ in tqdm(range(n_episodes)):

        obs, _ = env.reset(regenerate_schedule=True, regenerate_rail=True)
        if RENDER:
            env_renderer.reset()
            env_renderer.render_env(show=True,
                                    frames=True,
                                    show_observations=False)

        if not obs:
            break

        steps = 0
        ep_return = 0
        done = defaultdict(lambda: False)

        while not done['__all__']:
            actions = agent.compute_actions(obs, explore=False)
            obs, all_rewards, done, info = env.step(actions)
            if RENDER:
                env_renderer.render_env(show=True,
                                        frames=True,
                                        show_observations=False)
            print('.', end='', flush=True)
            steps += 1
            ep_return += np.sum(list(all_rewards.values()))

        pc = np.sum(np.array([1 for a in env.agents if is_done(a)
                              ])) / env.get_num_agents()
        print("EPISODE PC:", pc)
        n_episodes += 1
        pcs.append(pc)
        returns.append(ep_return /
                       (env._max_episode_steps * env.get_num_agents()))
        malfs.append(
            np.sum([a.malfunction_data['nr_malfunctions']
                    for a in env.agents]))
    return pcs, returns, malfs
示例#21
0
    def __init__(self,
                 n_cars=3,
                 n_acts=5,
                 min_obs=-1,
                 max_obs=1,
                 n_nodes=2,
                 ob_radius=10,
                 x_dim=36,
                 y_dim=36,
                 feats='all'):

        self.tree_obs = tree_observation.TreeObservation(n_nodes)
        self.n_cars = n_cars
        self.n_nodes = n_nodes
        self.ob_radius = ob_radius
        self.feats = feats

        rail_gen = sparse_rail_generator(max_num_cities=3,
                                         seed=666,
                                         grid_mode=False,
                                         max_rails_between_cities=2,
                                         max_rails_in_city=3)

        self._rail_env = RailEnv(
            width=x_dim,
            height=y_dim,
            rail_generator=rail_gen,
            schedule_generator=sparse_schedule_generator(speed_ration_map),
            number_of_agents=n_cars,
            malfunction_generator_and_process_data=malfunction_from_params(
                stochastic_data),
            obs_builder_object=self.tree_obs)

        self.renderer = RenderTool(self._rail_env, gl="PILSVG")
        self.action_dict = dict()
        self.info = dict()
        self.old_obs = dict()
示例#22
0
def main(args):
    try:
        opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""])
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        sys.exit(2)
    sleep_for_animation = True
    for o, a in opts:
        if o in ("--sleep-for-animation"):
            sleep_for_animation = str2bool(a)
        else:
            assert False, "unhandled option"

    # Initiate the Predictor
    custom_predictor = ShortestPathPredictorForRailEnv(10)

    # Pass the Predictor to the observation builder
    custom_obs_builder = ObservePredictions(custom_predictor)

    # Initiate Environment
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=5,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=3,
                  obs_builder_object=custom_obs_builder)

    obs, info = env.reset()
    env_renderer = RenderTool(env, gl="PILSVG")

    # We render the initial step and show the obsered cells as colored boxes
    env_renderer.render_env(show=True,
                            frames=True,
                            show_observations=True,
                            show_predictions=False)

    action_dict = {}
    for step in range(100):
        for a in range(env.get_num_agents()):
            action = np.random.randint(0, 5)
            action_dict[a] = action
        obs, all_rewards, done, _ = env.step(action_dict)
        print("Rewards: ", all_rewards, "  [done=", done, "]")
        env_renderer.render_env(show=True,
                                frames=True,
                                show_observations=True,
                                show_predictions=False)
        if sleep_for_animation:
            time.sleep(0.5)
示例#23
0
def env_gradual_update(input_env, agent=False, hardness_lvl=1):

    agent_num = input_env.number_of_agents
    env_width = input_env.width + 4
    env_height = input_env.height + 4

    map_agent_ratio = int(np.round(((env_width + env_height) / 2) / 5 - 2))

    if map_agent_ratio > 0:
        agent_num = int(np.round(((env_width + env_height) / 2) / 5 - 2))
    else:
        agent_num = 1

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
示例#24
0
def create_env(nr_start_goal=10,
               nr_extra=2,
               min_dist=8,
               max_dist=99999,
               nr_agent=10,
               seed=0,
               render_mode='PIL'):
    env = RailEnv(width=30,
                  height=30,
                  rail_generator=complex_rail_generator(
                      nr_start_goal, nr_extra, min_dist, max_dist, seed),
                  schedule_generator=complex_schedule_generator(),
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=nr_agent)
    env_renderer = RenderTool(env, gl=render_mode)
    obs = env.reset()

    return env, env_renderer, obs
示例#25
0
def main(args):
    try:
        opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""])
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        sys.exit(2)
    sleep_for_animation = True
    for o, a in opts:
        if o in ("--sleep-for-animation"):
            sleep_for_animation = str2bool(a)
        else:
            assert False, "unhandled option"

    env = RailEnv(width=7,
                  height=7,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=5,
                                                        max_dist=99999,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=SingleAgentNavigationObs())

    obs, info = env.reset()
    env_renderer = RenderTool(env)
    env_renderer.render_env(show=True, frames=True, show_observations=True)
    for step in range(100):
        action = np.argmax(obs[0]) + 1
        obs, all_rewards, done, _ = env.step({0: action})
        print("Rewards: ", all_rewards, "  [done=", done, "]")
        env_renderer.render_env(show=True, frames=True, show_observations=True)
        if sleep_for_animation:
            time.sleep(0.1)
        if done["__all__"]:
            break
    env_renderer.close_window()
示例#26
0
def env_random_update(input_env, decay, agent=False, hardness_lvl=1):

    agent_num = np.random.randint(1, 5)
    env_width = (agent_num + 2) * 5
    env_height = (agent_num + 2) * 5

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
示例#27
0
                  max_num_cities=n_cities,
                  grid_mode=False,
                  max_rails_between_cities=max_rails_between_cities,
                  max_rails_in_city=max_rails_in_city),
              schedule_generator=sparse_schedule_generator(speed_profiles),
              number_of_agents=n_agents,
              malfunction_generator_and_process_data=malfunction_from_params(
                  malfunction_parameters),
              obs_builder_object=tree_observation,
              random_seed=seed)

env.reset(regenerate_schedule=True, regenerate_rail=True)

# Setup renderer

env_renderer = RenderTool(env)
'''
env_renderer.render_env(show=True,show_predictions=False)
time.sleep(5)
env_renderer.close_window()
'''
n_features_per_node = env.obs_builder.observation_dim
n_nodes = 0
for i in range(observation_tree_depth + 1):
    n_nodes += np.power(4, i)
state_size = n_features_per_node * n_nodes

action_size = 5

# Max number of steps per episode
# This is the official formula used during evaluations
示例#28
0
env = RailEnv(width=x_dim,
              height=y_dim,
              rail_generator=sparse_rail_generator(max_num_cities=3,
                                                   # Number of cities in map (where train stations are)
                                                   seed=1,  # Random seed
                                                   grid_mode=False,
                                                   max_rails_between_cities=2,
                                                   max_rails_in_city=4),
              schedule_generator=sparse_schedule_generator(speed_ration_map),
              number_of_agents=n_agents,
              malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
              obs_builder_object=TreeObservation)
env.reset()

env_renderer = RenderTool(env, gl="PILSVG", )
num_features_per_node = env.obs_builder.observation_dim

tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
    nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5

# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
    n_trials = 60000
max_steps = int(3 * (env.height + env.width))
eps = 1.
eps_end = 0.005
示例#29
0
def main(args, dir):
    '''
	
	:param args: 
	:return: 
	Episodes to debug (set breakpoint in episodes loop to debug):
	- ep = 3, agent 1 spawns in front of 3, blocking its path; 0 and 2 are in a deadlock since they have same priority
	- ep = 4, agents stop because of wrong priorities even though the conflict zone wasn't entered,
	- ep = 14, 
	'''
    rail_generator = sparse_rail_generator(
        max_num_cities=args.max_num_cities,
        seed=args.seed,
        grid_mode=args.grid_mode,
        max_rails_between_cities=args.max_rails_between_cities,
        max_rails_in_city=args.max_rails_in_city,
    )

    # Maps speeds to % of appearance in the env
    speed_ration_map = {
        1.: 0.25,  # Fast passenger train
        1. / 2.: 0.25,  # Fast freight train
        1. / 3.: 0.25,  # Slow commuter train
        1. / 4.: 0.25
    }  # Slow freight train

    observation_builder = GraphObsForRailEnv(
        predictor=ShortestPathPredictorForRailEnv(
            max_depth=args.prediction_depth),
        bfs_depth=4)

    env = RailEnv(
        width=args.width,
        height=args.height,
        rail_generator=rail_generator,
        schedule_generator=sparse_schedule_generator(speed_ration_map),
        number_of_agents=args.num_agents,
        obs_builder_object=observation_builder,
        malfunction_generator_and_process_data=malfunction_from_params(
            parameters={
                'malfunction_rate':
                args.malfunction_rate,  # Rate of malfunction occurrence
                'min_duration':
                args.min_duration,  # Minimal duration of malfunction
                'max_duration':
                args.max_duration  # Max duration of malfunction
            }))

    if args.render:
        env_renderer = RenderTool(env,
                                  agent_render_variant=AgentRenderVariant.
                                  AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True)

    sm = stateMachine()
    tb = TestBattery(env, observation_builder)

    state_machine_action_dict = {}
    railenv_action_dict = {}
    # max_time_steps = env.compute_max_episode_steps(args.width, args.height)
    max_time_steps = 200
    T_rewards = []  # List of episodes rewards
    T_Qs = []  # List of q values
    T_num_done_agents = []  # List of number of done agents for each episode
    T_all_done = []  # If all agents completed in each episode
    T_episodes = []  # Time taken for each episode

    if args.save_image and not os.path.isdir("image_dump"):
        os.makedirs("image_dump")

    step_taken = 0
    total_step_taken = 0
    total_episodes = 0
    step_times = []  # Time taken for each step

    for ep in range(args.num_episodes):
        # Reset info at the beginning of an episode
        start_time = time.time()  # Take time of one episode

        if args.generate_baseline:
            if not os.path.isdir("image_dump/" + str(dir)) and args.save_image:
                os.makedirs("image_dump/" + str(dir))
        else:
            if not os.path.isdir("image_dump/" + str(ep)) and args.save_image:
                os.makedirs("image_dump/" + str(ep))

        state, info = env.reset()
        tb.reset()

        if args.render:
            env_renderer.reset()
        reward_sum, all_done = 0, False  # reward_sum contains the cumulative reward obtained as sum during the steps
        num_done_agents = 0

        state_machine_action = {}
        for i in range(env.number_of_agents):
            state_machine_action[i] = 0

        for step in range(max_time_steps):
            start_step_time = time.time()

            #if step % 10 == 0:
            #	print(step)

            # Test battery
            # see test_battery.py
            triggers = tb.tests(state, args.prediction_depth,
                                state_machine_action)
            # state machine based on triggers of test battery
            # see state_machine.py
            state_machine_action = sm.act(
                triggers)  # State machine picks action

            for a in range(env.get_num_agents()):
                #if info['action_required'][a]:
                #	#railenv_action = observation_builder.choose_railenv_action(a, state_machine_action[a])
                #	railenv_action = observation_builder.choose_railenv_action(a, state_machine_action[a])
                #	state_machine_action_dict.update({a: state_machine_action})
                #	railenv_action_dict.update({a: railenv_action})
                # railenv_action = observation_builder.choose_railenv_action(a, state_machine_action[a])
                railenv_action = observation_builder.choose_railenv_action(
                    a, state_machine_action[a])
                state_machine_action_dict.update({a: state_machine_action})
                railenv_action_dict.update({a: railenv_action})

            state, reward, done, info = env.step(
                railenv_action_dict)  # Env step

            if args.generate_baseline:
                #env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
                env_renderer.render_env(show=False,
                                        show_observations=False,
                                        show_predictions=True)
            else:
                env_renderer.render_env(show=True,
                                        show_observations=False,
                                        show_predictions=True)

            if args.generate_baseline:
                if args.save_image:
                    env_renderer.save_image("image_dump/" + str(dir) +
                                            "/image_" + str(step) + "_.png")
            else:
                if args.save_image:
                    env_renderer.save_image("image_dump/" + str(ep) +
                                            "/image_" + str(step) + "_.png")

            if args.debug:
                for a in range(env.get_num_agents()):
                    log('\n\n#########################################')
                    log('\nInfo for agent {}'.format(a))
                    #log('\npath : {}'.format(state[a]["path"]))
                    log('\noverlap : {}'.format(state[a]["overlap"]))
                    log('\ndirection : {}'.format(state[a]["direction"]))
                    log('\nOccupancy, first layer: {}'.format(
                        state[a]["occupancy"]))
                    log('\nOccupancy, second layer: {}'.format(
                        state[a]["conflict"]))
                    log('\nForks: {}'.format(state[a]["forks"]))
                    log('\nTarget: {}'.format(state[a]["target"]))
                    log('\nPriority: {}'.format(state[a]["priority"]))
                    log('\nMax priority encountered: {}'.format(
                        state[a]["max_priority"]))
                    log('\nNum malfunctioning agents (globally): {}'.format(
                        state[a]["n_malfunction"]))
                    log('\nNum agents ready to depart (globally): {}'.format(
                        state[a]["ready_to_depart"]))
                    log('\nStatus: {}'.format(info['status'][a]))
                    log('\nPosition: {}'.format(env.agents[a].position))
                    log('\nTarget: {}'.format(env.agents[a].target))
                    log('\nMoving? {} at speed: {}'.format(
                        env.agents[a].moving, info['speed'][a]))
                    log('\nAction required? {}'.format(
                        info['action_required'][a]))
                    log('\nState machine action: {}'.format(
                        state_machine_action_dict[a]))
                    log('\nRailenv action: {}'.format(railenv_action_dict[a]))
                    log('\nRewards: {}'.format(reward[a]))
                    log('\n\n#########################################')

            reward_sum += sum(reward[a] for a in range(env.get_num_agents()))

            step_taken = step
            time_taken_step = time.time() - start_step_time
            step_times.append(time_taken_step)

            if done['__all__']:
                all_done = True
                break

        total_step_taken += step_taken

        time_taken = time.time() - start_time  # Time taken for one episode
        total_episodes = ep

        # Time metrics - too precise
        avg_time_step = sum(step_times) / step_taken
        #print("Avg time step: " + str(avg_time_step))

        # No need to close the renderer since env parameter sizes stay the same
        T_rewards.append(reward_sum)
        # Compute num of agents that reached their target
        for a in range(env.get_num_agents()):
            if done[a]:
                num_done_agents += 1
        percentage_done_agents = num_done_agents / env.get_num_agents()
        log("\nDone agents in episode: {}".format(percentage_done_agents))
        T_num_done_agents.append(
            percentage_done_agents)  # In proportion to total
        T_all_done.append(all_done)

    # Average number of agents that reached their target
    avg_done_agents = sum(T_num_done_agents) / len(T_num_done_agents) if len(
        T_num_done_agents) > 0 else 0
    avg_reward = sum(T_rewards) / len(T_rewards) if len(T_rewards) > 0 else 0
    avg_norm_reward = avg_reward / (max_time_steps / env.get_num_agents())

    avg_ep_time = sum(T_episodes) / args.num_episodes

    if total_episodes == 0:
        total_episodes = 1

    log("\nSeed: " + str(args.seed) \
      + "\t | Avg_done_agents: " + str(avg_done_agents)\
      + "\t | Avg_reward: " + str(avg_reward)\
      + "\t | Avg_norm_reward: " + str(avg_norm_reward)\
      + "\t | Max_num_time_steps: " + str(max_time_steps)\
      + "\t | Avg_num_time_steps: " + str(total_step_taken/total_episodes)
            + "\t | Avg episode time: " + str(avg_ep_time))
    rail_generator=rail_generator,
    schedule_generator=schedule_generator,
    number_of_agents=nr_trains,
    obs_builder_object=observation_builder,
    malfunction_generator_and_process_data=malfunction_from_params(
        stochastic_data),
    remove_agents_at_target=True)
observations, information = local_env.reset()

# print("observations:", observations)

# Initiate the renderer
env_renderer = RenderTool(
    local_env,
    gl="PILSVG",
    agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
    show_debug=False,
    screen_height=1200,  # Adjust these parameters to fit your resolution
    screen_width=1800)  # Adjust these parameters to fit your resolution

######### Get arguments of the script #########
parser = argparse.ArgumentParser()
parser.add_argument("-step", type=int, help="steps")
args = parser.parse_args()

######### Custom controller: imports #########
controller = GreedyAgent(218, local_env.action_space[0])


######### Define custom controller #########
def my_controller(obs, number_of_agents, astar_paths_readable, timestamp):