def build_experiment_config(): """Builds MDQN experiment config which can be executed in different ways.""" # Create an environment, grab the spec, and use it to create networks. env_name = FLAGS.env_name def env_factory(seed): del seed return helpers.make_atari_environment(level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) environment_spec = specs.make_environment_spec(env_factory(0)) # Create network. network = helpers.make_dqn_atari_network(environment_spec) # Construct the agent. config = dqn.DQNConfig(discount=0.99, learning_rate=5e-5, n_step=1, epsilon=0.01, target_update_period=2000, min_replay_size=20_000, max_replay_size=1_000_000, samples_per_insert=8, batch_size=32) loss_fn = losses.MunchausenQLearning(discount=config.discount, max_abs_reward=1., huber_loss_parameter=1., entropy_temperature=0.03, munchausen_coefficient=0.9) dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) return experiments.ExperimentConfig(builder=dqn_builder, environment_factory=env_factory, network_factory=lambda spec: network, evaluator_factories=[], seed=FLAGS.seed, max_num_actor_steps=FLAGS.num_steps)
def main(_): # Create an environment, grab the spec. environment = utils.make_environment(task=FLAGS.env_name) aqua_config = config.AquademConfig() spec = specs.make_environment_spec(environment) discretized_spec = aquadem_builder.discretize_spec(spec, aqua_config.num_actions) # Create AQuaDem builder. loss_fn = dqn.losses.MunchausenQLearning(max_abs_reward=100.) dqn_config = dqn.DQNConfig(min_replay_size=1000, n_step=3, num_sgd_steps_per_step=8, learning_rate=1e-4, samples_per_insert=256) rl_agent = dqn.DQNBuilder(config=dqn_config, loss_fn=loss_fn) make_demonstrations = utils.get_make_demonstrations_fn( FLAGS.env_name, FLAGS.num_demonstrations, FLAGS.seed) builder = aquadem_builder.AquademBuilder( rl_agent=rl_agent, config=aqua_config, make_demonstrations=make_demonstrations) # Create networks. q_network = aquadem_networks.make_q_network(spec=discretized_spec, ) dqn_networks = dqn.DQNNetworks( policy_network=networks_lib.non_stochastic_network_to_typed(q_network)) networks = aquadem_networks.make_action_candidates_network( spec=spec, num_actions=aqua_config.num_actions, discrete_rl_networks=dqn_networks) exploration_epsilon = 0.01 discrete_policy = dqn.default_behavior_policy(dqn_networks, exploration_epsilon) behavior_policy = aquadem_builder.get_aquadem_policy( discrete_policy, networks) # Create the environment loop used for training. agent = local_layout.LocalLayout(seed=FLAGS.seed, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, batch_size=dqn_config.batch_size * dqn_config.num_sgd_steps_per_step) train_logger = loggers.CSVLogger(FLAGS.workdir, label='train') train_loop = acme.EnvironmentLoop(environment, agent, logger=train_logger) # Create the evaluation actor and loop. eval_policy = dqn.default_behavior_policy(dqn_networks, 0.) eval_policy = aquadem_builder.get_aquadem_policy(eval_policy, networks) eval_actor = builder.make_actor(random_key=jax.random.PRNGKey(FLAGS.seed), policy=eval_policy, environment_spec=spec, variable_source=agent) eval_env = utils.make_environment(task=FLAGS.env_name, evaluation=True) eval_logger = loggers.CSVLogger(FLAGS.workdir, label='eval') eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=10) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=10)