def load(reward_scale, seed): """Load a cartpole experiment with the prescribed settings.""" env = wrappers.RewardScale(env=cartpole.Cartpole(seed=seed), reward_scale=reward_scale, seed=seed) env.bsuite_num_episodes = sweep.NUM_EPISODES return env
def load(reward_scale, seed): """Load a bandit_scale experiment with the prescribed settings.""" env = wrappers.RewardScale(env=mnist.MNISTBandit(seed=seed), reward_scale=reward_scale, seed=seed) env.bsuite_num_episodes = sweep.NUM_EPISODES return env
def load(reward_scale: float, seed: int): """Load a mountain_car experiment with the prescribed settings.""" env = wrappers.RewardScale(env=mountain_car.MountainCar(seed=seed), reward_scale=reward_scale, seed=seed) env.bsuite_num_episodes = sweep.NUM_EPISODES return env
def test_unwrap(self): raw_env = FakeEnvironment([dm_env.restart([])]) scale_env = wrappers.RewardScale(raw_env, reward_scale=1.) noise_env = wrappers.RewardNoise(scale_env, noise_scale=1.) logging_env = wrappers.Logging(noise_env, logger=None) # pytype: disable=wrong-arg-types unwrapped = logging_env.raw_env self.assertEqual(id(raw_env), id(unwrapped))
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) env = wrappers.RewardScale(env, reward_scale=FLAGS.reward_scale) agent = PopArtAgent( observation_spec=env.observation_spec(), action_spec=env.action_spec(), num_hidden_units=FLAGS.num_hidden_units, epsilon=FLAGS.epsilon, learning_rate=FLAGS.learning_rate, pop_art_step_size=FLAGS.pop_art_step_size, ) accumulator = TransitionAccumulator() experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=1, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )