Example #1
0
    def run_ppo_agent(self, make_networks_fn):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_networks_fn(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        counter = counting.Counter()
        logger = loggers.make_default_logger('learner')
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
            counter=counter,
            logger=logger,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
Example #2
0
    def test_ppo_nest_safety(self):
        # Create a fake environment with nested observations.
        environment = fakes.NestedDiscreteEnvironment(num_observations={
            'lat': 2,
            'long': 3
        },
                                                      num_actions=5,
                                                      obs_shape=(10, 5),
                                                      obs_dtype=np.float32,
                                                      episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_haiku_networks(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Example #3
0
def build_experiment_config():
  """Builds PPO experiment config which can be executed in different ways."""
  # Create an environment, grab the spec, and use it to create networks.
  suite, task = FLAGS.env_name.split(':', 1)

  config = ppo.PPOConfig(entropy_cost=0, learning_rate=1e-4)
  ppo_builder = ppo.PPOBuilder(config)

  layer_sizes = (256, 256, 256)
  return experiments.ExperimentConfig(
      builder=ppo_builder,
      environment_factory=lambda seed: helpers.make_environment(suite, task),
      network_factory=lambda spec: ppo.make_networks(spec, layer_sizes),
      seed=FLAGS.seed,
      max_num_actor_steps=FLAGS.num_steps)
Example #4
0
def main(_):
    task = FLAGS.task
    environment_factory = lambda seed: helpers.make_environment(task)
    config = ppo.PPOConfig(unroll_length=16,
                           num_minibatches=32,
                           num_epochs=10,
                           batch_size=2048 // 16)
    program = ppo.DistributedPPO(environment_factory=environment_factory,
                                 network_factory=ppo.make_continuous_networks,
                                 config=config,
                                 seed=FLAGS.seed,
                                 num_actors=4,
                                 max_number_of_steps=100).build()

    # Launch experiment.
    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Example #5
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                           num_minibatches=FLAGS.num_minibatches,
                           num_epochs=FLAGS.num_epochs,
                           batch_size=FLAGS.batch_size)

    learner_logger = experiment_utils.make_experiment_logger(
        label='learner', steps_key='learner_steps')
    agent = ppo.PPO(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed,
                    counter=counting.Counter(prefix='learner'),
                    logger=learner_logger)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Example #6
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                               num_minibatches=FLAGS.ppo_num_minibatches,
                               num_epochs=FLAGS.ppo_num_epochs,
                               batch_size=FLAGS.transition_batch_size //
                               FLAGS.unroll_length,
                               learning_rate=0.0003,
                               entropy_cost=0,
                               gae_lambda=0.8,
                               value_cost=0.25)
    ppo_networks = ppo.make_continuous_networks(environment_spec)
    if FLAGS.pretrain:
        ppo_networks = add_bc_pretraining(ppo_networks)

    discriminator_batch_size = FLAGS.transition_batch_size
    ail_config = ail.AILConfig(
        direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length,
        discriminator_batch_size=discriminator_batch_size,
        is_sequence_based=True,
        num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step,
        share_iterator=FLAGS.share_iterator,
    )

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        # Note: observation embedding is not needed for e.g. Mujoco.
        return ail.DiscriminatorModule(
            environment_spec=environment_spec,
            use_action=True,
            use_next_obs=True,
            network_core=ail.DiscriminatorMLP([4, 4], ),
        )(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=ppo_networks)

    agent = ail.GAIL(spec=environment_spec,
                     network=ail_network,
                     config=ail.GAILConfig(ail_config, ppo_config),
                     seed=FLAGS.seed,
                     batch_size=ppo_config.batch_size,
                     make_demonstrations=functools.partial(
                         helpers.make_demonstration_iterator,
                         dataset_name=FLAGS.dataset_name),
                     policy_network=ppo.make_inference_fn(ppo_networks))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Example #7
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))