コード例 #1
0
def test_multi_env():
    """
    Test monitoring for concurrent environments.
    """
    with tempfile.TemporaryDirectory() as dirpath:
        log_file = os.path.join(dirpath, 'monitor.csv')
        env1 = LoggedEnv(SimpleEnv(2, (3, ), 'float32'),
                         log_file,
                         use_locking=True)
        env2 = LoggedEnv(SimpleEnv(3, (3, ), 'float32'),
                         log_file,
                         use_locking=True)

        env1.reset()
        env2.reset()
        for _ in range(13):
            for env in [env1, env2]:
                if env.step(env.action_space.sample())[2]:
                    env.reset()
        env1.close()
        env2.close()

        with open(log_file, 'rt'):
            log_contents = pandas.read_csv(log_file)
            assert list(log_contents['r']) == [2, 2.5, 2, 2.5, 2, 2, 2.5]
            assert list(log_contents['l']) == [3, 4, 3, 4, 3, 3, 4]
コード例 #2
0
    def test_multi_env(self):
        """
        Test monitoring for concurrent environments.
        """
        dirpath = tempfile.mkdtemp()
        try:
            log_file = os.path.join(dirpath, 'monitor.csv')
            env1 = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file, use_locking=True)
            env2 = LoggedEnv(SimpleEnv(3, (3,), 'float32'), log_file, use_locking=True)

            env1.reset()
            env2.reset()
            for _ in range(13):
                for env in [env1, env2]:
                    if env.step(env.action_space.sample())[2]:
                        env.reset()
            env1.close()
            env2.close()

            with open(log_file, 'rt'):
                log_contents = pandas.read_csv(log_file)
                self.assertEqual(list(log_contents['r']), [2, 2.5, 2, 2.5, 2, 2, 2.5])
                self.assertEqual(list(log_contents['l']), [3, 4, 3, 4, 3, 3, 4])
        finally:
            shutil.rmtree(dirpath)
コード例 #3
0
def test_skip():
    """
    Test a FrameSkipEnv wrapper.
    """
    # Timestep limit is 5.
    env = SimpleEnv(4, (3, 2, 5), 'float32')
    act1 = np.random.uniform(high=255.0, size=(3, 2, 5))
    act2 = np.random.uniform(high=255.0, size=(3, 2, 5))
    obs1 = env.reset()
    rew1 = 0.0
    rew2 = 0.0
    for _ in range(3):
        obs2, rew, _, _ = env.step(act1)
        rew1 += rew
    for _ in range(2):
        obs3, rew, done, _ = env.step(act2)
        rew2 += rew
    assert done

    env = FrameSkipEnv(env, num_frames=3)
    actual_obs1 = env.reset()
    assert np.allclose(actual_obs1, obs1)
    actual_obs2, actual_rew1, done, _ = env.step(act1)
    assert not done
    assert actual_rew1 == rew1
    assert np.allclose(actual_obs2, obs2)
    actual_obs3, actual_rew2, done, _ = env.step(act2)
    assert done
    assert actual_rew2 == rew2
    assert np.allclose(actual_obs3, obs3)
コード例 #4
0
 def test_resize_even(self):
     """
     Test resizing for an even number of pixels.
     """
     env = SimpleEnv(5, (13, 5, 3), 'float32')
     frame = env.reset()
     actual = ResizeImageEnv(env, size=(5, 4)).reset()
     expected = tf.Session().run(
         tf.image.resize_images(frame, [5, 4], method=tf.image.ResizeMethod.AREA))
     self.assertEqual(actual.shape, (5, 4, 3))
     self.assertTrue(np.allclose(actual, expected))
コード例 #5
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_truncation(stateful, state_tuple):
    """
    Test sequence truncation for TruncatedRoller with a
    batch of one environment.
    """
    env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model,
                                   total_timesteps // 2 + 1)
    actual1 = trunc_roller.rollouts()
    assert actual1[-1].trunc_end
    actual2 = trunc_roller.rollouts()
    expected1, expected2 = _artificial_truncation(expected,
                                                  len(actual1) - 1,
                                                  actual1[-1].num_steps)
    assert len(actual2) == len(expected2) + 1
    actual2 = actual2[:-1]
    _compare_rollout_batch(actual1, expected1)
    _compare_rollout_batch(actual2, expected2)
コード例 #6
0
ファイル: test_wrappers.py プロジェクト: decoderkurt/anyrl-py
def test_batched_stack(concat):
    """
    Test that BatchedFrameStack is equivalent to a regular
    batched FrameStackEnv.
    """
    envs = [
        lambda idx=i: SimpleEnv(idx + 2, (3, 2, 5), 'float32')
        for i in range(6)
    ]
    env1 = BatchedFrameStack(batched_gym_env(envs,
                                             num_sub_batches=3,
                                             sync=True),
                             concat=concat)
    env2 = batched_gym_env(
        [lambda env=e: FrameStackEnv(env(), concat=concat) for e in envs],
        num_sub_batches=3,
        sync=True)
    for j in range(50):
        for i in range(3):
            if j == 0 or (j + i) % 17 == 0:
                env1.reset_start(sub_batch=i)
                env2.reset_start(sub_batch=i)
                obs1 = env1.reset_wait(sub_batch=i)
                obs2 = env2.reset_wait(sub_batch=i)
                assert np.allclose(obs1, obs2)
            actions = [env1.action_space.sample() for _ in range(2)]
            env1.step_start(actions, sub_batch=i)
            env2.step_start(actions, sub_batch=i)
            obs1, rews1, dones1, _ = env1.step_wait(sub_batch=i)
            obs2, rews2, dones2, _ = env2.step_wait(sub_batch=i)
            assert np.allclose(obs1, obs2)
            assert np.array(rews1 == rews2).all()
            assert np.array(dones1 == dones2).all()
コード例 #7
0
    def _test_batch_equivalence_case(self, stateful, state_tuple):
        """
        Test that doing things in batches is consistent,
        given the model parameters.
        """
        env_fns = [
            lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
        ]
        model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

        unbatched_rollouts = []
        for env_fn in env_fns:
            batched_env = batched_gym_env([env_fn], sync=True)
            trunc_roller = TruncatedRoller(batched_env, model, 17)
            for _ in range(3):
                unbatched_rollouts.extend(trunc_roller.rollouts())

        batched_rollouts = []
        batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            batched_rollouts.extend(trunc_roller.rollouts())

        _compare_rollout_batch(self,
                               unbatched_rollouts,
                               batched_rollouts,
                               ordered=False)
コード例 #8
0
    def _test_truncation_case(self, stateful, state_tuple):
        """
        Test rollout truncation and continuation for a
        specific set of model parameters.
        """
        env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model,
                                       total_timesteps // 2 + 1)
        actual1 = trunc_roller.rollouts()
        self.assertTrue(actual1[-1].trunc_end)
        actual2 = trunc_roller.rollouts()
        expected1, expected2 = _artificial_truncation(expected,
                                                      len(actual1) - 1,
                                                      actual1[-1].num_steps)
        self.assertEqual(len(actual2), len(expected2) + 1)
        actual2 = actual2[:-1]
        _compare_rollout_batch(self, actual1, expected1)
        _compare_rollout_batch(self, actual2, expected2)
コード例 #9
0
    def _test_batch_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence when using a batch of
        environments.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

        batched_env = batched_gym_env([env_fn] * 21,
                                      num_sub_batches=7,
                                      sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()

        total_steps = sum([r.num_steps for r in actual])
        self.assertTrue(len(actual) >= ep_roller.min_episodes)
        self.assertTrue(total_steps >= ep_roller.min_steps)

        if 'min_steps' not in roller_kwargs:
            num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
            self.assertTrue(len(actual) == num_eps)

        basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
        expected = basic_roller.rollouts()

        _compare_rollout_batch(self, actual, expected)
コード例 #10
0
 def test_max_2(self):
     """
     Test maxing 2 frames.
     """
     env = SimpleEnv(5, (3, 2, 5), 'float32')
     actions = [env.action_space.sample() for _ in range(4)]
     frame1 = env.reset()
     frame2 = env.step(actions[0])[0]
     frame3 = env.step(actions[1])[0]
     frame4 = env.step(actions[2])[0]
     frame5 = env.step(actions[3])[0]
     wrapped = MaxEnv(env, num_images=2)
     max1 = wrapped.reset()
     max2 = wrapped.step(actions[0])[0]
     max3 = wrapped.step(actions[1])[0]
     max4 = wrapped.step(actions[2])[0]
     max5 = wrapped.step(actions[3])[0]
     self.assertTrue((max1 == frame1).all())
     self.assertTrue((max2 == np.max([frame1, frame2], axis=0)).all())
     self.assertTrue((max3 == np.max([frame2, frame3], axis=0)).all())
     self.assertTrue((max4 == np.max([frame3, frame4], axis=0)).all())
     self.assertTrue((max5 == np.max([frame4, frame5], axis=0)).all())
コード例 #11
0
def test_logged_single_env():
    """
    Test LoggedEnv for a single environment.
    """
    with tempfile.TemporaryDirectory() as dirpath:
        log_file = os.path.join(dirpath, 'monitor.csv')
        env = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file)
        for _ in range(4):
            env.reset()
            while not env.step(env.action_space.sample())[2]:
                pass
        env.close()
        with open(log_file, 'rt'):
            log_contents = pandas.read_csv(log_file)
            assert list(log_contents['r']) == [2] * 4
            assert list(log_contents['l']) == [3] * 4
コード例 #12
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_ep_basic_equivalence(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a single environment.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, **limits)
    expected = basic_roller.rollouts()

    batched_env = batched_gym_env([env_fn], sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()
    _compare_rollout_batch(actual, expected)
コード例 #13
0
    def _test_basic_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence for a single env in a
        specific case.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, **roller_kwargs)
        expected = basic_roller.rollouts()

        batched_env = batched_gym_env([env_fn], sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
コード例 #14
0
    def _test_basic_equivalence_case(self, stateful, state_tuple):
        """
        Test BasicRoller equivalence for a specific set of
        model settings.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
        actual = trunc_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
コード例 #15
0
 def test_rewards(self):
     """
     Test that rewards are masked properly.
     """
     real_env = SimpleEnv(1, (3, 4), 'uint8')
     env = RL2Env(real_env, real_env.action_space.sample(), num_eps=5, warmup_eps=-2)
     env.reset()
     done_eps = 0
     while done_eps < 3:
         obs, rew, _, _ = env.step(env.action_space.sample())
         self.assertEqual(rew, 0)
         if obs[3]:
             done_eps += 1
     while True:
         _, rew, done, _ = env.step(env.action_space.sample())
         self.assertNotEqual(rew, 0)
         if done:
             break
コード例 #16
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_trunc_basic_equivalence(stateful, state_tuple):
    """
    Test that TruncatedRoller is equivalent to BasicRoller
    for batches of one environment when the episodes end
    cleanly.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
    actual = trunc_roller.rollouts()
    _compare_rollout_batch(actual, expected)
コード例 #17
0
 def test_single_env(self):
     """
     Test monitoring for a single environment.
     """
     dirpath = tempfile.mkdtemp()
     try:
         log_file = os.path.join(dirpath, 'monitor.csv')
         env = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file)
         for _ in range(4):
             env.reset()
             while not env.step(env.action_space.sample())[2]:
                 pass
         env.close()
         with open(log_file, 'rt'):
             log_contents = pandas.read_csv(log_file)
             self.assertEqual(list(log_contents['r']), [2] * 4)
             self.assertEqual(list(log_contents['l']), [3] * 4)
     finally:
         shutil.rmtree(dirpath)
コード例 #18
0
 def test_num_eps(self):
     """
     Test that the meta-episode contains the right
     number of sub-episodes.
     """
     real_env = SimpleEnv(1, (3, 4), 'uint8')
     env = RL2Env(real_env, real_env.action_space.sample(),
                  num_eps=3, warmup_eps=1)
     done_eps = 0
     env.reset()
     while done_eps < 3:
         obs, _, done, _ = env.step(env.action_space.sample())
         if obs[3]:
             done_eps += 1
         if done_eps == 3:
             self.assertTrue(done)
             break
         else:
             self.assertFalse(done)
コード例 #19
0
 def test_multiple_batches(self):
     """
     Make sure calling rollouts multiple times works.
     """
     env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
     env = env_fn()
     try:
         model = SimpleModel(env.action_space.low.shape)
     finally:
         env.close()
     batched_env = batched_gym_env([env_fn], sync=True)
     try:
         ep_roller = EpisodeRoller(batched_env,
                                   model,
                                   min_episodes=5,
                                   min_steps=7)
         first = ep_roller.rollouts()
         for _ in range(3):
             _compare_rollout_batch(self, first, ep_roller.rollouts())
     finally:
         batched_env.close()
コード例 #20
0
def test_rl2_num_eps():
    """
    Test that RL^2 meta-episodes contain the right number
    of sub-episodes.
    """
    real_env = SimpleEnv(1, (3, 4), 'uint8')
    env = RL2Env(real_env,
                 real_env.action_space.sample(),
                 num_eps=3,
                 warmup_eps=1)
    done_eps = 0
    env.reset()
    while done_eps < 3:
        obs, _, done, _ = env.step(env.action_space.sample())
        if obs[3]:
            done_eps += 1
        if done_eps == 3:
            assert done
            break
        else:
            assert not done
コード例 #21
0
ファイル: test_rollers.py プロジェクト: decoderkurt/anyrl-py
def test_ep_batches(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a batch of envs.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

    batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()

    total_steps = sum([r.num_steps for r in actual])
    assert len(actual) >= ep_roller.min_episodes
    assert total_steps >= ep_roller.min_steps

    if 'min_steps' not in limits:
        num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
        assert len(actual) == num_eps

    basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
    expected = basic_roller.rollouts()

    _compare_rollout_batch(actual, expected)
コード例 #22
0
def test_trunc_drop_states():
    """
    Test TruncatedRoller with drop_states=True.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=True, state_tuple=True)

    expected_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        expected_rollouts.extend(trunc_roller.rollouts())
    for rollout in expected_rollouts:
        for model_out in rollout.model_outs:
            model_out['states'] = None

    actual_rollouts = []
    trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True)
    for _ in range(3):
        actual_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(actual_rollouts, expected_rollouts)
コード例 #23
0
def test_trunc_batches(stateful, state_tuple):
    """
    Test that TruncatedRoller produces the same result for
    batches as it does for individual environments.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

    unbatched_rollouts = []
    for env_fn in env_fns:
        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            unbatched_rollouts.extend(trunc_roller.rollouts())

    batched_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        batched_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
コード例 #24
0
 def env_fn():
     return SimpleEnv(3, (4, 5), 'uint8')
コード例 #25
0
ファイル: test_env.py プロジェクト: unixpickle/anyrl-py
 def make_fn(seed):
     """
     Get an environment constructor with a seed.
     """
     return lambda: SimpleEnv(seed, SHAPE, dtype)
コード例 #26
0
 def env_fn():
     return SimpleEnv(7, (5, 3), 'uint8')