def test_td3_fd(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment( episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. td3_network = td3.make_networks(spec) batch_size = 10 td3_config = td3.TD3Config( batch_size=batch_size, min_replay_size=1) lfd_config = lfd.LfdConfig(initial_insert_count=0, demonstration_ratio=0.2) td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config, td3_config=td3_config) counter = counting.Counter() agent = lfd.TD3fD( spec=spec, td3_network=td3_network, td3_fd_config=td3_fd_config, lfd_iterator_fn=fake_demonstration_iterator, seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = td3.make_networks(environment_spec) # Create the learner. learner = td3.TD3Learner( networks=networks, random_key=key_learner, discount=FLAGS.discount, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), use_sarsa_target=FLAGS.use_sarsa_target, bc_alpha=FLAGS.bc_alpha, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: del key return networks.policy_network.apply(params, observation) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def build_experiment_config(): """Builds TD3 experiment config which can be executed in different ways.""" # Create an environment, grab the spec, and use it to create networks. suite, task = FLAGS.env_name.split(':', 1) network_factory = (lambda spec: td3.make_networks( spec, hidden_layer_sizes=(256, 256, 256))) # Construct the agent. config = td3.TD3Config( policy_learning_rate=3e-4, critic_learning_rate=3e-4, ) td3_builder = td3.TD3Builder(config) # pylint:disable=g-long-lambda return experiments.ExperimentConfig( builder=td3_builder, environment_factory=lambda seed: helpers.make_environment(suite, task), network_factory=network_factory, seed=FLAGS.seed, max_num_actor_steps=FLAGS.num_steps)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = td3.make_networks(environment_spec) # Construct the agent. config = td3.TD3Config(num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step) agent = td3.TD3(environment_spec, agent_networks, config=config, seed=FLAGS.seed) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=td3.get_default_behavior_policy( agent_networks, environment_spec.actions, sigma=0.), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def test_td3(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. network = td3.make_networks(spec) config = td3.TD3Config(batch_size=10, min_replay_size=1) counter = counting.Counter() agent = td3.TD3(spec=spec, network=network, config=config, seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Construct the agent. # Local layout makes sure that we populate the buffer with min_replay_size # initial transitions and that there's no need for tolerance_rate. In order # for deadlocks not to happen we need to disable rate limiting that heppens # inside the TD3Builder. This is achieved by the min_replay_size and # samples_per_insert_tolerance_rate arguments. td3_config = td3.TD3Config( num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step, min_replay_size=1, samples_per_insert_tolerance_rate=float('inf')) td3_networks = td3.make_networks(environment_spec) if FLAGS.pretrain: td3_networks = add_bc_pretraining(td3_networks) ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size * td3_config.num_sgd_steps_per_step) dac_config = ail.DACConfig(ail_config, td3_config) def discriminator(*args, **kwargs) -> networks_lib.Logits: return ail.DiscriminatorModule(environment_spec=environment_spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], ))(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) ail_network = ail.AILNetworks( ail.make_discriminator(environment_spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=td3_networks) agent = ail.DAC(spec=environment_spec, network=ail_network, config=dac_config, seed=FLAGS.seed, batch_size=td3_config.batch_size * td3_config.num_sgd_steps_per_step, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=td3.get_default_behavior_policy( td3_networks, action_specs=environment_spec.actions, sigma=td3_config.sigma)) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. # TODO(lukstafi): sigma=0 for eval? eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=td3.get_default_behavior_policy( td3_networks, action_specs=environment_spec.actions, sigma=0.), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)