Beispiel #1
0
def test_malfunction_before_entry():
    """Tests that malfunctions are working properly for agents before entering the environment!"""
    # Set fixed malfunction duration for this test
    stochastic_data = MalfunctionParameters(malfunction_rate=2,  # Rate of malfunction occurence
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    env.reset(False, False, False, random_seed=10)
    env.agents[0].target = (0, 0)

    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 10
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 10
    assert env.agents[4].malfunction_data['malfunction'] == 10
    assert env.agents[5].malfunction_data['malfunction'] == 10
    assert env.agents[6].malfunction_data['malfunction'] == 10
    assert env.agents[7].malfunction_data['malfunction'] == 10
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10
Beispiel #2
0
def test_path_not_exists(rendering=False):
    rail, rail_map = make_simple_rail_unconnected()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # south dead-end
        0,  # north
        (0, 3),  # north dead-end
        False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
def test_malfanction_from_params():
    """
    Test loading malfunction from
    Returns
    -------

    """
    stochastic_data = MalfunctionParameters(
        malfunction_rate=1000,  # Rate of malfunction occurence
        min_duration=2,  # Minimal duration of malfunction
        max_duration=5  # Max duration of malfunction
    )
    rail, rail_map = make_simple_rail2()

    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=10,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data))
    env.reset()
    assert env.malfunction_process_data.malfunction_rate == 1000
    assert env.malfunction_process_data.min_duration == 2
    assert env.malfunction_process_data.max_duration == 5
Beispiel #4
0
def demo(args=None):
    """Demo script to check installation"""
    env = RailEnv(width=15,
                  height=15,
                  rail_generator=complex_rail_generator(nr_start_goal=10,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=5)

    env._max_episode_steps = int(15 * (env.width + env.height))
    env_renderer = RenderTool(env)

    while True:
        obs, info = env.reset()
        _done = False
        # Run a single episode here
        step = 0
        while not _done:
            # Compute Action
            _action = {}
            for _idx, _ in enumerate(env.agents):
                _action[_idx] = np.random.randint(0, 5)
            obs, all_rewards, done, _ = env.step(_action)
            _done = done['__all__']
            step += 1
            env_renderer.render_env(show=True,
                                    frames=False,
                                    show_observations=False,
                                    show_predictions=False)
            time.sleep(0.3)
    return 0
Beispiel #5
0
def test_global_obs():
    rail, rail_map = make_simple_rail()

    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())

    global_obs, info = env.reset()

    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

    assert (global_obs[0][0].shape == rail_map.shape + (16, ))

    rail_map_recons = np.zeros_like(rail_map)
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)

    assert (rail_map_recons.all() == rail_map.all())

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def render_test(parameters, test_nr=0, nr_examples=5):
    for trial in range(nr_examples):
        # Reset the env
        print(
            'Showing {} Level {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.
            format(test_nr, trial, parameters[0], parameters[1],
                   parameters[2]))
        file_name = "./Tests/{}/Level_{}.pkl".format(test_nr, trial)

        env = RailEnv(
            width=1,
            height=1,
            rail_generator=rail_from_file(file_name),
            obs_builder_object=TreeObsForRailEnv(max_depth=2),
            number_of_agents=1,
        )
        env_renderer = RenderTool(
            env,
            gl="PILSVG",
        )
        env_renderer.set_new_rail()

        env.reset(False, False)
        env_renderer.render_env(show=True, show_observations=False)

        time.sleep(0.1)
        env_renderer.close_window()
    return
def test_normalize_features():

    random.seed(1)
    np.random.seed(1)
    max_depth = 4

    for i in range(10):
        tree_observer = TreeObsForRailEnv(max_depth=max_depth)
        next_rand_number = random.randint(0, 100)

        env = RailEnv(width=10,
                      height=10,
                      rail_generator=complex_rail_generator(
                          nr_start_goal=10,
                          nr_extra=1,
                          min_dist=8,
                          max_dist=99999,
                          seed=next_rand_number),
                      schedule_generator=complex_schedule_generator(),
                      number_of_agents=1,
                      obs_builder_object=tree_observer)

        obs, all_rewards, done, _ = env.step({0: 0})

        obs_new = tree_observer.get()
        # data, distance, agent_data = split_tree(tree=np.array(obs_old), num_features_per_node=11)
        data_normalized = normalize_observation(obs_new,
                                                max_depth,
                                                observation_radius=10)

        filename = 'testdata/test_array_{}.csv'.format(i)
        data_loaded = np.loadtxt(filename, delimiter=',')

        assert np.allclose(data_loaded, data_normalized)
Beispiel #8
0
    def __init__(self,
                 width,
                 height,
                 rail_generator,
                 number_of_agents,
                 remove_agents_at_target,
                 obs_builder_object,
                 wait_for_all_done,
                 schedule_generator=random_schedule_generator(),
                 name=None):
        super().__init__()

        self.env = RailEnv(
            width=width,
            height=height,
            rail_generator=rail_generator,
            schedule_generator=schedule_generator,
            number_of_agents=number_of_agents,
            obs_builder_object=obs_builder_object,
            remove_agents_at_target=remove_agents_at_target,
        )

        self.wait_for_all_done = wait_for_all_done
        self.env_renderer = None
        self.agents_done = []
        self.frame_step = 0
        self.name = name
        self.number_of_agents = number_of_agents

        # Track when targets are reached. Ony used for correct reward propagation
        # when using wait_for_all_done=True
        self.at_target = dict(
            zip(list(np.arange(self.number_of_agents)),
                [False for _ in range(self.number_of_agents)]))
Beispiel #9
0
    def regenerate(self, method=None, nAgents=0, env=None):
        self.log("Regenerate size", self.regen_size_width,
                 self.regen_size_height)

        if method is None or method == "Empty":
            fnMethod = empty_rail_generator()
        elif method == "Random Cell":
            fnMethod = random_rail_generator(
                cell_type_relative_proportion=[1] * 11)
        else:
            fnMethod = complex_rail_generator(nr_start_goal=nAgents,
                                              nr_extra=20,
                                              min_dist=12,
                                              seed=int(time.time()))

        if env is None:
            self.env = RailEnv(
                width=self.regen_size_width,
                height=self.regen_size_height,
                rail_generator=fnMethod,
                number_of_agents=nAgents,
                obs_builder_object=TreeObsForRailEnv(max_depth=2))
        else:
            self.env = env
        self.env.reset(regenerate_rail=True)
        self.fix_env()
        self.set_env(self.env)
        self.view.new_env()
        self.redraw()
Beispiel #10
0
    def env_create(self, obs_builder_object):
        """
            Create a local env and remote env on which the
            local agent can operate.
            The observation builder is only used in the local env
            and the remote env uses a DummyObservationBuilder
        """
        time_start = time.time()
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_CREATE
        _request['payload'] = {}
        _response = self._remote_request(_request)
        observation = _response['payload']['observation']
        info = _response['payload']['info']
        random_seed = _response['payload']['random_seed']
        test_env_file_path = _response['payload']['env_file_path']
        time_diff = time.time() - time_start
        self.update_running_mean_stats("env_creation_wait_time", time_diff)

        if not observation:
            # If the observation is False,
            # then the evaluations are complete
            # hence return false
            return observation, info

        if self.verbose:
            print("Received Env : ", test_env_file_path)

        test_env_file_path = os.path.join(
            self.test_envs_root,
            test_env_file_path
        )
        if not os.path.exists(test_env_file_path):
            raise Exception(
                "\nWe cannot seem to find the env file paths at the required location.\n"
                "Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
                "to point to the location of the Tests folder ? \n"
                "We are currently looking at `{}` for the tests".format(self.test_envs_root)
            )

        if self.verbose:
            print("Current env path : ", test_env_file_path)
        self.current_env_path = test_env_file_path
        self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
                           schedule_generator=schedule_from_file(test_env_file_path),
                           malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path),
                           obs_builder_object=obs_builder_object)

        time_start = time.time()
        local_observation, info = self.env.reset(
            regenerate_rail=True,
            regenerate_schedule=True,
            activate_agents=False,
            random_seed=random_seed
        )
        time_diff = time.time() - time_start
        self.update_running_mean_stats("internal_env_reset_time", time_diff)
        # Use the local observation
        # as the remote server uses a dummy observation builder
        return local_observation, info
def test_get_entry_directions():
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(
                      max_depth=2,
                      predictor=ShortestPathPredictorForRailEnv()))
    env.reset()

    def _assert(position, expected):
        actual = env.get_valid_directions_on_grid(*position)
        assert actual == expected, "[{},{}] actual={}, expected={}".format(
            *position, actual, expected)

    # north dead end
    _assert((0, 3), [True, False, False, False])

    # west dead end
    _assert((3, 0), [False, False, False, True])

    # switch
    _assert((3, 3), [False, True, True, True])

    # horizontal
    _assert((3, 2), [False, True, False, True])

    # vertical
    _assert((2, 3), [True, False, True, False])

    # nowhere
    _assert((0, 0), [False, False, False, False])
Beispiel #12
0
def test_path_exists(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # north of south dead-end
        0,  # north
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (6, 6),  # south dead-end
        2,  # south
        (3, 9),  # east dead-end
        True)

    check_path(
        env,
        rail,
        (3, 0),  # east dead-end
        3,  # west
        (0, 3),  # north dead-end
        True)
    check_path(
        env,
        rail,
        (5, 6),  # east dead-end
        0,  # west
        (1, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        2,  # south
        (3, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1, 3),  # east dead-end
        0,  # north
        (3, 3),  # north dead-end
        True)
Beispiel #13
0
def load_flatland_env(env_config: Dict[str, Any]) -> RailEnv:
    """Loads a flatland environment given a config dict. Also, the possible agents in the
    environment are set"""

    env = RailEnv(**env_config)
    env.possible_agents = env.agents[:]

    return env
def load_env(env_dict, obs_builder_object=GlobalObsForRailEnv()):
    """
    Loads an env
    """
    env = RailEnv(height=4, width=4, obs_builder_object=obs_builder_object)
    env.reset(regenerate_rail=False, regenerate_schedule=False)
    RailEnvPersister.set_full_state(env, env_dict)
    return env
Beispiel #15
0
def test_random_rail_generator():
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents)
    env.reset()
    assert env.rail.grid.shape == (y_dim, x_dim)
    assert env.get_num_agents() == n_agents
def test_shortest_path_predictor_conflicts(rendering=False):
    rail, rail_map = make_invalid_simple_rail()
    env = RailEnv(
        width=rail_map.shape[1],
        height=rail_map.shape[0],
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=2,
        obs_builder_object=TreeObsForRailEnv(
            max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
    )
    env.reset()

    # set the initial position
    agent = env.agents[0]
    agent.initial_position = (5, 6)  # south dead-end
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.initial_direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    agent = env.agents[1]
    agent.initial_position = (3, 8)  # east dead-end
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.initial_direction = 3  # west
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
    agent.status = RailAgentStatus.ACTIVE

    observations, info = env.reset(False, False, True)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
    tree_0 = observations[0]
    tree_1 = observations[1]
    env.obs_builder.util_print_obs_subtree(tree_0)
    env.obs_builder.util_print_obs_subtree(tree_1)

    # check the expectations
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0,
                              "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1,
                              "agent[1]: ")
Beispiel #17
0
def test_malfunction_process():
    # Set fixed malfunction duration for this test
    stochastic_data = MalfunctionParameters(
        malfunction_rate=1,  # Rate of malfunction occurence
        min_duration=3,  # Minimal duration of malfunction
        max_duration=3  # Max duration of malfunction
    )

    rail, rail_map = make_simple_rail2()

    env = RailEnv(
        width=25,
        height=30,
        rail_generator=rail_from_grid_transition_map(rail),
        schedule_generator=random_schedule_generator(),
        number_of_agents=1,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data),
        obs_builder_object=SingleAgentNavigationObs())
    obs, info = env.reset(False, False, True, random_seed=10)

    agent_halts = 0
    total_down_time = 0
    agent_old_position = env.agents[0].position

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
    for step in range(100):
        actions = {}

        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        obs, all_rewards, done, _ = env.step(actions)

        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
            # Check that agent is not moving while malfunctioning
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

    # Check that the appropriate number of malfunctions is achieved
    assert env.agents[0].malfunction_data[
        'nr_malfunctions'] == 23, "Actual {}".format(
            env.agents[0].malfunction_data['nr_malfunctions'])

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
def test_load_env():
    env = RailEnv(10, 10)
    env.reset()
    env.load_resource('env_data.tests', 'test-10x10.mpk')

    agent_static = EnvAgent((0, 0), 2, (5, 5), False)
    env.add_agent(agent_static)
    assert env.get_num_agents() == 1
Beispiel #19
0
def main(args):
    try:
        opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""])
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        sys.exit(2)
    sleep_for_animation = True
    for o, a in opts:
        if o in ("--sleep-for-animation"):
            sleep_for_animation = str2bool(a)
        else:
            assert False, "unhandled option"

    # Initiate the Predictor
    custom_predictor = ShortestPathPredictorForRailEnv(10)

    # Pass the Predictor to the observation builder
    custom_obs_builder = ObservePredictions(custom_predictor)

    # Initiate Environment
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=5,
                                                        nr_extra=1,
                                                        min_dist=8,
                                                        max_dist=99999,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=3,
                  obs_builder_object=custom_obs_builder)

    obs, info = env.reset()
    env_renderer = RenderTool(env, gl="PILSVG")

    # We render the initial step and show the obsered cells as colored boxes
    env_renderer.render_env(show=True,
                            frames=True,
                            show_observations=True,
                            show_predictions=False)

    action_dict = {}
    for step in range(100):
        for a in range(env.get_num_agents()):
            action = np.random.randint(0, 5)
            action_dict[a] = action
        obs, all_rewards, done, _ = env.step(action_dict)
        print("Rewards: ", all_rewards, "  [done=", done, "]")
        env_renderer.render_env(show=True,
                                frames=True,
                                show_observations=True,
                                show_predictions=False)
        if sleep_for_animation:
            time.sleep(0.5)
Beispiel #20
0
def test_render_env(save_new_images=False):
    np.random.seed(100)
    oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2))
    oEnv.reset()
    oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
    oRT = rt.RenderTool(oEnv, gl="PILSVG")
    oRT.render_env(show=False)
    checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)

    oRT = rt.RenderTool(oEnv, gl="PIL")
    oRT.render_env()
    checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
Beispiel #21
0
def test_empty_rail_generator():
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents)
    env.reset()
    # Check the dimensions
    assert env.rail.grid.shape == (y_dim, x_dim)
    # Check that no grid was generated
    assert np.count_nonzero(env.rail.grid) == 0
    # Check that no agents where placed
    assert env.get_num_agents() == 0
Beispiel #22
0
def decorate_step_method(env: RailEnv) -> None:
    """Enable the step method of the env to take action dictionaries where agent keys
    are the agent ids. Flatland uses the agent handles as keys instead. This function
    decorates the step method so that it accepts an action dict where the keys are the
    agent ids
    """
    env.step_ = env.step

    def _step(self: RailEnv,
              actions: Dict[str, Union[int, float, Any]]) -> dm_env.TimeStep:
        actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()}
        return self.step_(actions_)

    env.step = tp.MethodType(_step, env)
Beispiel #23
0
    def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"):
        """ Create an Editor MVC assembly around a railenv, or create one if None.
        """
        if env is None:
            env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
                          obs_builder_object=TreeObsForRailEnv(max_depth=2))

        env.reset()

        self.editor = EditorModel(env, env_filename=env_filename)
        self.editor.view = self.view = View(self.editor, sGL=sGL)
        self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view)
        self.view.init_canvas()
        self.view.init_widgets()  # has to be done after controller
def test_save_load():
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=2,
                                                        nr_extra=5,
                                                        min_dist=6,
                                                        seed=1),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2)
    env.reset()
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
    env.save("test_save.dat")
    env.load("test_save.dat")
    assert (env.width == 10)
    assert (env.height == 10)
    assert (len(env.agents) == 2)
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
Beispiel #25
0
    def _launch(self):
        rail_generator = sparse_rail_generator(
            seed=self._config['seed'],
            max_num_cities=self._config['max_num_cities'],
            grid_mode=self._config['grid_mode'],
            max_rails_between_cities=self._config['max_rails_between_cities'],
            max_rails_in_city=self._config['max_rails_in_city'])

        malfunction_generator = no_malfunction_generator()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = malfunction_from_params(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(
                width=self._config['width'],
                height=self._config['height'],
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=self._config['number_of_agents'],
                malfunction_generator_and_process_data=malfunction_generator,
                obs_builder_object=self._observation.builder(),
                remove_agents_at_target=False,
                random_seed=self._config['seed'])

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
Beispiel #26
0
def create_rail_env(env_params, tree_observation):
    n_agents = env_params.n_agents
    x_dim = env_params.x_dim
    y_dim = env_params.y_dim
    n_cities = env_params.n_cities
    max_rails_between_cities = env_params.max_rails_between_cities
    max_rails_in_city = env_params.max_rails_in_city
    seed = env_params.seed

    # Break agents from time to time
    malfunction_parameters = MalfunctionParameters(
        malfunction_rate=env_params.malfunction_rate,
        min_duration=20,
        max_duration=50)

    return RailEnv(
        width=x_dim,
        height=y_dim,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            grid_mode=False,
            max_rails_between_cities=max_rails_between_cities,
            max_rails_in_city=max_rails_in_city),
        schedule_generator=sparse_schedule_generator(),
        number_of_agents=n_agents,
        malfunction_generator_and_process_data=malfunction_from_params(
            malfunction_parameters),
        obs_builder_object=tree_observation,
        random_seed=seed)
def train_validate_env_generator(train_set, observation):
    if train_set:
        random_seed = np.random.randint(1000)
    else:
        random_seed = np.random.randint(1000, 2000)

    test_env_no = np.random.randint(9)
    level_no = np.random.randint(2)
    random.seed(random_seed)
    np.random.seed(random_seed)

    test_envs_root = f"./test-envs/Test_{test_env_no}"
    test_env_file_path = f"Level_{level_no}.pkl"

    test_env_file_path = os.path.join(test_envs_root, test_env_file_path)
    print(
        f"Testing Environment: {test_env_file_path} with seed: {random_seed}")

    env = RailEnv(width=1,
                  height=1,
                  rail_generator=rail_from_file(test_env_file_path),
                  schedule_generator=schedule_from_file(test_env_file_path),
                  malfunction_generator_and_process_data=malfunction_from_file(
                      test_env_file_path),
                  obs_builder_object=observation)
    return env, random_seed
Beispiel #28
0
    def replay_verify(
        ctl: ControllerFromTrainruns,
        env: RailEnv,
        call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a,
        **k: None):
        """Replays this deterministic `ActionPlan` and verifies whether it is feasible.

        Parameters
        ----------
        ctl
        env
        call_back
            Called before/after each step() call. The env is passed to it.
        """
        call_back(env)
        i = 0
        while not env.dones['__all__'] and i <= env._max_episode_steps:
            for agent_id, agent in enumerate(env.agents):
                waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(
                    agent_id, i)
                assert agent.position == waypoint.position, \
                    "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
                                                                    waypoint.position)
            actions = ctl.act(i)
            print("actions for {}: {}".format(i, actions))

            obs, all_rewards, done, _ = env.step(actions)

            call_back(env)

            i += 1
def load_flatland_environment_from_file(
        file_name: str,
        load_from_package: str = None,
        obs_builder_object: ObservationBuilder = None) -> RailEnv:
    """
    Parameters
    ----------
    file_name : str
        The pickle file.
    load_from_package : str
        The python module to import from. Example: 'env_data.tests'
        This requires that there are `__init__.py` files in the folder structure we load the file from.
    obs_builder_object: ObservationBuilder
        The obs builder for the `RailEnv` that is created.


    Returns
    -------
    RailEnv
        The environment loaded from the pickle file.
    """
    if obs_builder_object is None:
        obs_builder_object = TreeObsForRailEnv(
            max_depth=2,
            predictor=ShortestPathPredictorForRailEnv(max_depth=10))
    environment = RailEnv(
        width=1,
        height=1,
        rail_generator=rail_from_file(file_name, load_from_package),
        schedule_generator=schedule_from_file(file_name, load_from_package),
        number_of_agents=1,
        obs_builder_object=obs_builder_object)
    return environment
Beispiel #30
0
def fine_tune(config, run, env: RailEnv):
    """
    Fine-tune the agent on a static env at evaluation time
    """
    RailEnvPersister.save(env, CURRENT_ENV_PATH)
    num_agents = env.get_num_agents()
    tune_time = get_tune_time(num_agents)

    def env_creator(env_config):
        return FlatlandSparse(env_config,
                              fine_tune_env_path=CURRENT_ENV_PATH,
                              max_steps=num_agents * 100)

    register_env("flatland_sparse", env_creator)
    config['num_workers'] = 3
    config['num_envs_per_worker'] = 1
    config['lr'] = 0.00001 * num_agents
    exp_an = ray.tune.run(run["agent"],
                          reuse_actors=True,
                          verbose=1,
                          stop={"time_since_restore": tune_time},
                          checkpoint_freq=1,
                          keep_checkpoints_num=1,
                          checkpoint_score_attr="episode_reward_mean",
                          config=config,
                          restore=run["checkpoint_path"])

    trial: Trial = exp_an.trials[0]
    agent_config = trial.config
    agent_config['num_workers'] = 0
    agent = trial.get_trainable_cls()(env=config["env"], config=trial.config)
    checkpoint = exp_an.get_trial_checkpoints_paths(
        trial, metric="episode_reward_mean")
    agent.restore(checkpoint[0][0])
    return agent