Beispiel #1
0
def build_experiment_config():
    """Builds MDQN experiment config which can be executed in different ways."""
    # Create an environment, grab the spec, and use it to create networks.
    env_name = FLAGS.env_name

    def env_factory(seed):
        del seed
        return helpers.make_atari_environment(level=env_name,
                                              sticky_actions=True,
                                              zero_discount_on_life_loss=False)

    environment_spec = specs.make_environment_spec(env_factory(0))

    # Create network.
    network = helpers.make_dqn_atari_network(environment_spec)

    # Construct the agent.
    config = dqn.DQNConfig(discount=0.99,
                           learning_rate=5e-5,
                           n_step=1,
                           epsilon=0.01,
                           target_update_period=2000,
                           min_replay_size=20_000,
                           max_replay_size=1_000_000,
                           samples_per_insert=8,
                           batch_size=32)
    loss_fn = losses.MunchausenQLearning(discount=config.discount,
                                         max_abs_reward=1.,
                                         huber_loss_parameter=1.,
                                         entropy_temperature=0.03,
                                         munchausen_coefficient=0.9)

    dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn)

    return experiments.ExperimentConfig(builder=dqn_builder,
                                        environment_factory=env_factory,
                                        network_factory=lambda spec: network,
                                        evaluator_factories=[],
                                        seed=FLAGS.seed,
                                        max_num_actor_steps=FLAGS.num_steps)
Beispiel #2
0
def main(_):
    # Create an environment, grab the spec.
    environment = utils.make_environment(task=FLAGS.env_name)
    aqua_config = config.AquademConfig()
    spec = specs.make_environment_spec(environment)
    discretized_spec = aquadem_builder.discretize_spec(spec,
                                                       aqua_config.num_actions)

    # Create AQuaDem builder.
    loss_fn = dqn.losses.MunchausenQLearning(max_abs_reward=100.)
    dqn_config = dqn.DQNConfig(min_replay_size=1000,
                               n_step=3,
                               num_sgd_steps_per_step=8,
                               learning_rate=1e-4,
                               samples_per_insert=256)
    rl_agent = dqn.DQNBuilder(config=dqn_config, loss_fn=loss_fn)
    make_demonstrations = utils.get_make_demonstrations_fn(
        FLAGS.env_name, FLAGS.num_demonstrations, FLAGS.seed)
    builder = aquadem_builder.AquademBuilder(
        rl_agent=rl_agent,
        config=aqua_config,
        make_demonstrations=make_demonstrations)

    # Create networks.
    q_network = aquadem_networks.make_q_network(spec=discretized_spec, )
    dqn_networks = dqn.DQNNetworks(
        policy_network=networks_lib.non_stochastic_network_to_typed(q_network))
    networks = aquadem_networks.make_action_candidates_network(
        spec=spec,
        num_actions=aqua_config.num_actions,
        discrete_rl_networks=dqn_networks)
    exploration_epsilon = 0.01
    discrete_policy = dqn.default_behavior_policy(dqn_networks,
                                                  exploration_epsilon)
    behavior_policy = aquadem_builder.get_aquadem_policy(
        discrete_policy, networks)

    # Create the environment loop used for training.
    agent = local_layout.LocalLayout(seed=FLAGS.seed,
                                     environment_spec=spec,
                                     builder=builder,
                                     networks=networks,
                                     policy_network=behavior_policy,
                                     batch_size=dqn_config.batch_size *
                                     dqn_config.num_sgd_steps_per_step)

    train_logger = loggers.CSVLogger(FLAGS.workdir, label='train')
    train_loop = acme.EnvironmentLoop(environment, agent, logger=train_logger)

    # Create the evaluation actor and loop.
    eval_policy = dqn.default_behavior_policy(dqn_networks, 0.)
    eval_policy = aquadem_builder.get_aquadem_policy(eval_policy, networks)
    eval_actor = builder.make_actor(random_key=jax.random.PRNGKey(FLAGS.seed),
                                    policy=eval_policy,
                                    environment_spec=spec,
                                    variable_source=agent)
    eval_env = utils.make_environment(task=FLAGS.env_name, evaluation=True)

    eval_logger = loggers.CSVLogger(FLAGS.workdir, label='eval')
    eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=10)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=10)