Пример #1
0
    def test_step_1_valid_when_env_not_done(self, env_spec: EnvSpec,
                                            helpers: Helpers) -> None:
        wrapped_env, _ = helpers.get_wrapped_env(env_spec)

        # Seed environment since we are sampling actions.
        # We need to seed env and action space.
        random_seed = 42
        wrapped_env.seed(random_seed)
        helpers.seed_action_space(wrapped_env, random_seed)

        #  Get agent names from env
        agents = wrapped_env.agents

        # Parallel env_types
        if env_spec.env_type == EnvType.Parallel:
            test_agents_actions = {
                agent: wrapped_env.action_spaces[agent].sample()
                for agent in agents
            }
            with patch.object(wrapped_env, "step") as parallel_step:
                parallel_step.return_value = None, None, None, None
                _ = wrapped_env.step(test_agents_actions)
                parallel_step.assert_called_once_with(test_agents_actions)

        # Sequential env_types
        elif env_spec.env_type == EnvType.Sequential:
            for agent in agents:
                with patch.object(wrapped_env, "step") as seq_step:
                    seq_step.return_value = None
                    test_agent_action = wrapped_env.action_spaces[
                        agent].sample()
                    _ = wrapped_env.step(test_agent_action)
                    seq_step.assert_called_once_with(test_agent_action)
Пример #2
0
    def test_step_0_valid_when_env_not_done(self, env_spec: EnvSpec,
                                            helpers: Helpers) -> None:
        wrapped_env, _ = helpers.get_wrapped_env(env_spec)

        # Seed environment since we are sampling actions.
        # We need to seed env and action space.
        random_seed = 42
        wrapped_env.seed(random_seed)
        helpers.seed_action_space(wrapped_env, random_seed)

        #  Get agent names from env
        agents = wrapped_env.agents

        timestep = wrapped_env.reset()
        if type(timestep) == tuple:
            initial_dm_env_timestep, env_extras = timestep
        else:
            initial_dm_env_timestep = timestep
        # Parallel env_types
        if env_spec.env_type == EnvType.Parallel:
            test_agents_actions = {
                agent: wrapped_env.action_spaces[agent].sample()
                for agent in agents
            }
            curr_dm_timestep = wrapped_env.step(test_agents_actions)

            for agent in wrapped_env.agents:
                assert not np.array_equal(
                    initial_dm_env_timestep.observation[agent].observation,
                    curr_dm_timestep.observation[agent].observation,
                ), "Failed to update observations."

        # Sequential env_types
        elif env_spec.env_type == EnvType.Sequential:
            curr_dm_timestep = initial_dm_env_timestep
            for agent in agents:
                if env_spec.env_source == EnvSource.OpenSpiel:
                    test_agent_actions = np.random.choice(
                        np.where(
                            curr_dm_timestep.observation.legal_actions)[0])
                else:
                    test_agent_actions = wrapped_env.action_spaces[
                        agent].sample()

                curr_dm_timestep = wrapped_env.step(test_agent_actions)

                assert not np.array_equal(
                    initial_dm_env_timestep.observation.observation,
                    curr_dm_timestep.observation.observation,
                ), "Failed to update observations."

        assert (wrapped_env._reset_next_step is
                False), "Failed to set _reset_next_step correctly."
        assert curr_dm_timestep.reward is not None, "Failed to set rewards."
        assert (curr_dm_timestep.step_type is
                dm_env.StepType.MID), "Failed to update step type."
Пример #3
0
    def test_step_2_invalid_when_env_done(self, env_spec: EnvSpec,
                                          helpers: Helpers,
                                          monkeypatch: MonkeyPatch) -> None:
        wrapped_env, _ = helpers.get_wrapped_env(env_spec)

        if env_spec.env_source == EnvSource.OpenSpiel:
            pytest.skip("Open Spiel does not use the .last() method")

        # Seed environment since we are sampling actions.
        # We need to seed env and action space.
        random_seed = 42
        wrapped_env.seed(random_seed)
        helpers.seed_action_space(wrapped_env, random_seed)

        #  Get agent names from env
        _ = wrapped_env.reset()
        agents = wrapped_env.agents

        # Parallel env_types
        if env_spec.env_type == EnvType.Parallel:
            test_agents_actions = {
                agent: wrapped_env.action_spaces[agent].sample()
                for agent in agents
            }

            monkeypatch.setattr(wrapped_env, "env_done", helpers.mock_done)

            curr_dm_timestep = wrapped_env.step(test_agents_actions)

            helpers.assert_env_reset(wrapped_env, curr_dm_timestep, env_spec)

        # Sequential env_types
        # TODO (Kale-ab): Make this part below less reliant on PZ.
        elif env_spec.env_type == EnvType.Sequential:
            n_agents = wrapped_env.num_agents

            # Mock functions to act like PZ environment is done
            def mock_environment_last() -> Any:
                observe = wrapped_env.observation_spaces[agent].sample()
                reward = 0.0
                done = True
                info: Dict = {}
                return observe, reward, done, info

            def mock_step(action: types.Action) -> None:
                return

            # Mocks certain functions - if functions don't exist, error is not thrown.
            monkeypatch.setattr(wrapped_env._environment,
                                "last",
                                mock_environment_last,
                                raising=False)
            monkeypatch.setattr(wrapped_env._environment,
                                "step",
                                mock_step,
                                raising=False)

            for index, (agent) in enumerate(wrapped_env.agent_iter(n_agents)):
                test_agent_actions = wrapped_env.action_spaces[agent].sample()

                # Mock whole env being done when you reach final agent
                if index == n_agents - 1:
                    monkeypatch.setattr(
                        wrapped_env,
                        "env_done",
                        helpers.mock_done,
                    )

                # Mock update has occurred in step
                monkeypatch.setattr(wrapped_env._environment,
                                    "_has_updated",
                                    True,
                                    raising=False)

                curr_dm_timestep = wrapped_env.step(test_agent_actions)

                # Check each agent is on last step
                assert (curr_dm_timestep.step_type is
                        dm_env.StepType.LAST), "Failed to update step type."

            helpers.assert_env_reset(wrapped_env, curr_dm_timestep, env_spec)

        assert (wrapped_env._reset_next_step is
                True), "Failed to set _reset_next_step correctly."
        assert (curr_dm_timestep.step_type is
                dm_env.StepType.LAST), "Failed to update step type."