def wrap_environment(env: dm_env.Environment, bsuite_id: Text, results_dir: Text, overwrite: bool = False, log_by_step: bool = False) -> dm_env.Environment: """Returns a wrapped environment that logs using CSV.""" logger = Logger(bsuite_id, results_dir, overwrite) return wrappers.Logging(env, logger, log_by_step=log_by_step)
def wrap_environment(env: dm_env.Environment, pretty_print: bool = True, log_every: bool = False, log_by_step: bool = False) -> dm_env.Environment: """Returns a wrapped environment that logs to terminal.""" logger = Logger(pretty_print) return wrappers.Logging( env, logger, log_by_step=log_by_step, log_every=log_every)
def test_unwrap(self): raw_env = FakeEnvironment([dm_env.restart([])]) scale_env = wrappers.RewardScale(raw_env, reward_scale=1.) noise_env = wrappers.RewardNoise(scale_env, noise_scale=1.) logging_env = wrappers.Logging(noise_env, logger=None) # pytype: disable=wrong-arg-types unwrapped = logging_env.raw_env self.assertEqual(id(raw_env), id(unwrapped))
def wrap_environment(env: environments.Environment, db_path: str, experiment_name: str, setting_index: int, log_by_step: bool = False) -> dm_env.Environment: """Returns a wrapped environment that logs using SQLite.""" logger = Logger(db_path, experiment_name, setting_index) return wrappers.Logging(env, logger, log_by_step=log_by_step)
def wrap_environment(env: environments.Environment, pretty_print: bool = True, log_every: bool = False, log_by_step: bool = False) -> dm_env.Environment: """Returns a wrapped environment that logs to terminal.""" # Set logging up to show up in STDERR. std_logging.getLogger().addHandler(logging.PythonHandler()) logger = Logger(pretty_print, absl_logging=True) return wrappers.Logging( env, logger, log_by_step=log_by_step, log_every=log_every)
def test_wrapper(self): """Tests that the wrapper computes and logs the correct data.""" mock_logger = mock.MagicMock() mock_logger.write = mock.MagicMock() # Make a fake environment that cycles through these time steps. timesteps = [ dm_env.restart([]), dm_env.transition(1, []), dm_env.transition(2, []), dm_env.termination(3, []), ] expected_episode_return = 6 fake_env = FakeEnvironment(timesteps) env = wrappers.Logging(env=fake_env, logger=mock_logger, log_every=True) num_episodes = 5 for _ in range(num_episodes): timestep = env.reset() while not timestep.last(): timestep = env.step(action=0) # We count the number of transitions, hence the -1. expected_episode_length = len(timesteps) - 1 expected_calls = [] for i in range(1, num_episodes + 1): expected_calls.append( mock.call( dict( steps=expected_episode_length * i, episode=i, total_return=expected_episode_return * i, episode_len=expected_episode_length, episode_return=expected_episode_return, ))) mock_logger.write.assert_has_calls(expected_calls)