def test_td3_fd(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    td3_network = td3.make_networks(spec)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=1)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)
    counter = counting.Counter()
    agent = lfd.TD3fD(
        spec=spec,
        td3_network=td3_network,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=20)
Exemple #2
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #3
0
def build_experiment_config():
    """Builds TD3 experiment config which can be executed in different ways."""
    # Create an environment, grab the spec, and use it to create networks.

    suite, task = FLAGS.env_name.split(':', 1)
    network_factory = (lambda spec: td3.make_networks(
        spec, hidden_layer_sizes=(256, 256, 256)))

    # Construct the agent.
    config = td3.TD3Config(
        policy_learning_rate=3e-4,
        critic_learning_rate=3e-4,
    )
    td3_builder = td3.TD3Builder(config)
    # pylint:disable=g-long-lambda
    return experiments.ExperimentConfig(
        builder=td3_builder,
        environment_factory=lambda seed: helpers.make_environment(suite, task),
        network_factory=network_factory,
        seed=FLAGS.seed,
        max_num_actor_steps=FLAGS.num_steps)
Exemple #4
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = td3.make_networks(environment_spec)

    # Construct the agent.
    config = td3.TD3Config(num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step)
    agent = td3.TD3(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            agent_networks, environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Exemple #5
0
    def test_td3(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        network = td3.make_networks(spec)

        config = td3.TD3Config(batch_size=10, min_replay_size=1)

        counter = counting.Counter()
        agent = td3.TD3(spec=spec,
                        network=network,
                        config=config,
                        seed=0,
                        counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Exemple #6
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the TD3Builder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    td3_config = td3.TD3Config(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    td3_networks = td3.make_networks(environment_spec)
    if FLAGS.pretrain:
        td3_networks = add_bc_pretraining(td3_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size *
                               td3_config.num_sgd_steps_per_step)
    dac_config = ail.DACConfig(ail_config, td3_config)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=td3_networks)

    agent = ail.DAC(spec=environment_spec,
                    network=ail_network,
                    config=dac_config,
                    seed=FLAGS.seed,
                    batch_size=td3_config.batch_size *
                    td3_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=td3.get_default_behavior_policy(
                        td3_networks,
                        action_specs=environment_spec.actions,
                        sigma=td3_config.sigma))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    # TODO(lukstafi): sigma=0 for eval?
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            td3_networks, action_specs=environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)