Пример #1
0
    def test_wrapper_env_reset(self, env_spec: EnvSpec,
                               helpers: Helpers) -> None:
        wrapped_env, _ = helpers.get_wrapped_env(env_spec)
        num_agents = len(wrapped_env.agents)

        timestep = wrapped_env.reset()
        if type(timestep) == tuple:
            dm_env_timestep, env_extras = timestep
        else:
            dm_env_timestep = timestep
        props_which_should_not_be_none = [
            dm_env_timestep, dm_env_timestep.observation
        ]

        assert helpers.verify_all_props_not_none(
            props_which_should_not_be_none), "Failed to ini dm_env_timestep."
        assert (dm_env_timestep.step_type == dm_env.StepType.FIRST
                ), "Failed to have correct StepType."
        if (env_spec.env_name == "tic_tac_toe"
                and env_spec.env_source == EnvSource.OpenSpiel
                and env_spec.env_type == EnvType.Sequential):
            pytest.skip(
                "This test is only applicable to parralel wrappers and only works "
                "for the provided PZ sequential envs because they have 3 agents, and"
                "an OLT has length of 3 (a bug, i'd say)")
        assert (len(dm_env_timestep.observation) == num_agents
                ), "Failed to generate observation for all agents."
        assert wrapped_env._reset_next_step is False, "_reset_next_step not set."

        helpers.assert_env_reset(wrapped_env, dm_env_timestep, env_spec)
Пример #2
0
    def test_step_2_invalid_when_env_done(self, env_spec: EnvSpec,
                                          helpers: Helpers,
                                          monkeypatch: MonkeyPatch) -> None:
        wrapped_env, _ = helpers.get_wrapped_env(env_spec)

        if env_spec.env_source == EnvSource.OpenSpiel:
            pytest.skip("Open Spiel does not use the .last() method")

        # Seed environment since we are sampling actions.
        # We need to seed env and action space.
        random_seed = 42
        wrapped_env.seed(random_seed)
        helpers.seed_action_space(wrapped_env, random_seed)

        #  Get agent names from env
        _ = wrapped_env.reset()
        agents = wrapped_env.agents

        # Parallel env_types
        if env_spec.env_type == EnvType.Parallel:
            test_agents_actions = {
                agent: wrapped_env.action_spaces[agent].sample()
                for agent in agents
            }

            monkeypatch.setattr(wrapped_env, "env_done", helpers.mock_done)

            curr_dm_timestep = wrapped_env.step(test_agents_actions)

            helpers.assert_env_reset(wrapped_env, curr_dm_timestep, env_spec)

        # Sequential env_types
        # TODO (Kale-ab): Make this part below less reliant on PZ.
        elif env_spec.env_type == EnvType.Sequential:
            n_agents = wrapped_env.num_agents

            # Mock functions to act like PZ environment is done
            def mock_environment_last() -> Any:
                observe = wrapped_env.observation_spaces[agent].sample()
                reward = 0.0
                done = True
                info: Dict = {}
                return observe, reward, done, info

            def mock_step(action: types.Action) -> None:
                return

            # Mocks certain functions - if functions don't exist, error is not thrown.
            monkeypatch.setattr(wrapped_env._environment,
                                "last",
                                mock_environment_last,
                                raising=False)
            monkeypatch.setattr(wrapped_env._environment,
                                "step",
                                mock_step,
                                raising=False)

            for index, (agent) in enumerate(wrapped_env.agent_iter(n_agents)):
                test_agent_actions = wrapped_env.action_spaces[agent].sample()

                # Mock whole env being done when you reach final agent
                if index == n_agents - 1:
                    monkeypatch.setattr(
                        wrapped_env,
                        "env_done",
                        helpers.mock_done,
                    )

                # Mock update has occurred in step
                monkeypatch.setattr(wrapped_env._environment,
                                    "_has_updated",
                                    True,
                                    raising=False)

                curr_dm_timestep = wrapped_env.step(test_agent_actions)

                # Check each agent is on last step
                assert (curr_dm_timestep.step_type is
                        dm_env.StepType.LAST), "Failed to update step type."

            helpers.assert_env_reset(wrapped_env, curr_dm_timestep, env_spec)

        assert (wrapped_env._reset_next_step is
                True), "Failed to set _reset_next_step correctly."
        assert (curr_dm_timestep.step_type is
                dm_env.StepType.LAST), "Failed to update step type."