def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents assert env.agents[0].malfunction_data['malfunction'] == 0 assert env.agents[1].malfunction_data['malfunction'] == 10 assert env.agents[2].malfunction_data['malfunction'] == 0 assert env.agents[3].malfunction_data['malfunction'] == 10 assert env.agents[4].malfunction_data['malfunction'] == 10 assert env.agents[5].malfunction_data['malfunction'] == 10 assert env.agents[6].malfunction_data['malfunction'] == 10 assert env.agents[7].malfunction_data['malfunction'] == 10 assert env.agents[8].malfunction_data['malfunction'] == 10 assert env.agents[9].malfunction_data['malfunction'] == 10
def test_path_not_exists(rendering=False): rail, rail_map = make_simple_rail_unconnected() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # south dead-end 0, # north (0, 3), # north dead-end False) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?")
def test_malfanction_from_params(): """ Test loading malfunction from Returns ------- """ stochastic_data = MalfunctionParameters( malfunction_rate=1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data)) env.reset() assert env.malfunction_process_data.malfunction_rate == 1000 assert env.malfunction_process_data.min_duration == 2 assert env.malfunction_process_data.max_duration == 5
def demo(args=None): """Demo script to check installation""" env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999), schedule_generator=complex_schedule_generator(), number_of_agents=5) env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) while True: obs, info = env.reset() _done = False # Run a single episode here step = 0 while not _done: # Compute Action _action = {} for _idx, _ in enumerate(env.agents): _action[_idx] = np.random.randint(0, 5) obs, all_rewards, done, _ = env.step(_action) _done = done['__all__'] step += 1 env_renderer.render_env(show=True, frames=False, show_observations=False, show_predictions=False) time.sleep(0.3) return 0
def test_global_obs(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) global_obs, info = env.reset() # we have to take step for the agent to enter the grid. global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) assert (global_obs[0][0].shape == rail_map.shape + (16, )) rail_map_recons = np.zeros_like(rail_map) for i in range(global_obs[0][0].shape[0]): for j in range(global_obs[0][0].shape[1]): rail_map_recons[i, j] = int( ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) assert (rail_map_recons.all() == rail_map.all()) # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell obs_agents_state = global_obs[0][1] obs_agents_state = obs_agents_state + 1 assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def render_test(parameters, test_nr=0, nr_examples=5): for trial in range(nr_examples): # Reset the env print( 'Showing {} Level {} with (x_dim,y_dim) = ({},{}) and {} Agents.'. format(test_nr, trial, parameters[0], parameters[1], parameters[2])) file_name = "./Tests/{}/Level_{}.pkl".format(test_nr, trial) env = RailEnv( width=1, height=1, rail_generator=rail_from_file(file_name), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=1, ) env_renderer = RenderTool( env, gl="PILSVG", ) env_renderer.set_new_rail() env.reset(False, False) env_renderer.render_env(show=True, show_observations=False) time.sleep(0.1) env_renderer.close_window() return
def test_normalize_features(): random.seed(1) np.random.seed(1) max_depth = 4 for i in range(10): tree_observer = TreeObsForRailEnv(max_depth=max_depth) next_rand_number = random.randint(0, 100) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator( nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=next_rand_number), schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=tree_observer) obs, all_rewards, done, _ = env.step({0: 0}) obs_new = tree_observer.get() # data, distance, agent_data = split_tree(tree=np.array(obs_old), num_features_per_node=11) data_normalized = normalize_observation(obs_new, max_depth, observation_radius=10) filename = 'testdata/test_array_{}.csv'.format(i) data_loaded = np.loadtxt(filename, delimiter=',') assert np.allclose(data_loaded, data_normalized)
def __init__(self, width, height, rail_generator, number_of_agents, remove_agents_at_target, obs_builder_object, wait_for_all_done, schedule_generator=random_schedule_generator(), name=None): super().__init__() self.env = RailEnv( width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=number_of_agents, obs_builder_object=obs_builder_object, remove_agents_at_target=remove_agents_at_target, ) self.wait_for_all_done = wait_for_all_done self.env_renderer = None self.agents_done = [] self.frame_step = 0 self.name = name self.number_of_agents = number_of_agents # Track when targets are reached. Ony used for correct reward propagation # when using wait_for_all_done=True self.at_target = dict( zip(list(np.arange(self.number_of_agents)), [False for _ in range(self.number_of_agents)]))
def regenerate(self, method=None, nAgents=0, env=None): self.log("Regenerate size", self.regen_size_width, self.regen_size_height) if method is None or method == "Empty": fnMethod = empty_rail_generator() elif method == "Random Cell": fnMethod = random_rail_generator( cell_type_relative_proportion=[1] * 11) else: fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time())) if env is None: self.env = RailEnv( width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod, number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2)) else: self.env = env self.env.reset(regenerate_rail=True) self.fix_env() self.set_env(self.env) self.view.new_env() self.redraw()
def env_create(self, obs_builder_object): """ Create a local env and remote env on which the local agent can operate. The observation builder is only used in the local env and the remote env uses a DummyObservationBuilder """ time_start = time.time() _request = {} _request['type'] = messages.FLATLAND_RL.ENV_CREATE _request['payload'] = {} _response = self._remote_request(_request) observation = _response['payload']['observation'] info = _response['payload']['info'] random_seed = _response['payload']['random_seed'] test_env_file_path = _response['payload']['env_file_path'] time_diff = time.time() - time_start self.update_running_mean_stats("env_creation_wait_time", time_diff) if not observation: # If the observation is False, # then the evaluations are complete # hence return false return observation, info if self.verbose: print("Received Env : ", test_env_file_path) test_env_file_path = os.path.join( self.test_envs_root, test_env_file_path ) if not os.path.exists(test_env_file_path): raise Exception( "\nWe cannot seem to find the env file paths at the required location.\n" "Did you remember to set the AICROWD_TESTS_FOLDER environment variable " "to point to the location of the Tests folder ? \n" "We are currently looking at `{}` for the tests".format(self.test_envs_root) ) if self.verbose: print("Current env path : ", test_env_file_path) self.current_env_path = test_env_file_path self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path), schedule_generator=schedule_from_file(test_env_file_path), malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path), obs_builder_object=obs_builder_object) time_start = time.time() local_observation, info = self.env.reset( regenerate_rail=True, regenerate_schedule=True, activate_agents=False, random_seed=random_seed ) time_diff = time.time() - time_start self.update_running_mean_stats("internal_env_reset_time", time_diff) # Use the local observation # as the remote server uses a dummy observation builder return local_observation, info
def test_get_entry_directions(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() def _assert(position, expected): actual = env.get_valid_directions_on_grid(*position) assert actual == expected, "[{},{}] actual={}, expected={}".format( *position, actual, expected) # north dead end _assert((0, 3), [True, False, False, False]) # west dead end _assert((3, 0), [False, False, False, True]) # switch _assert((3, 3), [False, True, True, True]) # horizontal _assert((3, 2), [False, True, False, True]) # vertical _assert((2, 3), [True, False, True, False]) # nowhere _assert((0, 0), [False, False, False, False])
def test_path_exists(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() check_path( env, rail, (5, 6), # north of south dead-end 0, # north (3, 9), # east dead-end True) check_path( env, rail, (6, 6), # south dead-end 2, # south (3, 9), # east dead-end True) check_path( env, rail, (3, 0), # east dead-end 3, # west (0, 3), # north dead-end True) check_path( env, rail, (5, 6), # east dead-end 0, # west (1, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 2, # south (3, 3), # north dead-end True) check_path( env, rail, (1, 3), # east dead-end 0, # north (3, 3), # north dead-end True)
def load_flatland_env(env_config: Dict[str, Any]) -> RailEnv: """Loads a flatland environment given a config dict. Also, the possible agents in the environment are set""" env = RailEnv(**env_config) env.possible_agents = env.agents[:] return env
def load_env(env_dict, obs_builder_object=GlobalObsForRailEnv()): """ Loads an env """ env = RailEnv(height=4, width=4, obs_builder_object=obs_builder_object) env.reset(regenerate_rail=False, regenerate_schedule=False) RailEnvPersister.set_full_state(env, env_dict) return env
def test_random_rail_generator(): n_agents = 1 x_dim = 5 y_dim = 10 # Check that a random level at with correct parameters is generated env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents) env.reset() assert env.rail.grid.shape == (y_dim, x_dim) assert env.get_num_agents() == n_agents
def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_invalid_simple_rail() env = RailEnv( width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.reset() # set the initial position agent = env.agents[0] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE agent = env.agents[1] agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE observations, info = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?") # get the trees to test obs_builder: TreeObsForRailEnv = env.obs_builder pp = pprint.PrettyPrinter(indent=4) tree_0 = observations[0] tree_1 = observations[1] env.obs_builder.util_print_obs_subtree(tree_0) env.obs_builder.util_print_obs_subtree(tree_1) # check the expectations expected_conflicts_0 = [('F', 'R')] expected_conflicts_1 = [('F', 'L')] _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def test_malfunction_process(): # Set fixed malfunction duration for this test stochastic_data = MalfunctionParameters( malfunction_rate=1, # Rate of malfunction occurence min_duration=3, # Minimal duration of malfunction max_duration=3 # Max duration of malfunction ) rail, rail_map = make_simple_rail2() env = RailEnv( width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params( stochastic_data), obs_builder_object=SingleAgentNavigationObs()) obs, info = env.reset(False, False, True, random_seed=10) agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position # Move target to unreachable position in order to not interfere with test env.agents[0].target = (0, 0) for step in range(100): actions = {} for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: agent_malfunctioning = True else: agent_malfunctioning = False if agent_malfunctioning: # Check that agent is not moving while malfunctioning assert agent_old_position == env.agents[0].position agent_old_position = env.agents[0].position total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved assert env.agents[0].malfunction_data[ 'nr_malfunctions'] == 23, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that malfunctioning data was standing around assert total_down_time > 0
def test_load_env(): env = RailEnv(10, 10) env.reset() env.load_resource('env_data.tests', 'test-10x10.mpk') agent_static = EnvAgent((0, 0), 2, (5, 5), False) env.add_agent(agent_static) assert env.get_num_agents() == 1
def main(args): try: opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""]) except getopt.GetoptError as err: print(str(err)) # will print something like "option -a not recognized" sys.exit(2) sleep_for_animation = True for o, a in opts: if o in ("--sleep-for-animation"): sleep_for_animation = str2bool(a) else: assert False, "unhandled option" # Initiate the Predictor custom_predictor = ShortestPathPredictorForRailEnv(10) # Pass the Predictor to the observation builder custom_obs_builder = ObservePredictions(custom_predictor) # Initiate Environment env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=custom_obs_builder) obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") # We render the initial step and show the obsered cells as colored boxes env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) action_dict = {} for step in range(100): for a in range(env.get_num_agents()): action = np.random.randint(0, 5) action_dict[a] = action obs, all_rewards, done, _ = env.step(action_dict) print("Rewards: ", all_rewards, " [done=", done, "]") env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) if sleep_for_animation: time.sleep(0.5)
def test_render_env(save_new_images=False): np.random.seed(100) oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) oEnv.reset() oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
def test_empty_rail_generator(): n_agents = 1 x_dim = 5 y_dim = 10 # Check that a random level at with correct parameters is generated env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents) env.reset() # Check the dimensions assert env.rail.grid.shape == (y_dim, x_dim) # Check that no grid was generated assert np.count_nonzero(env.rail.grid) == 0 # Check that no agents where placed assert env.get_num_agents() == 0
def decorate_step_method(env: RailEnv) -> None: """Enable the step method of the env to take action dictionaries where agent keys are the agent ids. Flatland uses the agent handles as keys instead. This function decorates the step method so that it accepts an action dict where the keys are the agent ids """ env.step_ = env.step def _step(self: RailEnv, actions: Dict[str, Union[int, float, Any]]) -> dm_env.TimeStep: actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} return self.step_(actions_) env.step = tp.MethodType(_step, env)
def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"): """ Create an Editor MVC assembly around a railenv, or create one if None. """ if env is None: env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() self.editor = EditorModel(env, env_filename=env_filename) self.editor.view = self.view = View(self.editor, sGL=sGL) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.init_canvas() self.view.init_widgets() # has to be done after controller
def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction agent_1_tar = env.agents[0].target agent_2_pos = env.agents[1].position agent_2_dir = env.agents[1].direction agent_2_tar = env.agents[1].target env.save("test_save.dat") env.load("test_save.dat") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) assert (agent_1_pos == env.agents[0].position) assert (agent_1_dir == env.agents[0].direction) assert (agent_1_tar == env.agents[0].target) assert (agent_2_pos == env.agents[1].position) assert (agent_2_dir == env.agents[1].direction) assert (agent_2_tar == env.agents[1].target)
def _launch(self): rail_generator = sparse_rail_generator( seed=self._config['seed'], max_num_cities=self._config['max_num_cities'], grid_mode=self._config['grid_mode'], max_rails_between_cities=self._config['max_rails_between_cities'], max_rails_in_city=self._config['max_rails_in_city']) malfunction_generator = no_malfunction_generator() if { 'malfunction_rate', 'malfunction_min_duration', 'malfunction_max_duration' } <= self._config.keys(): stochastic_data = { 'malfunction_rate': self._config['malfunction_rate'], 'min_duration': self._config['malfunction_min_duration'], 'max_duration': self._config['malfunction_max_duration'] } malfunction_generator = malfunction_from_params(stochastic_data) speed_ratio_map = None if 'speed_ratio_map' in self._config: speed_ratio_map = { float(k): float(v) for k, v in self._config['speed_ratio_map'].items() } schedule_generator = sparse_schedule_generator(speed_ratio_map) env = None try: env = RailEnv( width=self._config['width'], height=self._config['height'], rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=self._config['number_of_agents'], malfunction_generator_and_process_data=malfunction_generator, obs_builder_object=self._observation.builder(), remove_agents_at_target=False, random_seed=self._config['seed']) env.reset() except ValueError as e: logging.error("=" * 50) logging.error(f"Error while creating env: {e}") logging.error("=" * 50) return env
def create_rail_env(env_params, tree_observation): n_agents = env_params.n_agents x_dim = env_params.x_dim y_dim = env_params.y_dim n_cities = env_params.n_cities max_rails_between_cities = env_params.max_rails_between_cities max_rails_in_city = env_params.max_rails_in_city seed = env_params.seed # Break agents from time to time malfunction_parameters = MalfunctionParameters( malfunction_rate=env_params.malfunction_rate, min_duration=20, max_duration=50) return RailEnv( width=x_dim, height=y_dim, rail_generator=sparse_rail_generator( max_num_cities=n_cities, grid_mode=False, max_rails_between_cities=max_rails_between_cities, max_rails_in_city=max_rails_in_city), schedule_generator=sparse_schedule_generator(), number_of_agents=n_agents, malfunction_generator_and_process_data=malfunction_from_params( malfunction_parameters), obs_builder_object=tree_observation, random_seed=seed)
def train_validate_env_generator(train_set, observation): if train_set: random_seed = np.random.randint(1000) else: random_seed = np.random.randint(1000, 2000) test_env_no = np.random.randint(9) level_no = np.random.randint(2) random.seed(random_seed) np.random.seed(random_seed) test_envs_root = f"./test-envs/Test_{test_env_no}" test_env_file_path = f"Level_{level_no}.pkl" test_env_file_path = os.path.join(test_envs_root, test_env_file_path) print( f"Testing Environment: {test_env_file_path} with seed: {random_seed}") env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path), schedule_generator=schedule_from_file(test_env_file_path), malfunction_generator_and_process_data=malfunction_from_file( test_env_file_path), obs_builder_object=observation) return env, random_seed
def replay_verify( ctl: ControllerFromTrainruns, env: RailEnv, call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None): """Replays this deterministic `ActionPlan` and verifies whether it is feasible. Parameters ---------- ctl env call_back Called before/after each step() call. The env is passed to it. """ call_back(env) i = 0 while not env.dones['__all__'] and i <= env._max_episode_steps: for agent_id, agent in enumerate(env.agents): waypoint: Waypoint = ctl.get_waypoint_before_or_at_step( agent_id, i) assert agent.position == waypoint.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, waypoint.position) actions = ctl.act(i) print("actions for {}: {}".format(i, actions)) obs, all_rewards, done, _ = env.step(actions) call_back(env) i += 1
def load_flatland_environment_from_file( file_name: str, load_from_package: str = None, obs_builder_object: ObservationBuilder = None) -> RailEnv: """ Parameters ---------- file_name : str The pickle file. load_from_package : str The python module to import from. Example: 'env_data.tests' This requires that there are `__init__.py` files in the folder structure we load the file from. obs_builder_object: ObservationBuilder The obs builder for the `RailEnv` that is created. Returns ------- RailEnv The environment loaded from the pickle file. """ if obs_builder_object is None: obs_builder_object = TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv( width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1, obs_builder_object=obs_builder_object) return environment
def fine_tune(config, run, env: RailEnv): """ Fine-tune the agent on a static env at evaluation time """ RailEnvPersister.save(env, CURRENT_ENV_PATH) num_agents = env.get_num_agents() tune_time = get_tune_time(num_agents) def env_creator(env_config): return FlatlandSparse(env_config, fine_tune_env_path=CURRENT_ENV_PATH, max_steps=num_agents * 100) register_env("flatland_sparse", env_creator) config['num_workers'] = 3 config['num_envs_per_worker'] = 1 config['lr'] = 0.00001 * num_agents exp_an = ray.tune.run(run["agent"], reuse_actors=True, verbose=1, stop={"time_since_restore": tune_time}, checkpoint_freq=1, keep_checkpoints_num=1, checkpoint_score_attr="episode_reward_mean", config=config, restore=run["checkpoint_path"]) trial: Trial = exp_an.trials[0] agent_config = trial.config agent_config['num_workers'] = 0 agent = trial.get_trainable_cls()(env=config["env"], config=trial.config) checkpoint = exp_an.get_trial_checkpoints_paths( trial, metric="episode_reward_mean") agent.restore(checkpoint[0][0]) return agent