def load_env_dict(cls, filename, load_from_package=None):

        if load_from_package is not None:
            from importlib_resources import read_binary
            load_data = read_binary(load_from_package, filename)
        else:
            with open(filename, "rb") as file_in:
                load_data = file_in.read()

        if filename.endswith("mpk"):
            env_dict = msgpack.unpackb(load_data,
                                       use_list=False,
                                       encoding="utf-8")
        elif filename.endswith("pkl"):
            env_dict = pickle.loads(load_data)
        else:
            print(f"filename {filename} must end with either pkl or mpk")
            env_dict = {}

        # Replace the agents tuple with EnvAgent objects
        if "agents_static" in env_dict:
            env_dict["agents"] = EnvAgent.load_legacy_static_agent(
                env_dict["agents_static"])
            # remove the legacy key
            del env_dict["agents_static"]
        elif "agents" in env_dict:
            env_dict["agents"] = [
                EnvAgent(*d[0:12]) for d in env_dict["agents"]
            ]

        return env_dict
    def generator(rail: GridTransitionMap,
                  num_agents: int,
                  hints: Any = None,
                  num_resets: int = 0) -> Schedule:
        if load_from_package is not None:
            from importlib_resources import read_binary
            load_data = read_binary(load_from_package, filename)
        else:
            with open(filename, "rb") as file_in:
                load_data = file_in.read()
        data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
        if "agents_static" in data:
            agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
        else:
            agents = [EnvAgent(*d[0:12]) for d in data["agents"]]

        # setup with loaded data
        agents_position = [a.initial_position for a in agents]
        agents_direction = [a.direction for a in agents]
        agents_target = [a.target for a in agents]
        agents_speed = [a.speed_data['speed'] for a in agents]
        agents_malfunction = [
            a.malfunction_data['malfunction_rate'] for a in agents
        ]

        return Schedule(agent_positions=agents_position,
                        agent_directions=agents_direction,
                        agent_targets=agents_target,
                        agent_speeds=agents_speed,
                        agent_malfunction_rates=None)
Esempio n. 3
0
    def set_full_state_msg(self, msg_data):
        """
        Sets environment state with msgdata object passed as argument

        Parameters
        -------
        msg_data: msgpack object
        """
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
        # agents are always reset as not moving
        if "agents_static" in data:
            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
        else:
            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)