コード例 #1
0
def _create_single_agent_environment(seed, x_dim, y_dim, n_agents, n_cities,
                                     timed, max_rails_between_cities,
                                     max_rails_in_city, observation_builder):

    # Set the seeds
    random.seed(seed)
    np.random.seed(seed)

    # Setup the environment
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  rail_generator=sparse_rail_generator(
                      max_num_cities=n_cities,
                      seed=seed,
                      grid_mode=False,
                      max_rails_between_cities=max_rails_between_cities,
                      max_rails_in_city=max_rails_in_city),
                  schedule_generator=sparse_schedule_generator(timed=timed),
                  number_of_agents=n_agents,
                  obs_builder_object=observation_builder,
                  random_seed=seed)

    # Compute the maximum number of steps allowed.
    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))

    # Return produced environment
    return env, max_steps, x_dim, y_dim
コード例 #2
0
def train_validate_env_generator_params(train_set,
                                        n_agents,
                                        x_dim,
                                        y_dim,
                                        observation,
                                        stochastic_data,
                                        speed_ration_map,
                                        seed=1):
    if train_set:
        random_seed = np.random.randint(1000)
    else:
        random_seed = np.random.randint(1000, 2000)
    random.seed(random_seed)
    np.random.seed(random_seed)

    env = RailEnv(
        width=x_dim,
        height=y_dim,
        rail_generator=sparse_rail_generator(
            max_num_cities=3,
            # Number of cities in map (where train stations are)
            seed=seed,  # Random seed
            grid_mode=False,
            max_rails_between_cities=2,
            max_rails_in_city=3),
        schedule_generator=sparse_schedule_generator(speed_ration_map),
        number_of_agents=n_agents,
        malfunction_generator_and_process_data=malfunction_from_params(
            stochastic_data),
        # Malfunction data generator
        obs_builder_object=observation)
    return env, random_seed
コード例 #3
0
ファイル: ENV.py プロジェクト: Zeii2024/RL
    def env(self):
        # obs builder
        obs_builder_object = self.obs_builder_dict[self.obs_builder]

        env = RailEnv(
            width=self.width,  # width和height是网格grid的数量
            height=self.height,
            rail_generator=sparse_rail_generator(
                max_num_cities=self.max_num_cities,
                # Number of cities in map (where train stations are)
                seed=19,  # Random seed
                grid_mode=True,
                max_rails_between_cities=2,
                max_rails_in_city=2,
            ),
            schedule_generator=sparse_schedule_generator(
                self.speed_ration_map),
            number_of_agents=self.number_of_agents,
            malfunction_generator_and_process_data=malfunction_from_params(
                self.stochastic_data),
            # Malfunction data generator
            obs_builder_object=obs_builder_object,
            remove_agents_at_target=False,
            record_steps=True)
        return env
コード例 #4
0
    def __init__(self, observation_builder, width=12, height=12, num_agents=2):
        self.num_agents = num_agents

        self.schedule_gen = sparse_schedule_generator({
            1.: 0.25,  # Fast passenger train
            1. / 2.: 0.25,  # Fast freight train
            1. / 3.: 0.25,  # Slow commuter train
            1. / 4.: 0.25
        })

        self.stochastic_data = {
            'prop_malfunction': 0.3,  # Percentage of defective agents
            'malfunction_rate': 30,  # Rate of malfunction occurence
            'min_duration': 3,  # Minimal duration of malfunction
            'max_duration': 20  # Max duration of malfunction
        }

        self.done_last_step = {}
        self.observation_builder = observation_builder
        self.dist = {}

        self.num_of_done_agents = 0
        self.episode_step_count = 0
        self.max_steps = 40
        self.update_env_with_params(width=30,
                                    height=30,
                                    num_agents=1,
                                    max_steps=200,
                                    rail_type='sparse',
                                    rail_gen_params={
                                        'num_cities': 2,
                                        'grid_mode': False,
                                        'max_rails_between_cities': 1,
                                        'max_rails_in_city': 2
                                    })
コード例 #5
0
def create_rail_env(env_params, tree_observation):
    n_agents = env_params.n_agents
    x_dim = env_params.x_dim
    y_dim = env_params.y_dim
    n_cities = env_params.n_cities
    max_rails_between_cities = env_params.max_rails_between_cities
    max_rails_in_city = env_params.max_rails_in_city
    seed = env_params.seed

    # Break agents from time to time
    malfunction_parameters = MalfunctionParameters(
        malfunction_rate=env_params.malfunction_rate,
        min_duration=20,
        max_duration=50)

    return RailEnv(
        width=x_dim,
        height=y_dim,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            grid_mode=False,
            max_rails_between_cities=max_rails_between_cities,
            max_rails_in_city=max_rails_in_city),
        schedule_generator=sparse_schedule_generator(),
        number_of_agents=n_agents,
        malfunction_generator_and_process_data=malfunction_from_params(
            malfunction_parameters),
        obs_builder_object=tree_observation,
        random_seed=seed)
コード例 #6
0
def create_rail_env(args, load_env=""):
    '''
    Build a RailEnv object with the specified parameters,
    as described in the .yml file
    '''
    # Check if an environment file is provided
    if load_env:
        rail_generator = rail_from_file(load_env)
    else:
        rail_generator = sparse_rail_generator(
            max_num_cities=args.env.max_cities,
            grid_mode=args.env.grid,
            max_rails_between_cities=args.env.max_rails_between_cities,
            max_rails_in_city=args.env.max_rails_in_cities,
            seed=args.env.seed)

    # Build predictor and observator
    obs_type = args.policy.type.get_true_key()
    if PREDICTORS[obs_type] is ShortestDeviationPathPredictor:
        predictor = PREDICTORS[obs_type](
            max_depth=args.observator.max_depth,
            max_deviations=args.predictor.max_depth)
    else:
        predictor = PREDICTORS[obs_type](max_depth=args.predictor.max_depth)
    observator = OBSERVATORS[obs_type](args.observator.max_depth, predictor)

    # Initialize malfunctions
    malfunctions = None
    if args.env.malfunctions.enabled:
        malfunctions = ParamMalfunctionGen(
            MalfunctionParameters(
                malfunction_rate=args.env.malfunctions.rate,
                min_duration=args.env.malfunctions.min_duration,
                max_duration=args.env.malfunctions.max_duration))

    # Initialize agents speeds
    speed_map = None
    if args.env.variable_speed:
        speed_map = {1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}
    schedule_generator = sparse_schedule_generator(speed_map,
                                                   seed=args.env.seed)

    # Build the environment
    return RailEnvWrapper(params=args,
                          width=args.env.width,
                          height=args.env.height,
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=args.env.num_trains,
                          obs_builder_object=observator,
                          malfunction_generator=malfunctions,
                          remove_agents_at_target=True,
                          random_seed=args.env.seed)
コード例 #7
0
    def _launch(self):
        rail_generator = sparse_rail_generator(
            seed=self._config['seed'],
            max_num_cities=self._config['max_num_cities'],
            grid_mode=self._config['grid_mode'],
            max_rails_between_cities=self._config['max_rails_between_cities'],
            max_rails_in_city=self._config['max_rails_in_city'])

        malfunction_generator = no_malfunction_generator()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = malfunction_from_params(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(
                width=self._config['width'],
                height=self._config['height'],
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=self._config['number_of_agents'],
                malfunction_generator_and_process_data=malfunction_generator,
                obs_builder_object=self._observation.builder(),
                remove_agents_at_target=False,
                random_seed=self._config['seed'],
                # Should Below line be commented as here the env tries different configs,
                # hence opening it can be wasteful, morever the render has to be closed
                use_renderer=self._env_config.get('render'))

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
コード例 #8
0
    def _launch(self):
        print("NEW ENV LAUNCHED")
        n_agents, n_cities, dim = get_round_2_env()

        rail_generator = sparse_rail_generator(
            seed=self._config['seed'],
            max_num_cities=n_cities,
            grid_mode=self._config['grid_mode'],
            max_rails_between_cities=self._config['max_rails_between_cities'],
            max_rails_in_city=self._config['max_rails_in_city'])

        malfunction_generator = NoMalfunctionGen()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = ParamMalfunctionGen(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(width=dim,
                          height=dim,
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=n_agents,
                          malfunction_generator=malfunction_generator,
                          obs_builder_object=self._observation.builder(),
                          remove_agents_at_target=False,
                          random_seed=self._config['seed'],
                          use_renderer=self._env_config.get('render'))

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
コード例 #9
0
def create_random_railways(project_root):
    speed_ration_map = {
        1 / 1: 1.0,  # Fast passenger train
        1 / 2.: 0.0,  # Fast freight train
        1 / 3.: 0.0,  # Slow commuter train
        1 / 4.: 0.0
    }  # Slow freight train

    rail_generator = sparse_rail_generator(grid_mode=False,
                                           max_num_cities=3,
                                           max_rails_between_cities=2,
                                           max_rails_in_city=3)
    schedule_generator = sparse_schedule_generator(speed_ration_map)
    return rail_generator, schedule_generator
コード例 #10
0
    def _launch(self):
        rail_generator = self.get_rail_generator()

        malfunction_generator = NoMalfunctionGen()
        if {'malfunction_rate', 'malfunction_min_duration', 'malfunction_max_duration'} <= self._config.keys():
            print("MALFUNCTIONS POSSIBLE")
            params = MalfunctionParameters(malfunction_rate=1 / self._config['malfunction_rate'],
                                           max_duration=self._config['malfunction_max_duration'],
                                           min_duration=self._config['malfunction_min_duration'])
            malfunction_generator = ParamMalfunctionGen(params)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v) for k, v in self._config['speed_ratio_map'].items()
            }
        if self._gym_env_class == SequentialFlatlandGymEnv:
            schedule_generator = SequentialSparseSchedGen(speed_ratio_map, seed=1)
        else:
            schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            if self._fine_tune_env_path is None:
                env = RailEnv(
                    width=self._config['width'],
                    height=self._config['height'],
                    rail_generator=rail_generator,
                    schedule_generator=schedule_generator,
                    number_of_agents=self._config['number_of_agents'],
                    malfunction_generator=malfunction_generator,
                    obs_builder_object=self._observation.builder(),
                    remove_agents_at_target=True,
                    random_seed=self._config['seed'],
                    use_renderer=self._env_config.get('render')
                )
                env.reset()
            else:
                env, _ = RailEnvPersister.load_new(self._fine_tune_env_path)
                env.reset(regenerate_rail=False, regenerate_schedule=False)
                env.obs_builder = self._observation.builder()
                env.obs_builder.set_env(env)

        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
コード例 #11
0
 def _launch(self, env_params, observation):
     return RailEnv(
         width=env_params.x_dim,
         height=env_params.y_dim,
         rail_generator=sparse_rail_generator(
             max_num_cities=env_params.n_cities,
             grid_mode=False,
             max_rails_between_cities=env_params.max_rails_between_cities,
             max_rails_in_city=env_params.max_rails_in_city,
             seed=env_params.seed
         ),
         schedule_generator=sparse_schedule_generator(env_params.speed_profiles),
         number_of_agents=env_params.n_agents,
         malfunction_generator_and_process_data=malfunction_from_params(env_params.malfunction_parameters),
         obs_builder_object=observation,
         random_seed=env_params.seed
     )
コード例 #12
0
ファイル: main.py プロジェクト: MelsHakobyan96/flatland_2.0
def env_gradual_update(input_env, agent=False, hardness_lvl=1):

    agent_num = input_env.number_of_agents
    env_width = input_env.width + 4
    env_height = input_env.height + 4

    map_agent_ratio = int(np.round(((env_width + env_height) / 2) / 5 - 2))

    if map_agent_ratio > 0:
        agent_num = int(np.round(((env_width + env_height) / 2) / 5 - 2))
    else:
        agent_num = 1

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
コード例 #13
0
    def _launch(self):
        rail_generator = self.get_rail_generator()

        malfunction_generator = NoMalfunctionGen()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = ParamMalfunctionGen(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            print("GENERATE NEW ENV WITH", self._prev_num_agents, "AGENTS")
            env = RailEnv(width=self._config['width'],
                          height=self._config['height'],
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=self._prev_num_agents,
                          malfunction_generator=malfunction_generator,
                          obs_builder_object=self._observation.builder(),
                          remove_agents_at_target=False,
                          random_seed=self._config['seed'],
                          use_renderer=self._env_config.get('render'))

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
コード例 #14
0
def test_schedule_from_file_sparse():
    """
    Test to see that all parameters are loaded as expected
    Returns
    -------

    """
    # Different agent types (trains) with different speeds.
    speed_ration_map = {
        1.: 0.25,  # Fast passenger train
        1. / 2.: 0.25,  # Fast freight train
        1. / 3.: 0.25,  # Slow commuter train
        1. / 4.: 0.25
    }  # Slow freight train

    # Generate Sparse test env
    rail_generator = sparse_rail_generator(
        max_num_cities=5,
        seed=1,
        grid_mode=False,
        max_rails_between_cities=3,
        max_rails_in_city=6,
    )
    schedule_generator = sparse_schedule_generator(speed_ration_map)

    create_and_save_env(file_name="./sparse_env_test.pkl",
                        rail_generator=rail_generator,
                        schedule_generator=schedule_generator)

    # Sparse generator
    rail_generator = rail_from_file("./sparse_env_test.pkl")
    schedule_generator = schedule_from_file("./sparse_env_test.pkl")
    sparse_env_from_file = RailEnv(width=1,
                                   height=1,
                                   rail_generator=rail_generator,
                                   schedule_generator=schedule_generator)
    sparse_env_from_file.reset(True, True)

    # Assert loaded agent number is correct
    assert sparse_env_from_file.get_num_agents() == 10

    # Assert max steps is correct
    assert sparse_env_from_file._max_episode_steps == 500
コード例 #15
0
ファイル: flatland_single.py プロジェクト: wullli/flatlander
    def _launch(self):
        rail_generator = self.get_rail_generator()

        malfunction_generator = no_malfunction_generator()
        if {'malfunction_rate', 'min_duration', 'max_duration'
            } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = malfunction_from_params(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(
                width=self._config['width'],
                height=self._config['height'],
                rail_generator=rail_generator,
                schedule_generator=schedule_generator,
                number_of_agents=self._config['number_of_agents'],
                malfunction_generator_and_process_data=malfunction_generator,
                obs_builder_object=self._observation.builder(),
                remove_agents_at_target=False,
                random_seed=self._config['seed'])

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
コード例 #16
0
    def _thunk():
        env_Orig = RailEnvWrapper(
            width=x_dim,
            height=y_dim,
            rail_generator=sparse_rail_generator(
                max_num_cities=3,
                # Number of cities in map (where train stations are)
                seed=1,  # Random seed
                grid_mode=False,
                max_rails_between_cities=2,
                max_rails_in_city=3),
            schedule_generator=sparse_schedule_generator(speed_ration_map),
            number_of_agents=n_agents,
            stochastic_data=stochastic_data,  # Malfunction data generator
            obs_builder_object=TreeObservation)

        env = copy.deepcopy(env_Orig)

        # After training we want to render the results so we also load a renderer
        #env_renderer = RenderTool(env, gl="PILSVG", )

        return env
コード例 #17
0
def get_env(config=None, rl=False):
    n_agents = 16
    schedule_generator = sparse_schedule_generator(None)

    rail_generator = sparse_rail_generator(
        seed=seed,
        max_num_cities=3,
        grid_mode=False,
        max_rails_between_cities=2,
        max_rails_in_city=4,
    )

    if rl:
        obs_builder = make_obs("combined", {
            "path": None,
            "simple_meta": None
        }).builder()
    else:
        obs_builder = DummyObs()

    params = MalfunctionParameters(malfunction_rate=1 / 1000,
                                   max_duration=50,
                                   min_duration=20)
    malfunction_generator = ParamMalfunctionGen(params)

    env = RailEnv(
        width=28,
        height=28,
        rail_generator=rail_generator,
        schedule_generator=schedule_generator,
        number_of_agents=n_agents,
        malfunction_generator=malfunction_generator,
        obs_builder_object=obs_builder,
        remove_agents_at_target=True,
        random_seed=seed,
    )

    return env
コード例 #18
0
ファイル: main.py プロジェクト: MelsHakobyan96/flatland_2.0
def env_random_update(input_env, decay, agent=False, hardness_lvl=1):

    agent_num = np.random.randint(1, 5)
    env_width = (agent_num + 2) * 5
    env_height = (agent_num + 2) * 5

    if hardness_lvl == 1:

        rail_generator = complex_rail_generator(nr_start_goal=20,
                                                nr_extra=1,
                                                min_dist=9,
                                                max_dist=99999,
                                                seed=0)

        schedule_generator = complex_schedule_generator()
    else:

        rail_generator = sparse_rail_generator(nr_start_goal=9,
                                               nr_extra=1,
                                               min_dist=9,
                                               max_dist=99999,
                                               seed=0)

        schedule_generator = sparse_schedule_generator()

    global env, env_renderer, render

    if render:
        env_renderer.close_window()

    env = RailEnv(width=env_width,
                  height=env_height,
                  rail_generator=rail_generator,
                  schedule_generator=schedule_generator,
                  obs_builder_object=GlobalObsForRailEnv(),
                  number_of_agents=agent_num)

    env_renderer = RenderTool(env)
コード例 #19
0
def demo_lpg_planing():
    from flatland.envs.rail_generators import sparse_rail_generator
    from flatland.envs.schedule_generators import sparse_schedule_generator
    from flatland.envs.observations import TreeObsForRailEnv
    n_agents = 1
    x_dim = 25
    y_dim = 25
    n_cities = 4
    max_rails_between_cities = 2
    max_rails_in_city = 3
    seed = 42
    # Observation parameters
    observation_tree_depth = 2

    domain_file = "./pddl/flatland.pddl"
    problem_dir = "./pddl/flatland"
    num_problems = 6

    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)

    env = PDDLFlatlandEnv(
        width=x_dim,
        height=y_dim,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            seed=seed,
            grid_mode=False,
            max_rails_between_cities=max_rails_between_cities,
            max_rails_in_city=max_rails_in_city),
        schedule_generator=sparse_schedule_generator(),
        number_of_agents=n_agents,
        obs_builder_object=tree_observation,
        domain_file=domain_file,
        problem_dir=problem_dir)

    for problem_index in range(num_problems):
        env.fix_problem_index(problem_index)
        run_planning_flatland_demo(env, 'lpg')
コード例 #20
0
def get_env(config=None, rl=False):
    n_agents = 32
    schedule_generator = sparse_schedule_generator(None)

    rail_generator = sparse_rail_generator(
        seed=seed,
        max_num_cities=4,
        grid_mode=False,
        max_rails_between_cities=2,
        max_rails_in_city=4,
    )

    if rl:
        obs_builder = make_obs(
            config["env_config"]['observation'],
            config["env_config"].get('observation_config')).builder()
    else:
        obs_builder = DummyObs()

    params = MalfunctionParameters(malfunction_rate=1 / 1000,
                                   max_duration=50,
                                   min_duration=20)
    malfunction_generator = ParamMalfunctionGen(params)

    env = RailEnv(
        width=32,
        height=32,
        rail_generator=rail_generator,
        schedule_generator=schedule_generator,
        number_of_agents=n_agents,
        malfunction_generator=malfunction_generator,
        obs_builder_object=obs_builder,
        remove_agents_at_target=True,
        random_seed=seed,
    )

    return env
コード例 #21
0
ファイル: fl_environment.py プロジェクト: yk/youtube-flatland
    def __init__(self,
                 n_cars=3,
                 n_acts=5,
                 min_obs=-1,
                 max_obs=1,
                 n_nodes=2,
                 ob_radius=10,
                 x_dim=36,
                 y_dim=36,
                 feats='all'):

        self.tree_obs = tree_observation.TreeObservation(n_nodes)
        self.n_cars = n_cars
        self.n_nodes = n_nodes
        self.ob_radius = ob_radius
        self.feats = feats

        rail_gen = sparse_rail_generator(max_num_cities=3,
                                         seed=666,
                                         grid_mode=False,
                                         max_rails_between_cities=2,
                                         max_rails_in_city=3)

        self._rail_env = RailEnv(
            width=x_dim,
            height=y_dim,
            rail_generator=rail_gen,
            schedule_generator=sparse_schedule_generator(speed_ration_map),
            number_of_agents=n_cars,
            malfunction_generator_and_process_data=malfunction_from_params(
                stochastic_data),
            obs_builder_object=self.tree_obs)

        self.renderer = RenderTool(self._rail_env, gl="PILSVG")
        self.action_dict = dict()
        self.info = dict()
        self.old_obs = dict()
コード例 #22
0
def create_multi_agent_rail_env(seed, timed):
    n_agents = 4
    # Environment parameters
    x_dim = 25
    y_dim = 25
    n_cities = 4
    max_rails_between_cities = 2
    max_rails_in_city = 3

    # Default observation parameters
    observation_tree_depth = 2
    observation_max_path_depth = 30

    # Default (tree) observation builder
    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
    tree_observation = TreeObsForRailEnvExtended(
        max_depth=observation_tree_depth, predictor=predictor)

    random.seed(seed)
    np.random.seed(seed)

    env = RailEnv(width=x_dim,
                  height=y_dim,
                  rail_generator=sparse_rail_generator(
                      max_num_cities=n_cities,
                      seed=seed,
                      grid_mode=False,
                      max_rails_between_cities=max_rails_between_cities,
                      max_rails_in_city=max_rails_in_city),
                  schedule_generator=sparse_schedule_generator(timed=timed),
                  number_of_agents=n_agents,
                  malfunction_generator_and_process_data=None,
                  obs_builder_object=tree_observation,
                  random_seed=seed)
    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
    return env, max_steps, x_dim, y_dim, observation_tree_depth, observation_max_path_depth