def test_sample_from_early_done_env(self): ev = RolloutWorker( env_creator=lambda _: EarlyDoneMultiAgent(), policy_spec={ "p0": PolicySpec(policy_class=MockPolicy), "p1": PolicySpec(policy_class=MockPolicy), }, policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)), batch_mode="complete_episodes", rollout_fragment_length=1, ) # This used to raise an Error due to the EarlyDoneMultiAgent # terminating at e.g. agent0 w/o publishing the observation for # agent1 anymore. This limitation is fixed and an env may # terminate at any time (as well as return rewards for any agent # at any time, even when that agent doesn't have an obs returned # in the same call to `step()`). ma_batch = ev.sample() # Make sure that agents took the correct (alternating timesteps) # path. Except for the last timestep, where both agents got # terminated. ag0_ts = ma_batch.policy_batches["p0"]["t"] ag1_ts = ma_batch.policy_batches["p1"]["t"] self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0)) self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
def test_sample_from_early_done_env(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = RolloutWorker( env_creator=lambda _: EarlyDoneMultiAgent(), policy={ "p0": (MockPolicy, obs_space, act_space, {}), "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_mode="complete_episodes", rollout_fragment_length=1) self.assertRaisesRegexp(ValueError, ".*don't have a last observation.*", lambda: ev.sample())