Пример #1
0
    def test_check_env_step_incorrect_error(self):
        step = MagicMock(return_value=(5, 5, True, {}))
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        sampled_obs = env.reset()
        env.step = step
        with pytest.raises(ValueError, match="The element returned by step"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {}))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a reward "):
            check_env(env)
        step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {}))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a done"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict"))
        env.step = step
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a info"):
            check_env(env)
Пример #2
0
 def test_reset(self):
     reset = MagicMock(return_value=5)
     env = RandomEnv()
     env.reset = reset
     # check reset with out of bounds fails
     error = ".*The observation collected from env.reset().*"
     with pytest.raises(ValueError, match=error):
         check_env(env)
     # check reset with obs of incorrect type fails
     reset = MagicMock(return_value=float(1))
     env.reset = reset
     with pytest.raises(ValueError, match=error):
         check_env(env)
     del env
Пример #3
0
 def test_check_env_reset_incorrect_error(self):
     reset = MagicMock(return_value=5)
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     env.reset = reset
     with pytest.raises(ValueError, match="The element returned by reset"):
         check_env(env)
     bad_obs = {
         0: np.array([np.inf, np.inf, np.inf, np.inf]),
         1: np.array([np.inf, np.inf, np.inf, np.inf]),
     }
     env.reset = lambda *_: bad_obs
     with pytest.raises(ValueError,
                        match="The observation collected from env"):
         check_env(env)
Пример #4
0
    def test_check_space_contains_functions_errors(self):
        def bad_contains_function(self, x):
            raise ValueError("This is a bad contains function")

        env = self._make_base_env()

        env.observation_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your observation_space_contains function has some"):
            check_env(env)

        env = self._make_base_env()
        env.action_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your action_space_contains function has some error"):
            check_env(env)
Пример #5
0
 def test_bad_sample_function(self):
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     bad_action = {0: 2, 1: 2}
     env.action_space_sample = lambda *_: bad_action
     with pytest.raises(
             ValueError,
             match="The action collected from action_space_sample"):
         check_env(env)
     env = make_multi_agent("CartPole-v1")({"num_agents": 2})
     bad_obs = {
         0: np.array([np.inf, np.inf, np.inf, np.inf]),
         1: np.array([np.inf, np.inf, np.inf, np.inf]),
     }
     env.observation_space_sample = lambda *_: bad_obs
     with pytest.raises(
             ValueError,
             match="The observation collected from observation_space_sample",
     ):
         check_env(env)
Пример #6
0
    def test_bad_sample_function(self):
        env = self._make_base_env()
        bad_action = {0: {0: 2, 1: 2}}
        env.action_space_sample = lambda *_: bad_action
        with pytest.raises(
                ValueError,
                match="The action collected from action_space_sample"):
            check_env(env)

        env = self._make_base_env()
        bad_obs = {
            0: {
                0: np.array([np.inf, np.inf, np.inf, np.inf]),
                1: np.array([np.inf, np.inf, np.inf, np.inf]),
            }
        }
        env.observation_space_sample = lambda *_: bad_obs
        with pytest.raises(
                ValueError,
                match="The observation collected from observation_space_sample",
        ):
            check_env(env)
Пример #7
0
    def test_check_env_step_incorrect_error(self):
        good_reward = {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}
        good_done = {0: {0: False, 1: False}, 1: {0: False, 1: False}}
        good_info = {0: {0: {}, 1: {}}, 1: {0: {}, 1: {}}}

        env = self._make_base_env()
        bad_multi_env_dict_obs = {0: 1, 1: {0: np.zeros(4)}}
        poll = MagicMock(return_value=(bad_multi_env_dict_obs, good_reward,
                                       good_done, good_info, {}))
        env.poll = poll
        with pytest.raises(
                ValueError,
                match="The element returned by step, "
                "next_obs has values that are not"
                " MultiAgentDicts",
        ):
            check_env(env)

        bad_reward = {0: {0: "not_reward", 1: 1}}
        good_obs = env.observation_space_sample()
        poll = MagicMock(return_value=(good_obs, bad_reward, good_done,
                                       good_info, {}))
        env.poll = poll
        with pytest.raises(AssertionError,
                           match="Your step function must "
                           "return a rewards that are"):
            check_env(env)
        bad_done = {0: {0: "not_done", 1: False}}
        poll = MagicMock(return_value=(good_obs, good_reward, bad_done,
                                       good_info, {}))
        env.poll = poll
        with pytest.raises(
                AssertionError,
                match="Your step function must "
                "return a done that is "
                "boolean.",
        ):
            check_env(env)
        bad_info = {0: {0: "not_info", 1: {}}}
        poll = MagicMock(return_value=(good_obs, good_reward, good_done,
                                       bad_info, {}))
        env.poll = poll
        with pytest.raises(
                AssertionError,
                match="Your step function must"
                " return a info that is a "
                "dict.",
        ):
            check_env(env)
Пример #8
0
    def test_step(self):
        step = MagicMock(return_value=(5, 5, True, {}))
        env = RandomEnv()
        env.step = step
        error = ".*The observation collected from env.step.*"
        with pytest.raises(ValueError, match=error):
            check_env(env)

        # check reset that returns obs of incorrect type fails
        step = MagicMock(return_value=(float(1), 5, True, {}))
        env.step = step
        with pytest.raises(ValueError, match=error):
            check_env(env)

        # check step that returns reward of non float/int fails
        step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
        env.step = step
        error = ("Your step function must return a reward that is integer or "
                 "float.")
        with pytest.raises(AssertionError, match=error):
            check_env(env)

        # check step that returns a non bool fails
        step = MagicMock(return_value=(1, float(5), "not a valid done signal",
                                       {}))
        env.step = step
        error = "Your step function must return a done that is a boolean."
        with pytest.raises(AssertionError, match=error):
            check_env(env)

        # check step that returns a non dict fails
        step = MagicMock(return_value=(1, float(5), True,
                                       "not a valid env info"))
        env.step = step
        error = "Your step function must return a info that is a dict."
        with pytest.raises(AssertionError, match=error):
            check_env(env)
        del env
Пример #9
0
    def test_check_env_reset_incorrect_error(self):
        reset = MagicMock(return_value=5)
        env = self._make_base_env()
        env.try_reset = reset
        with pytest.raises(ValueError,
                           match=("MultiEnvDict. Instead, it is of"
                                  " type")):
            check_env(env)
        obs_with_bad_agent_ids = {
            2: np.array([np.inf, np.inf, np.inf, np.inf]),
            1: np.array([np.inf, np.inf, np.inf, np.inf]),
        }
        obs_with_bad_env_ids = {"bad_env_id": obs_with_bad_agent_ids}
        reset = MagicMock(return_value=obs_with_bad_env_ids)
        env.try_reset = reset
        with pytest.raises(ValueError,
                           match="has dict keys that don't "
                           "correspond to"):
            check_env(env)
        reset = MagicMock(return_value={0: obs_with_bad_agent_ids})
        env.try_reset = reset

        with pytest.raises(
                ValueError,
                match="The element returned by "
                "try_reset has agent_ids that are"
                " not the names of the agents",
        ):
            check_env(env)
        out_of_bounds_obs = {
            0: {
                0: np.array([np.inf, np.inf, np.inf, np.inf]),
                1: np.array([np.inf, np.inf, np.inf, np.inf]),
            }
        }
        env.try_reset = lambda *_: out_of_bounds_obs
        with pytest.raises(ValueError,
                           match="The observation collected from "
                           "try_reset"):
            check_env(env)
        del env
Пример #10
0
    def test_check_env_step_incorrect_error(self):
        step = MagicMock(return_value=(5, 5, True, {}))
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        sampled_obs = env.reset()
        env.step = step
        with pytest.raises(ValueError, match="The element returned by step"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, {
            0: "Not a reward"
        }, {
            0: True
        }, {}))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return rewards"):
            check_env(env)
        step = MagicMock(return_value=(sampled_obs, {
            0: 5
        }, {
            0: "Not a bool"
        }, {}))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return dones"):
            check_env(env)

        step = MagicMock(return_value=(sampled_obs, {
            0: 5
        }, {
            0: False
        }, {
            0: "Not a Dict"
        }))
        env.step = step
        with pytest.raises(ValueError,
                           match="Your step function must return infos"):
            check_env(env)
Пример #11
0
    def test_check_incorrect_space_contains_functions_error(self):
        def bad_contains_function(self, x):
            raise ValueError("This is a bad contains function")

        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        env.observation_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your observation_space_contains function has some"):
            check_env(env)
        env = make_multi_agent("CartPole-v1")({"num_agents": 2})
        bad_action = {0: 2, 1: 2}
        env.action_space_sample = lambda *_: bad_action
        with pytest.raises(
                ValueError,
                match="The action collected from action_space_sample"):
            check_env(env)

        env.action_space_contains = bad_contains_function
        with pytest.raises(
                ValueError,
                match="Your action_space_contains function has some error"):
            check_env(env)
Пример #12
0
 def test_check_correct_env(self):
     env = self._make_base_env()
     check_env(env)
     env = gym.make("CartPole-v0")
     env = convert_to_base_env(env)
     check_env(env)
Пример #13
0
 def test_check_correct_env(self):
     env = self._make_base_env()
     check_env(env)