Example #1
0
 def test_sample_from_early_done_env(self):
     ev = RolloutWorker(
         env_creator=lambda _: EarlyDoneMultiAgent(),
         policy_spec={
             "p0": PolicySpec(policy_class=MockPolicy),
             "p1": PolicySpec(policy_class=MockPolicy),
         },
         policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
         batch_mode="complete_episodes",
         rollout_fragment_length=1,
     )
     # This used to raise an Error due to the EarlyDoneMultiAgent
     # terminating at e.g. agent0 w/o publishing the observation for
     # agent1 anymore. This limitation is fixed and an env may
     # terminate at any time (as well as return rewards for any agent
     # at any time, even when that agent doesn't have an obs returned
     # in the same call to `step()`).
     ma_batch = ev.sample()
     # Make sure that agents took the correct (alternating timesteps)
     # path. Except for the last timestep, where both agents got
     # terminated.
     ag0_ts = ma_batch.policy_batches["p0"]["t"]
     ag1_ts = ma_batch.policy_batches["p1"]["t"]
     self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0))
     self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
Example #2
0
 def test_sample_from_early_done_env(self):
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(2)
     ev = RolloutWorker(
         env_creator=lambda _: EarlyDoneMultiAgent(),
         policy={
             "p0": (MockPolicy, obs_space, act_space, {}),
             "p1": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
         batch_mode="complete_episodes",
         rollout_fragment_length=1)
     self.assertRaisesRegexp(ValueError,
                             ".*don't have a last observation.*",
                             lambda: ev.sample())