Esempio n. 1
0
 def test_basic_mock(self):
     env = BasicMultiAgent(4)
     obs = env.reset()
     self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
     for _ in range(24):
         obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
         self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
         self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1})
         self.assertEqual(
             done, {0: False, 1: False, 2: False, 3: False, "__all__": False}
         )
     obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
     self.assertEqual(done, {0: True, 1: True, 2: True, 3: True, "__all__": True})
Esempio n. 2
0
    def test_wrap_multi_agent_env(self):
        record_env_dir = os.popen("mktemp -d").read()[:-1]
        print(f"tmp dir for videos={record_env_dir}")

        if not os.path.exists(record_env_dir):
            sys.exit(1)

        wrapped = record_env_wrapper(env=BasicMultiAgent(3),
                                     record_env=record_env_dir,
                                     log_dir="",
                                     policy_config={
                                         "in_evaluation": False,
                                     })
        # Type is VideoMonitor.
        self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
        self.assertTrue(isinstance(wrapped, VideoMonitor))

        wrapped.reset()

        # BasicMultiAgent is hardcoded to run 25-step episodes.
        for i in range(25):
            wrapped.step({0: 0, 1: 0, 2: 0})

        # Expect one video file to have been produced in the tmp dir.
        os.chdir(record_env_dir)
        ls = glob.glob("*.mp4")
        self.assertTrue(len(ls) == 1)
Esempio n. 3
0
 def test_external_multi_agent_env_truncate_episodes(self):
     agents = 4
     ev = RolloutWorker(
         env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
         policy_spec=MockPolicy,
         rollout_fragment_length=40,
         batch_mode="truncate_episodes")
     for _ in range(3):
         batch = ev.sample()
         self.assertEqual(batch.count, 160)
         self.assertEqual(len(np.unique(batch["agent_index"])), agents)
Esempio n. 4
0
 def test_multi_agent_sample_with_horizon(self):
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy_spec={
             "p0": PolicySpec(policy_class=MockPolicy),
             "p1": PolicySpec(policy_class=MockPolicy),
         },
         policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)),
         episode_horizon=10,  # test with episode horizon set
         rollout_fragment_length=50)
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
Esempio n. 5
0
 def test_multi_agent_sample_async_remote(self):
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy_spec={
             "p0": PolicySpec(policy_class=MockPolicy),
             "p1": PolicySpec(policy_class=MockPolicy),
         },
         policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
         rollout_fragment_length=50,
         num_envs=4,
         remote_worker_envs=True)
     batch = ev.sample()
     self.assertEqual(batch.count, 200)
Esempio n. 6
0
    def test_vectorize_basic(self):
        env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 2)
        obs, rew, dones, _, _ = env.poll()
        self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
        self.assertEqual(rew, {0: {}, 1: {}})
        self.assertEqual(
            dones,
            {
                0: {"__all__": False},
                1: {"__all__": False},
            },
        )
        for _ in range(24):
            env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
            obs, rew, dones, _, _ = env.poll()
            self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
            self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
            self.assertEqual(
                dones,
                {
                    0: {0: False, 1: False, "__all__": False},
                    1: {0: False, 1: False, "__all__": False},
                },
            )
        env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
        obs, rew, dones, _, _ = env.poll()
        self.assertEqual(
            dones,
            {
                0: {0: True, 1: True, "__all__": True},
                1: {0: True, 1: True, "__all__": True},
            },
        )

        # Reset processing
        self.assertRaises(
            ValueError, lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
        )
        self.assertEqual(env.try_reset(0), {0: {0: 0, 1: 0}})
        self.assertEqual(env.try_reset(1), {1: {0: 0, 1: 0}})
        env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
        obs, rew, dones, _, _ = env.poll()
        self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
        self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
        self.assertEqual(
            dones,
            {
                0: {0: False, 1: False, "__all__": False},
                1: {0: False, 1: False, "__all__": False},
            },
        )
Esempio n. 7
0
 def test_external_multi_agent_env_sample(self):
     agents = 2
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
         policy_spec={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
         rollout_fragment_length=50)
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
Esempio n. 8
0
 def test_multi_agent_sample_with_horizon(self):
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
         episode_horizon=10,  # test with episode horizon set
         rollout_fragment_length=50)
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
Esempio n. 9
0
 def test_wrap_multi_agent_env(self):
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(10),
         policy_spec=MockPolicy,
         policy_config={
             "in_evaluation": False,
         },
     )
     # Make sure we can properly sample from the wrapped env.
     ev.sample()
     # Make sure the resulting environment is indeed still an
     self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
     self.assertTrue(isinstance(ev.env, gym.Env))
     ev.stop()
Esempio n. 10
0
    def test_wrap_multi_agent_env(self):
        wrapped = record_env_wrapper(env=BasicMultiAgent(3),
                                     record_env=tempfile.gettempdir(),
                                     log_dir="",
                                     policy_config={
                                         "in_evaluation": False,
                                     })
        # Type is VideoMonitor.
        self.assertTrue(isinstance(wrapped, wrappers.Monitor))
        self.assertTrue(isinstance(wrapped, VideoMonitor))

        wrapped.reset()
        # BasicMultiAgent is hardcoded to run 25-step episodes.
        for i in range(25):
            wrapped.step({0: 0, 1: 0, 2: 0})
Esempio n. 11
0
 def test_multi_agent_sample_async_remote(self):
     # Allow to be run via Unittest.
     ray.init(num_cpus=4, ignore_reinit_error=True)
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
         rollout_fragment_length=50,
         num_envs=4,
         remote_worker_envs=True)
     batch = ev.sample()
     self.assertEqual(batch.count, 200)
Esempio n. 12
0
 def test_multi_agent_sample(self):
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
         rollout_fragment_length=50)
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
     self.assertEqual(batch.policy_batches["p0"].count, 150)
     self.assertEqual(batch.policy_batches["p1"].count, 100)
     self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
                      list(range(25)) * 6)
Esempio n. 13
0
    def test_multi_agent_sample(self):
        def policy_mapping_fn(agent_id, episode, worker, **kwargs):
            return "p{}".format(agent_id % 2)

        ev = RolloutWorker(env_creator=lambda _: BasicMultiAgent(5),
                           policy_spec={
                               "p0": PolicySpec(policy_class=MockPolicy),
                               "p1": PolicySpec(policy_class=MockPolicy),
                           },
                           policy_mapping_fn=policy_mapping_fn,
                           rollout_fragment_length=50)
        batch = ev.sample()
        self.assertEqual(batch.count, 50)
        self.assertEqual(batch.policy_batches["p0"].count, 150)
        self.assertEqual(batch.policy_batches["p1"].count, 100)
        self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
                         list(range(25)) * 6)
Esempio n. 14
0
 def test_multi_agent_sample_sync_remote(self):
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy_spec={
             "p0": PolicySpec(policy_class=MockPolicy),
             "p1": PolicySpec(policy_class=MockPolicy),
         },
         # This signature will raise a soft-deprecation warning due
         # to the new signature we are using (agent_id, episode, **kwargs),
         # but should not break this test.
         policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
         rollout_fragment_length=50,
         num_envs=4,
         remote_worker_envs=True,
         remote_env_batch_wait_ms=99999999)
     batch = ev.sample()
     self.assertEqual(batch.count, 200)
Esempio n. 15
0
 def test_multi_agent_sample_sync_remote(self):
     # Allow to be run via Unittest.
     ray.init(num_cpus=4, ignore_reinit_error=True)
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: BasicMultiAgent(5),
         policy_spec={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         # This signature will raise a soft-deprecation warning due
         # to the new signature we are using (agent_id, episode, **kwargs),
         # but should not break this test.
         policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
         rollout_fragment_length=50,
         num_envs=4,
         remote_worker_envs=True,
         remote_env_batch_wait_ms=99999999)
     batch = ev.sample()
     self.assertEqual(batch.count, 200)
Esempio n. 16
0
 def test_no_reset_until_poll(self):
     env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 1)
     self.assertFalse(env.get_sub_environments()[0].resetted)
     env.poll()
     self.assertTrue(env.get_sub_environments()[0].resetted)
Esempio n. 17
0
 def test_no_reset_until_poll(self):
     env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 1)
     self.assertFalse(env.get_unwrapped()[0].resetted)
     env.poll()
     self.assertTrue(env.get_unwrapped()[0].resetted)