コード例 #1
0
def test_truncation(stateful, state_tuple):
    """
    Test sequence truncation for TruncatedRoller with a
    batch of one environment.
    """
    def env_fn():
        return SimpleEnv(7, (5, 3), 'uint8')

    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model,
                                   total_timesteps // 2 + 1)
    actual1 = trunc_roller.rollouts()
    assert actual1[-1].trunc_end
    actual2 = trunc_roller.rollouts()
    expected1, expected2 = _artificial_truncation(expected,
                                                  len(actual1) - 1,
                                                  actual1[-1].num_steps)
    assert len(actual2) == len(expected2) + 1
    actual2 = actual2[:-1]
    _compare_rollout_batch(actual1, expected1)
    _compare_rollout_batch(actual2, expected2)
コード例 #2
0
    def _test_batch_equivalence_case(self, stateful, state_tuple):
        """
        Test that doing things in batches is consistent,
        given the model parameters.
        """
        env_fns = [
            lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
        ]
        model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

        unbatched_rollouts = []
        for env_fn in env_fns:
            batched_env = batched_gym_env([env_fn], sync=True)
            trunc_roller = TruncatedRoller(batched_env, model, 17)
            for _ in range(3):
                unbatched_rollouts.extend(trunc_roller.rollouts())

        batched_rollouts = []
        batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            batched_rollouts.extend(trunc_roller.rollouts())

        _compare_rollout_batch(self,
                               unbatched_rollouts,
                               batched_rollouts,
                               ordered=False)
コード例 #3
0
def test_ep_batches(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a batch of envs.
    """
    def env_fn():
        return SimpleEnv(3, (4, 5), 'uint8')

    model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

    batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()

    total_steps = sum([r.num_steps for r in actual])
    assert len(actual) >= ep_roller.min_episodes
    assert total_steps >= ep_roller.min_steps

    if 'min_steps' not in limits:
        num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
        assert len(actual) == num_eps

    basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
    expected = basic_roller.rollouts()

    _compare_rollout_batch(actual, expected)
コード例 #4
0
    def _test_truncation_case(self, stateful, state_tuple):
        """
        Test rollout truncation and continuation for a
        specific set of model parameters.
        """
        env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model,
                                       total_timesteps // 2 + 1)
        actual1 = trunc_roller.rollouts()
        self.assertTrue(actual1[-1].trunc_end)
        actual2 = trunc_roller.rollouts()
        expected1, expected2 = _artificial_truncation(expected,
                                                      len(actual1) - 1,
                                                      actual1[-1].num_steps)
        self.assertEqual(len(actual2), len(expected2) + 1)
        actual2 = actual2[:-1]
        _compare_rollout_batch(self, actual1, expected1)
        _compare_rollout_batch(self, actual2, expected2)
コード例 #5
0
    def _test_batch_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence when using a batch of
        environments.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

        batched_env = batched_gym_env([env_fn] * 21,
                                      num_sub_batches=7,
                                      sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()

        total_steps = sum([r.num_steps for r in actual])
        self.assertTrue(len(actual) >= ep_roller.min_episodes)
        self.assertTrue(total_steps >= ep_roller.min_steps)

        if 'min_steps' not in roller_kwargs:
            num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
            self.assertTrue(len(actual) == num_eps)

        basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
        expected = basic_roller.rollouts()

        _compare_rollout_batch(self, actual, expected)
コード例 #6
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_ep_basic_equivalence(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a single environment.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, **limits)
    expected = basic_roller.rollouts()

    batched_env = batched_gym_env([env_fn], sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()
    _compare_rollout_batch(actual, expected)
コード例 #7
0
    def _test_basic_equivalence_case(self, stateful, state_tuple):
        """
        Test BasicRoller equivalence for a specific set of
        model settings.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
        actual = trunc_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
コード例 #8
0
    def _test_basic_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence for a single env in a
        specific case.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, **roller_kwargs)
        expected = basic_roller.rollouts()

        batched_env = batched_gym_env([env_fn], sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
コード例 #9
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_trunc_basic_equivalence(stateful, state_tuple):
    """
    Test that TruncatedRoller is equivalent to BasicRoller
    for batches of one environment when the episodes end
    cleanly.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
    actual = trunc_roller.rollouts()
    _compare_rollout_batch(actual, expected)
コード例 #10
0
 def test_multiple_batches(self):
     """
     Make sure calling rollouts multiple times works.
     """
     env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
     env = env_fn()
     try:
         model = SimpleModel(env.action_space.low.shape)
     finally:
         env.close()
     batched_env = batched_gym_env([env_fn], sync=True)
     try:
         ep_roller = EpisodeRoller(batched_env,
                                   model,
                                   min_episodes=5,
                                   min_steps=7)
         first = ep_roller.rollouts()
         for _ in range(3):
             _compare_rollout_batch(self, first, ep_roller.rollouts())
     finally:
         batched_env.close()
コード例 #11
0
def test_trunc_drop_states():
    """
    Test TruncatedRoller with drop_states=True.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=True, state_tuple=True)

    expected_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        expected_rollouts.extend(trunc_roller.rollouts())
    for rollout in expected_rollouts:
        for model_out in rollout.model_outs:
            model_out['states'] = None

    actual_rollouts = []
    trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True)
    for _ in range(3):
        actual_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(actual_rollouts, expected_rollouts)
コード例 #12
0
def test_trunc_batches(stateful, state_tuple):
    """
    Test that TruncatedRoller produces the same result for
    batches as it does for individual environments.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

    unbatched_rollouts = []
    for env_fn in env_fns:
        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            unbatched_rollouts.extend(trunc_roller.rollouts())

    batched_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        batched_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)