def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = td3.make_networks(environment_spec) # Create the learner. learner = td3.TD3Learner( networks=networks, random_key=key_learner, discount=FLAGS.discount, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), use_sarsa_target=FLAGS.use_sarsa_target, bc_alpha=FLAGS.bc_alpha, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: del key return networks.policy_network.apply(params, observation) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def make_demonstration_iterator(batch_size: int, dataset_name: str, seed: int = 0): dataset = tfds.get_tfds_dataset(dataset_name) return tfds.JaxInMemoryRandomSampleIterator(dataset, jax.random.PRNGKey(seed), batch_size)
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset. transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = cql.make_networks(environment_spec) # Create the learner. learner = cql.CQLLearner( batch_size=FLAGS.batch_size, networks=networks, random_key=key_learner, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, demonstrations=demonstrations, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = networks.policy_network.apply(params, observation) return networks.sample_eval(dist_params, key) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=_ENV_NAME.value) spec = specs.make_environment_spec(environment) key = jax.random.PRNGKey(_SEED.value) key, dataset_key, evaluator_key = jax.random.split(key, 3) # Load the dataset. dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train'] # Unwrap the environment to get the demonstrations. dataset = mbop.episodes_to_timestep_batched_transitions(dataset, return_horizon=10) dataset = tfds.JaxInMemoryRandomSampleIterator( dataset, key=dataset_key, batch_size=_BATCH_SIZE.value) # Apply normalization to the dataset. mean_std = mbop.get_normalization_stats(dataset, _NUM_NORMALIZATION_BATCHES.value) apply_normalization = jax.jit( functools.partial(running_statistics.normalize, mean_std=mean_std)) dataset = (apply_normalization(sample) for sample in dataset) # Create the networks. networks = mbop.make_networks(spec, hidden_layer_sizes=tuple( _HIDDEN_LAYER_SIZES.value)) # Use the default losses. losses = mbop.MBOPLosses() def logger_fn(label: str, steps_key: str): return loggers.make_default_logger(label, steps_key=steps_key) def make_learner(name, logger_fn, counter, rng_key, dataset, network, loss): return mbop.make_ensemble_regressor_learner( name, _NUM_NETWORKS.value, logger_fn, counter, rng_key, dataset, network, loss, optax.adam(_LEARNING_RATE.value), _NUM_SGD_STEPS_PER_STEP.value, ) learner = mbop.MBOPLearner( networks, losses, dataset, key, logger_fn, functools.partial(make_learner, 'world_model'), functools.partial(make_learner, 'policy_prior'), functools.partial(make_learner, 'n_step_return')) planning_config = mbop.MPPIConfig() assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, ( 'Number of trajectories must be a multiple of the number of networks.') actor_core = mbop.make_ensemble_actor_core(networks, planning_config, spec, mean_std, use_round_robin=False) evaluator = mbop.make_actor(actor_core, evaluator_key, learner) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Train the agent. while True: for _ in range(_EVALUATE_EVERY.value): learner.step() eval_loop.run(_EVALUATION_EPISODES.value)
def make_demonstrations(batch_size): return acme_tfds.JaxInMemoryRandomSampleIterator( transitions_iterator, jax.random.PRNGKey(seed), batch_size)