def run_ppo_agent(self, make_networks_fn): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_networks_fn(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() counter = counting.Counter() logger = loggers.make_default_logger('learner') # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, counter=counter, logger=logger, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def test_ppo_nest_safety(self): # Create a fake environment with nested observations. environment = fakes.NestedDiscreteEnvironment(num_observations={ 'lat': 2, 'long': 3 }, num_actions=5, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_haiku_networks(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def build_experiment_config(): """Builds PPO experiment config which can be executed in different ways.""" # Create an environment, grab the spec, and use it to create networks. suite, task = FLAGS.env_name.split(':', 1) config = ppo.PPOConfig(entropy_cost=0, learning_rate=1e-4) ppo_builder = ppo.PPOBuilder(config) layer_sizes = (256, 256, 256) return experiments.ExperimentConfig( builder=ppo_builder, environment_factory=lambda seed: helpers.make_environment(suite, task), network_factory=lambda spec: ppo.make_networks(spec, layer_sizes), seed=FLAGS.seed, max_num_actor_steps=FLAGS.num_steps)
def main(_): task = FLAGS.task environment_factory = lambda seed: helpers.make_environment(task) config = ppo.PPOConfig(unroll_length=16, num_minibatches=32, num_epochs=10, batch_size=2048 // 16) program = ppo.DistributedPPO(environment_factory=environment_factory, network_factory=ppo.make_continuous_networks, config=config, seed=FLAGS.seed, num_actors=4, max_number_of_steps=100).build() # Launch experiment. lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = ppo.make_continuous_networks(environment_spec) # Construct the agent. config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length, num_minibatches=FLAGS.num_minibatches, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size) learner_logger = experiment_utils.make_experiment_logger( label='learner', steps_key='learner_steps') agent = ppo.PPO(environment_spec, agent_networks, config=config, seed=FLAGS.seed, counter=counting.Counter(prefix='learner'), logger=learner_logger) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=ppo.make_inference_fn(agent_networks, evaluation=True), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = ppo.make_continuous_networks(environment_spec) # Construct the agent. ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length, num_minibatches=FLAGS.ppo_num_minibatches, num_epochs=FLAGS.ppo_num_epochs, batch_size=FLAGS.transition_batch_size // FLAGS.unroll_length, learning_rate=0.0003, entropy_cost=0, gae_lambda=0.8, value_cost=0.25) ppo_networks = ppo.make_continuous_networks(environment_spec) if FLAGS.pretrain: ppo_networks = add_bc_pretraining(ppo_networks) discriminator_batch_size = FLAGS.transition_batch_size ail_config = ail.AILConfig( direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length, discriminator_batch_size=discriminator_batch_size, is_sequence_based=True, num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step, share_iterator=FLAGS.share_iterator, ) def discriminator(*args, **kwargs) -> networks_lib.Logits: # Note: observation embedding is not needed for e.g. Mujoco. return ail.DiscriminatorModule( environment_spec=environment_spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP([4, 4], ), )(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) ail_network = ail.AILNetworks( ail.make_discriminator(environment_spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=ppo_networks) agent = ail.GAIL(spec=environment_spec, network=ail_network, config=ail.GAILConfig(ail_config, ppo_config), seed=FLAGS.seed, batch_size=ppo_config.batch_size, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=ppo.make_inference_fn(ppo_networks)) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=ppo.make_inference_fn(agent_networks, evaluation=True), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def test_ail(self, algo, airl_discriminator=False, subtract_logpi=False, dropout=0., lipschitz_coeff=None): shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True) batch_size = 8 # Mujoco environment and associated demonstration dataset. if algo == 'ppo': environment = fakes.DiscreteEnvironment( num_actions=NUM_DISCRETE_ACTIONS, num_observations=NUM_OBSERVATIONS, obs_shape=OBS_SHAPE, obs_dtype=OBS_DTYPE, episode_length=EPISODE_LENGTH) else: environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) if algo == 'sac': networks = sac.make_networks(spec=spec) config = sac.SACConfig( batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) elif algo == 'ppo': unroll_length = 5 distribution_value_networks = make_ppo_networks(spec) networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=unroll_length, num_minibatches=2, num_epochs=4, batch_size=batch_size) base_builder = ppo.PPOBuilder(config=config) direct_rl_batch_size = batch_size * unroll_length behavior_policy = jax.jit(ppo.make_inference_fn(networks), backend='cpu') else: raise ValueError(f'Unexpected algorithm {algo}') if subtract_logpi: assert algo == 'sac' logpi_fn = make_sac_logpi(networks) else: logpi_fn = None if algo == 'ppo': embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1]) else: embedding = lambda x: x def discriminator(*args, **kwargs) -> networks_lib.Logits: if airl_discriminator: return ail.AIRLModule( environment_spec=spec, use_action=True, use_next_obs=True, discount=.99, g_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), h_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) else: return ail.DiscriminatorModule( environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) discriminator_network = ail.make_discriminator( environment_spec=spec, discriminator_transformed=discriminator_transformed, logpi_fn=logpi_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig( is_sequence_based=(algo == 'ppo'), share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name='policy' if subtract_logpi else None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) # Construct the agent. agent = local_layout.LocalLayout(seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent) train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))