def test_logged_single_env(): """ Test LoggedEnv for a single environment. """ with tempfile.TemporaryDirectory() as dirpath: log_file = os.path.join(dirpath, 'monitor.csv') env = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file) for _ in range(4): env.reset() while not env.step(env.action_space.sample())[2]: pass env.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) assert list(log_contents['r']) == [2] * 4 assert list(log_contents['l']) == [3] * 4
def test_single_env(self): """ Test monitoring for a single environment. """ dirpath = tempfile.mkdtemp() try: log_file = os.path.join(dirpath, 'monitor.csv') env = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file) for _ in range(4): env.reset() while not env.step(env.action_space.sample())[2]: pass env.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) self.assertEqual(list(log_contents['r']), [2] * 4) self.assertEqual(list(log_contents['l']), [3] * 4) finally: shutil.rmtree(dirpath)
def test_multi_env(self): """ Test monitoring for concurrent environments. """ dirpath = tempfile.mkdtemp() try: log_file = os.path.join(dirpath, 'monitor.csv') env1 = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file, use_locking=True) env2 = LoggedEnv(SimpleEnv(3, (3,), 'float32'), log_file, use_locking=True) env1.reset() env2.reset() for _ in range(13): for env in [env1, env2]: if env.step(env.action_space.sample())[2]: env.reset() env1.close() env2.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) self.assertEqual(list(log_contents['r']), [2, 2.5, 2, 2.5, 2, 2, 2.5]) self.assertEqual(list(log_contents['l']), [3, 4, 3, 4, 3, 3, 4]) finally: shutil.rmtree(dirpath)
def test_multi_env(): """ Test monitoring for concurrent environments. """ with tempfile.TemporaryDirectory() as dirpath: log_file = os.path.join(dirpath, 'monitor.csv') env1 = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file, use_locking=True) env2 = LoggedEnv(SimpleEnv(3, (3, ), 'float32'), log_file, use_locking=True) env1.reset() env2.reset() for _ in range(13): for env in [env1, env2]: if env.step(env.action_space.sample())[2]: env.reset() env1.close() env2.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) assert list(log_contents['r']) == [2, 2.5, 2, 2.5, 2, 2, 2.5] assert list(log_contents['l']) == [3, 4, 3, 4, 3, 3, 4]