Example #1
0
def test_logged_single_env():
    """
    Test LoggedEnv for a single environment.
    """
    with tempfile.TemporaryDirectory() as dirpath:
        log_file = os.path.join(dirpath, 'monitor.csv')
        env = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file)
        for _ in range(4):
            env.reset()
            while not env.step(env.action_space.sample())[2]:
                pass
        env.close()
        with open(log_file, 'rt'):
            log_contents = pandas.read_csv(log_file)
            assert list(log_contents['r']) == [2] * 4
            assert list(log_contents['l']) == [3] * 4
Example #2
0
 def test_single_env(self):
     """
     Test monitoring for a single environment.
     """
     dirpath = tempfile.mkdtemp()
     try:
         log_file = os.path.join(dirpath, 'monitor.csv')
         env = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file)
         for _ in range(4):
             env.reset()
             while not env.step(env.action_space.sample())[2]:
                 pass
         env.close()
         with open(log_file, 'rt'):
             log_contents = pandas.read_csv(log_file)
             self.assertEqual(list(log_contents['r']), [2] * 4)
             self.assertEqual(list(log_contents['l']), [3] * 4)
     finally:
         shutil.rmtree(dirpath)
Example #3
0
    def test_multi_env(self):
        """
        Test monitoring for concurrent environments.
        """
        dirpath = tempfile.mkdtemp()
        try:
            log_file = os.path.join(dirpath, 'monitor.csv')
            env1 = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file, use_locking=True)
            env2 = LoggedEnv(SimpleEnv(3, (3,), 'float32'), log_file, use_locking=True)

            env1.reset()
            env2.reset()
            for _ in range(13):
                for env in [env1, env2]:
                    if env.step(env.action_space.sample())[2]:
                        env.reset()
            env1.close()
            env2.close()

            with open(log_file, 'rt'):
                log_contents = pandas.read_csv(log_file)
                self.assertEqual(list(log_contents['r']), [2, 2.5, 2, 2.5, 2, 2, 2.5])
                self.assertEqual(list(log_contents['l']), [3, 4, 3, 4, 3, 3, 4])
        finally:
            shutil.rmtree(dirpath)
Example #4
0
def test_multi_env():
    """
    Test monitoring for concurrent environments.
    """
    with tempfile.TemporaryDirectory() as dirpath:
        log_file = os.path.join(dirpath, 'monitor.csv')
        env1 = LoggedEnv(SimpleEnv(2, (3, ), 'float32'),
                         log_file,
                         use_locking=True)
        env2 = LoggedEnv(SimpleEnv(3, (3, ), 'float32'),
                         log_file,
                         use_locking=True)

        env1.reset()
        env2.reset()
        for _ in range(13):
            for env in [env1, env2]:
                if env.step(env.action_space.sample())[2]:
                    env.reset()
        env1.close()
        env2.close()

        with open(log_file, 'rt'):
            log_contents = pandas.read_csv(log_file)
            assert list(log_contents['r']) == [2, 2.5, 2, 2.5, 2, 2, 2.5]
            assert list(log_contents['l']) == [3, 4, 3, 4, 3, 3, 4]