Ejemplo n.º 1
0
  def test_td3_fd(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    td3_network = td3.make_networks(spec)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=1)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)
    counter = counting.Counter()
    agent = lfd.TD3fD(
        spec=spec,
        td3_network=td3_network,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=20)
Ejemplo n.º 2
0
def main(_):
  task = FLAGS.task
  env_factory = lambda seed: helpers.make_environment(task)

  environment_spec = specs.make_environment_spec(env_factory(True))
  program = td3.DistributedTD3(
      environment_factory=env_factory,
      environment_spec=environment_spec,
      network_factory=td3.make_networks,
      config=td3.TD3Config(),
      num_actors=4,
      seed=1,
      max_number_of_steps=100).build()

  lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Ejemplo n.º 3
0
def build_experiment_config():
    """Builds TD3 experiment config which can be executed in different ways."""
    # Create an environment, grab the spec, and use it to create networks.

    suite, task = FLAGS.env_name.split(':', 1)
    network_factory = (lambda spec: td3.make_networks(
        spec, hidden_layer_sizes=(256, 256, 256)))

    # Construct the agent.
    config = td3.TD3Config(
        policy_learning_rate=3e-4,
        critic_learning_rate=3e-4,
    )
    td3_builder = td3.TD3Builder(config)
    # pylint:disable=g-long-lambda
    return experiments.ExperimentConfig(
        builder=td3_builder,
        environment_factory=lambda seed: helpers.make_environment(suite, task),
        network_factory=network_factory,
        seed=FLAGS.seed,
        max_num_actor_steps=FLAGS.num_steps)
Ejemplo n.º 4
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = td3.make_networks(environment_spec)

    # Construct the agent.
    config = td3.TD3Config(num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step)
    agent = td3.TD3(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            agent_networks, environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Ejemplo n.º 5
0
  def test_distributed_td3_fd(self):
    def make_env(seed):
      del seed
      return fakes.ContinuousEnvironment(
          episode_length=10, action_dim=3, observation_dim=5, bounded=True)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=16,
        samples_per_insert=2)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)

    spec = specs.make_environment_spec(make_env(0))

    agent = lfd.DistributedTD3fD(
        environment_factory=make_env,
        environment_spec=spec,
        network_factory=td3.make_networks,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        num_actors=2)

    program = agent.build()
    (learner_node,) = program.groups['learner']
    learner_node.disable_run()  # pytype: disable=attribute-error

    lp.launch(program, launch_type='test_mt')

    learner: acme.Learner = learner_node.create_handle().dereference()

    for _ in range(5):
      learner.step()
Ejemplo n.º 6
0
    def test_td3(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        network = td3.make_networks(spec)

        config = td3.TD3Config(batch_size=10, min_replay_size=1)

        counter = counting.Counter()
        agent = td3.TD3(spec=spec,
                        network=network,
                        config=config,
                        seed=0,
                        counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Ejemplo n.º 7
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the TD3Builder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    td3_config = td3.TD3Config(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    td3_networks = td3.make_networks(environment_spec)
    if FLAGS.pretrain:
        td3_networks = add_bc_pretraining(td3_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size *
                               td3_config.num_sgd_steps_per_step)
    dac_config = ail.DACConfig(ail_config, td3_config)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=td3_networks)

    agent = ail.DAC(spec=environment_spec,
                    network=ail_network,
                    config=dac_config,
                    seed=FLAGS.seed,
                    batch_size=td3_config.batch_size *
                    td3_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=td3.get_default_behavior_policy(
                        td3_networks,
                        action_specs=environment_spec.actions,
                        sigma=td3_config.sigma))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    # TODO(lukstafi): sigma=0 for eval?
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            td3_networks, action_specs=environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)