Esempio n. 1
0
    def test_check_env_step_incorrect_error(self):
        step = MagicMock(return_value=(5, 5, True, {}))
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        sampled_obs = env.reset()
        env.step = step
        with pytest.raises(ValueError, match="The element returned by step"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {}))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a reward "):
            check_env(env)
        step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {}))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a done"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict"))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a info"):
            check_env(env)
Esempio n. 2
0
def make_multiagent(env_name_or_creator):
    deprecation_warning(
        old="ray.rllib.examples.env.multi_agent.make_multiagent",
        new="ray.rllib.env.multi_agent_env.make_multi_agent",
        error=False,
    )
    return make_multi_agent(env_name_or_creator)
Esempio n. 3
0
 def test_space_in_preferred_format(self):
     env = NestedMultiAgentEnv()
     spaces_in_preferred_format = env._check_if_space_maps_agent_id_to_sub_space(
     )
     assert spaces_in_preferred_format, "Space is not in preferred format"
     env2 = make_multi_agent("CartPole-v1")()
     spaces_in_preferred_format = env2._check_if_space_maps_agent_id_to_sub_space(
     )
     assert (not spaces_in_preferred_format
             ), "Space should not be in preferred format but is."
Esempio n. 4
0
 def _make_base_env(self):
     del self
     num_envs = 2
     sub_envs = [
         make_multi_agent("CartPole-v1")({
             "num_agents": 2
         }) for _ in range(num_envs)
     ]
     env = MultiAgentEnvWrapper(None, sub_envs, 2)
     return env
Esempio n. 5
0
 def test_bad_sample_function(self):
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     bad_action = {0: 2, 1: 2}
     env.action_space_sample = lambda *_: bad_action
     with pytest.raises(
             ValueError,
             match="The action collected from action_space_sample"):
         check_env(env)
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     bad_obs = {
         0: np.array([np.inf, np.inf, np.inf, np.inf]),
         1: np.array([np.inf, np.inf, np.inf, np.inf]),
     }
     env.observation_space_sample = lambda *_: bad_obs
     with pytest.raises(
             ValueError,
             match="The observation collected from observation_space_sample",
     ):
         check_env(env)
 def test_group_agents_wrapper(self):
     MultiAgentCartPole = make_multi_agent("CartPole-v0")
     grouped_ma_cartpole = GroupAgentsWrapper(
         env=MultiAgentCartPole({"num_agents": 4}),
         groups={"group1": [0, 1], "group2": [2, 3]},
     )
     obs = grouped_ma_cartpole.reset()
     self.assertTrue(len(obs) == 2)
     self.assertTrue("group1" in obs and "group2" in obs)
     self.assertTrue(isinstance(obs["group1"], list) and len(obs["group1"]) == 2)
     self.assertTrue(isinstance(obs["group2"], list) and len(obs["group2"]) == 2)
Esempio n. 7
0
 def test_spaces_sample_contain_not_in_preferred_format(self):
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     # this environment has spaces that are not in the preferred format
     # for multi-agent environments where the spaces not in the preferred
     # format, users must override the observation_space_contains,
     # action_space_contains observation_space_sample,
     # and action_space_sample methods in order to do proper checks
     obs = env.observation_space_sample()
     assert env.observation_space_contains(
         obs), "Observation space does not contain obs"
     action = env.action_space_sample()
     assert env.action_space_contains(
         action), "Action space does not contain action"
Esempio n. 8
0
    def test_check_incorrect_space_contains_functions_error(self):
        def bad_contains_function(self, x):
            raise ValueError("This is a bad contains function")

        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        env.observation_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your observation_space_contains function has some"):
            check_env(env)
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        bad_action = {0: 2, 1: 2}
        env.action_space_sample = lambda *_: bad_action
        with pytest.raises(
                ValueError,
                match="The action collected from action_space_sample"):
            check_env(env)

        env.action_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your action_space_contains function has some error"):
            check_env(env)
Esempio n. 9
0
 def test_check_env_reset_incorrect_error(self):
     reset = MagicMock(return_value=5)
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     env.reset = reset
     with pytest.raises(ValueError, match="The element returned by reset"):
         check_env(env)
     bad_obs = {
         0: np.array([np.inf, np.inf, np.inf, np.inf]),
         1: np.array([np.inf, np.inf, np.inf, np.inf]),
     }
     env.reset = lambda *_: bad_obs
     with pytest.raises(ValueError,
                        match="The observation collected from env"):
         check_env(env)
Esempio n. 10
0
    def test_check_env_step_incorrect_error(self):
        step = MagicMock(return_value=(5, 5, True, {}))
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        sampled_obs = env.reset()
        env.step = step
        with pytest.raises(ValueError, match="The element returned by step"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, {
            0: "Not a reward"
        }, {
            0: True
        }, {}))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return rewards"):
            check_env(env)
        step = MagicMock(return_value=(sampled_obs, {
            0: 5
        }, {
            0: "Not a bool"
        }, {}))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return dones"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, {
            0: 5
        }, {
            0: False
        }, {
            0: "Not a Dict"
        }))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return infos"):
            check_env(env)
Esempio n. 11
0
def make_multiagent(env_name_or_creator):
    return make_multi_agent(env_name_or_creator)
Esempio n. 12
0
        obs_dict = {self.i: self.last_obs[self.i]}
        self.i = (self.i + 1) % self.num
        return obs_dict

    def step(self, action_dict):
        assert len(self.dones) != len(self.agents)
        for i, action in action_dict.items():
            (
                self.last_obs[i],
                self.last_rew[i],
                self.last_done[i],
                self.last_info[i],
            ) = self.agents[i].step(action)
        obs = {self.i: self.last_obs[self.i]}
        rew = {self.i: self.last_rew[self.i]}
        done = {self.i: self.last_done[self.i]}
        info = {self.i: self.last_info[self.i]}
        if done[self.i]:
            rew[self.i] = 0
            self.dones.add(self.i)
        self.i = (self.i + 1) % self.num
        done["__all__"] = len(self.dones) == len(self.agents)
        return obs, rew, done, info


MultiAgentCartPole = make_multi_agent("CartPole-v0")
MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0")
MultiAgentPendulum = make_multi_agent("Pendulum-v1")
MultiAgentStatelessCartPole = make_multi_agent(
    lambda config: StatelessCartPole(config))
        Args:
            mode (str): One of "rgb", "human", or "ascii". See gym.Env for
                more information.

        Returns:
            Union[np.ndarray, bool]: An image to render or True (if rendering
                is handled entirely in here).
        """

        # Just generate a random image here for demonstration purposes.
        # Also see `gym/envs/classic_control/cartpole.py` for
        # an example on how to use a Viewer object.
        return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)


MultiAgentCustomRenderedEnv = make_multi_agent(
    lambda config: CustomRenderedEnv(config))

if __name__ == "__main__":
    # Note: Recording and rendering in this example
    # should work for both local_mode=True|False.
    ray.init(num_cpus=4)
    args = parser.parse_args()

    # Example config causing
    config = {
        # Also try common gym envs like: "CartPole-v0" or "Pendulum-v0".
        "env": (MultiAgentCustomRenderedEnv
                if args.multi_agent else CustomRenderedEnv),
        "env_config": {
            "corridor_length": 10,
            "max_steps": 100,