def create_rail_env(args, load_env=""):
    '''
    Build a RailEnv object with the specified parameters,
    as described in the .yml file
    '''
    # Check if an environment file is provided
    if load_env:
        rail_generator = rail_from_file(load_env)
    else:
        rail_generator = sparse_rail_generator(
            max_num_cities=args.env.max_cities,
            grid_mode=args.env.grid,
            max_rails_between_cities=args.env.max_rails_between_cities,
            max_rails_in_city=args.env.max_rails_in_cities,
            seed=args.env.seed)

    # Build predictor and observator
    obs_type = args.policy.type.get_true_key()
    if PREDICTORS[obs_type] is ShortestDeviationPathPredictor:
        predictor = PREDICTORS[obs_type](
            max_depth=args.observator.max_depth,
            max_deviations=args.predictor.max_depth)
    else:
        predictor = PREDICTORS[obs_type](max_depth=args.predictor.max_depth)
    observator = OBSERVATORS[obs_type](args.observator.max_depth, predictor)

    # Initialize malfunctions
    malfunctions = None
    if args.env.malfunctions.enabled:
        malfunctions = ParamMalfunctionGen(
            MalfunctionParameters(
                malfunction_rate=args.env.malfunctions.rate,
                min_duration=args.env.malfunctions.min_duration,
                max_duration=args.env.malfunctions.max_duration))

    # Initialize agents speeds
    speed_map = None
    if args.env.variable_speed:
        speed_map = {1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}
    schedule_generator = sparse_schedule_generator(speed_map,
                                                   seed=args.env.seed)

    # Build the environment
    return RailEnvWrapper(params=args,
                          width=args.env.width,
                          height=args.env.height,
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=args.env.num_trains,
                          obs_builder_object=observator,
                          malfunction_generator=malfunctions,
                          remove_agents_at_target=True,
                          random_seed=args.env.seed)
Beispiel #2
0
    def _launch(self):
        print("NEW ENV LAUNCHED")
        n_agents, n_cities, dim = get_round_2_env()

        rail_generator = sparse_rail_generator(
            seed=self._config['seed'],
            max_num_cities=n_cities,
            grid_mode=self._config['grid_mode'],
            max_rails_between_cities=self._config['max_rails_between_cities'],
            max_rails_in_city=self._config['max_rails_in_city'])

        malfunction_generator = NoMalfunctionGen()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = ParamMalfunctionGen(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            env = RailEnv(width=dim,
                          height=dim,
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=n_agents,
                          malfunction_generator=malfunction_generator,
                          obs_builder_object=self._observation.builder(),
                          remove_agents_at_target=False,
                          random_seed=self._config['seed'],
                          use_renderer=self._env_config.get('render'))

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
Beispiel #3
0
    def _launch(self):
        rail_generator = self.get_rail_generator()

        malfunction_generator = NoMalfunctionGen()
        if {'malfunction_rate', 'malfunction_min_duration', 'malfunction_max_duration'} <= self._config.keys():
            print("MALFUNCTIONS POSSIBLE")
            params = MalfunctionParameters(malfunction_rate=1 / self._config['malfunction_rate'],
                                           max_duration=self._config['malfunction_max_duration'],
                                           min_duration=self._config['malfunction_min_duration'])
            malfunction_generator = ParamMalfunctionGen(params)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v) for k, v in self._config['speed_ratio_map'].items()
            }
        if self._gym_env_class == SequentialFlatlandGymEnv:
            schedule_generator = SequentialSparseSchedGen(speed_ratio_map, seed=1)
        else:
            schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            if self._fine_tune_env_path is None:
                env = RailEnv(
                    width=self._config['width'],
                    height=self._config['height'],
                    rail_generator=rail_generator,
                    schedule_generator=schedule_generator,
                    number_of_agents=self._config['number_of_agents'],
                    malfunction_generator=malfunction_generator,
                    obs_builder_object=self._observation.builder(),
                    remove_agents_at_target=True,
                    random_seed=self._config['seed'],
                    use_renderer=self._env_config.get('render')
                )
                env.reset()
            else:
                env, _ = RailEnvPersister.load_new(self._fine_tune_env_path)
                env.reset(regenerate_rail=False, regenerate_schedule=False)
                env.obs_builder = self._observation.builder()
                env.obs_builder.set_env(env)

        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
    def _launch(self):
        rail_generator = self.get_rail_generator()

        malfunction_generator = NoMalfunctionGen()
        if {
                'malfunction_rate', 'malfunction_min_duration',
                'malfunction_max_duration'
        } <= self._config.keys():
            stochastic_data = {
                'malfunction_rate': self._config['malfunction_rate'],
                'min_duration': self._config['malfunction_min_duration'],
                'max_duration': self._config['malfunction_max_duration']
            }
            malfunction_generator = ParamMalfunctionGen(stochastic_data)

        speed_ratio_map = None
        if 'speed_ratio_map' in self._config:
            speed_ratio_map = {
                float(k): float(v)
                for k, v in self._config['speed_ratio_map'].items()
            }
        schedule_generator = sparse_schedule_generator(speed_ratio_map)

        env = None
        try:
            print("GENERATE NEW ENV WITH", self._prev_num_agents, "AGENTS")
            env = RailEnv(width=self._config['width'],
                          height=self._config['height'],
                          rail_generator=rail_generator,
                          schedule_generator=schedule_generator,
                          number_of_agents=self._prev_num_agents,
                          malfunction_generator=malfunction_generator,
                          obs_builder_object=self._observation.builder(),
                          remove_agents_at_target=False,
                          random_seed=self._config['seed'],
                          use_renderer=self._env_config.get('render'))

            env.reset()
        except ValueError as e:
            logging.error("=" * 50)
            logging.error(f"Error while creating env: {e}")
            logging.error("=" * 50)

        return env
Beispiel #5
0
def get_env(config=None, rl=False):
    n_agents = 16
    schedule_generator = sparse_schedule_generator(None)

    rail_generator = sparse_rail_generator(
        seed=seed,
        max_num_cities=3,
        grid_mode=False,
        max_rails_between_cities=2,
        max_rails_in_city=4,
    )

    if rl:
        obs_builder = make_obs("combined", {
            "path": None,
            "simple_meta": None
        }).builder()
    else:
        obs_builder = DummyObs()

    params = MalfunctionParameters(malfunction_rate=1 / 1000,
                                   max_duration=50,
                                   min_duration=20)
    malfunction_generator = ParamMalfunctionGen(params)

    env = RailEnv(
        width=28,
        height=28,
        rail_generator=rail_generator,
        schedule_generator=schedule_generator,
        number_of_agents=n_agents,
        malfunction_generator=malfunction_generator,
        obs_builder_object=obs_builder,
        remove_agents_at_target=True,
        random_seed=seed,
    )

    return env
def get_env(config=None, rl=False):
    n_agents = 32
    schedule_generator = sparse_schedule_generator(None)

    rail_generator = sparse_rail_generator(
        seed=seed,
        max_num_cities=4,
        grid_mode=False,
        max_rails_between_cities=2,
        max_rails_in_city=4,
    )

    if rl:
        obs_builder = make_obs(
            config["env_config"]['observation'],
            config["env_config"].get('observation_config')).builder()
    else:
        obs_builder = DummyObs()

    params = MalfunctionParameters(malfunction_rate=1 / 1000,
                                   max_duration=50,
                                   min_duration=20)
    malfunction_generator = ParamMalfunctionGen(params)

    env = RailEnv(
        width=32,
        height=32,
        rail_generator=rail_generator,
        schedule_generator=schedule_generator,
        number_of_agents=n_agents,
        malfunction_generator=malfunction_generator,
        obs_builder_object=obs_builder,
        remove_agents_at_target=True,
        random_seed=seed,
    )

    return env
Beispiel #7
0
from flatland.envs.agent_utils import RailAgentStatus

from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator, sparse_rail_generator
from flatland.utils.rendertools import RenderTool

rail_generator = sparse_rail_generator(seed=0,
                                       max_num_cities=4,
                                       grid_mode=False,
                                       max_rails_between_cities=2,
                                       max_rails_in_city=2)

malfunction_generator = ParamMalfunctionGen(
    MalfunctionParameters(malfunction_rate=10,
                          min_duration=20,
                          max_duration=50))

speed_ratio_map = None
speed_ratio_map = {1: 1}
schedule_generator = sparse_schedule_generator(speed_ratio_map)

n_agents = 5
env = RailEnv(
    width=25,
    height=25,
    rail_generator=rail_generator,
    schedule_generator=schedule_generator,
    number_of_agents=n_agents,
    malfunction_generator=malfunction_generator,
    obs_builder_object=TreeObsForRailEnv(
flags = parser.parse_args()


# Seeded RNG so we can replicate our results
np.random.seed(0)

# We need to either load in some pre-generated railways from disk, or else create a random railway generator.
if flags.load_railways:
      rail_generator, schedule_generator = load_precomputed_railways(project_root, flags)
else: rail_generator, schedule_generator = create_random_railways(project_root)

# Create the Flatland environment
env = RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents,
              rail_generator=rail_generator,
              schedule_generator=schedule_generator,
              malfunction_generator=ParamMalfunctionGen(MalfunctionParameters(1 / 8000, 15, 50)),
              obs_builder_object=TreeObservation(max_depth=flags.tree_depth))

# After training we want to render the results so we also load a renderer
env_renderer = RenderTool(env, gl="PILSVG", screen_width=800, screen_height=800, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX)

# Calculate the state size based on the number of nodes in the tree observation
num_features_per_node = env.obs_builder.observation_dim
num_nodes = sum(np.power(4, i) for i in range(flags.tree_depth + 1))
state_size = num_nodes * num_features_per_node
action_size = 5

# Add some variables to keep track of the progress
scores_window, steps_window, collisions_window, done_window = [deque(maxlen=200) for _ in range(4)]
agent_obs = [None] * flags.num_agents
agent_obs_buffer = [None] * flags.num_agents