예제 #1
0
def test_dead_end():
    transitions = RailEnvTransitions()

    straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
    straight_horizontal = transitions.rotate_transition(straight_vertical,
                                                        90)

    dead_end_from_south = int('0010000000000000', 2)  # Case 7 - dead end

    # We instantiate the following railway
    # O->-- where > is the train and O the target. After 6 steps,
    # the train should be done.

    rail_map = np.array(
        [[transitions.rotate_transition(dead_end_from_south, 270)] +
         [straight_horizontal] * 3 +
         [transitions.rotate_transition(dead_end_from_south, 90)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    # We try the configuration in the 4 directions:
    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]

    # In the vertical configuration:
    rail_map = np.array(
        [[dead_end_from_south]] + [[straight_vertical]] * 3 +
        [[transitions.rotate_transition(dead_end_from_south, 180)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]

    rail_env.reset()
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
    def generator(rail: GridTransitionMap,
                  num_agents: int,
                  hints: Any = None,
                  num_resets: int = 0) -> Schedule:
        if load_from_package is not None:
            from importlib_resources import read_binary
            load_data = read_binary(load_from_package, filename)
        else:
            with open(filename, "rb") as file_in:
                load_data = file_in.read()
        data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
        if "agents_static" in data:
            agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
        else:
            agents = [EnvAgent(*d[0:12]) for d in data["agents"]]

        # setup with loaded data
        agents_position = [a.initial_position for a in agents]
        agents_direction = [a.direction for a in agents]
        agents_target = [a.target for a in agents]
        agents_speed = [a.speed_data['speed'] for a in agents]
        agents_malfunction = [
            a.malfunction_data['malfunction_rate'] for a in agents
        ]

        return Schedule(agent_positions=agents_position,
                        agent_directions=agents_direction,
                        agent_targets=agents_target,
                        agent_speeds=agents_speed,
                        agent_malfunction_rates=None)
    def load_env_dict(cls, filename, load_from_package=None):

        if load_from_package is not None:
            from importlib_resources import read_binary
            load_data = read_binary(load_from_package, filename)
        else:
            with open(filename, "rb") as file_in:
                load_data = file_in.read()

        if filename.endswith("mpk"):
            env_dict = msgpack.unpackb(load_data,
                                       use_list=False,
                                       encoding="utf-8")
        elif filename.endswith("pkl"):
            env_dict = pickle.loads(load_data)
        else:
            print(f"filename {filename} must end with either pkl or mpk")
            env_dict = {}

        # Replace the agents tuple with EnvAgent objects
        if "agents_static" in env_dict:
            env_dict["agents"] = EnvAgent.load_legacy_static_agent(
                env_dict["agents_static"])
            # remove the legacy key
            del env_dict["agents_static"]
        elif "agents" in env_dict:
            env_dict["agents"] = [
                EnvAgent(*d[0:12]) for d in env_dict["agents"]
            ]

        return env_dict
def test_load_env():
    env = RailEnv(10, 10)
    env.reset()
    env.load_resource('env_data.tests', 'test-10x10.mpk')

    agent_static = EnvAgent((0, 0), 2, (5, 5), False)
    env.add_agent(agent_static)
    assert env.get_num_agents() == 1
예제 #5
0
def test_load_env():
    #env = RailEnv(10, 10)
    #env.reset()
    # env.load_resource('env_data.tests', 'test-10x10.mpk')
    env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk")
    #env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk")

    agent_static = EnvAgent((0, 0), 2, (5, 5), False)
    env.add_agent(agent_static)
    assert env.get_num_agents() == 1
예제 #6
0
    def click_agent(self, cell_row_col):
        """ The user has clicked on a cell -
            * If there is an agent, select it
              * If that agent was already selected, then deselect it
            * If there is no agent selected, and no agent in the cell, create one
            * If there is an agent selected, and no agent in the cell, move the selected agent to the cell
        """

        # Has the user clicked on an existing agent?
        agent_idx = self.find_agent_at(cell_row_col)

        # This is in case we still have a selected agent even though the env has been recreated
        # with no agents.
        if (self.selected_agent is not None) and (self.selected_agent > len(self.env.agents)):
            self.selected_agent = None

        # Defensive coding below - for cell_row_col to be a tuple, not a numpy array:
        # numpy array breaks various things when loading the env.

        if agent_idx is None:
            # No
            if self.selected_agent is None:
                # Create a new agent and select it.
                agent = EnvAgent(initial_position=tuple(cell_row_col),
                    initial_direction=0, 
                    direction=0,
                    target=tuple(cell_row_col), 
                    moving=False,
                    )
                self.selected_agent = self.env.add_agent(agent)
                # self.env.set_agent_active(agent)
                self.view.oRT.update_background()
            else:
                # Move the selected agent to this cell
                agent = self.env.agents[self.selected_agent]
                agent.initial_position = tuple(cell_row_col)
                agent.position = tuple(cell_row_col)
                agent.old_position = tuple(cell_row_col)
        else:
            # Yes
            # Have they clicked on the agent already selected?
            if self.selected_agent is not None and agent_idx == self.selected_agent:
                # Yes - deselect the agent
                self.selected_agent = None
            else:
                # No - select the agent
                self.selected_agent = agent_idx

        self.redraw()
예제 #7
0
    def set_full_state_msg(self, msg_data):
        """
        Sets environment state with msgdata object passed as argument

        Parameters
        -------
        msg_data: msgpack object
        """
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
        # agents are always reset as not moving
        self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
예제 #8
0
파일: editor.py 프로젝트: Zeii2024/RL
    def click_agent(self, cell_row_col):
        """ The user has clicked on a cell -
            * If there is an agent, select it
              * If that agent was already selected, then deselect it
            * If there is no agent selected, and no agent in the cell, create one
            * If there is an agent selected, and no agent in the cell, move the selected agent to the cell
        """

        # Has the user clicked on an existing agent?
        agent_idx = self.find_agent_at(cell_row_col)

        if agent_idx is None:
            # No
            if self.selected_agent is None:
                # Create a new agent and select it.
                agent = EnvAgent(position=cell_row_col,
                                 direction=0,
                                 target=cell_row_col,
                                 moving=False)
                self.selected_agent = self.env.add_agent(agent)
                self.view.oRT.update_background()
            else:
                # Move the selected agent to this cell
                agent = self.env.agents[self.selected_agent]
                agent.position = cell_row_col
                agent.old_position = cell_row_col
        else:
            # Yes
            # Have they clicked on the agent already selected?
            if self.selected_agent is not None and agent_idx == self.selected_agent:
                # Yes - deselect the agent
                self.selected_agent = None
            else:
                # No - select the agent
                self.selected_agent = agent_idx

        self.redraw()
    def deprecated_set_full_state_dist_msg(self, msg_data):
        """
        Sets environment grid state and distance map with msgdata object passed as argument

        Parameters
        -------
        msg_data: msgpack object
        """
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
        # agents are always reset as not moving
        if "agents_static" in data:
            self.agents = EnvAgent.load_legacy_static_agent(
                data["agents_static"])
        else:
            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
        if "distance_map" in data.keys():
            self.distance_map.set(data["distance_map"])
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(
            list(range(self.get_num_agents())) + ["__all__"], False)