def test_discrete_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=10, num_observations=100, obs_shape=(10, ), obs_dtype=np.float32) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec, discrete_actions=True) def logp_fn(logits, actions): max_logits = jnp.max(logits, axis=-1, keepdims=True) logits = logits - max_logits logits_actions = jnp.sum( jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) log_prob = logits_actions - special.logsumexp(logits, axis=-1) return log_prob if loss_name == 'logp': loss_fn = bc.logp(logp_fn=logp_fn) elif loss_name == 'rcal': base_loss_fn = bc.logp(logp_fn=logp_fn) loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def test_continuous_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec) if loss_name == 'logp': loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) elif loss_name == 'mse': loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params .sample(seed=key)) elif loss_name == 'peerbc': base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) loss_fn = bc.peerbc(base_loss_fn, zeta=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def add_bc_pretraining(sac_networks: sac.SACNetworks) -> sac.SACNetworks: """Augments `sac_networks` to run BC pretraining in policy_network.init.""" make_demonstrations = functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name) bc_network = bc.pretraining.convert_to_bc_network( sac_networks.policy_network) loss = bc.logp(sac_networks.log_prob) def bc_init(*unused_args): return bc.pretraining.train_with_bc(make_demonstrations, bc_network, loss) return dataclasses.replace(sac_networks, policy_network=networks_lib.FeedForwardNetwork( bc_init, sac_networks.policy_network.apply))
def add_bc_pretraining(ppo_networks: ppo.PPONetworks) -> ppo.PPONetworks: """Augments `ppo_networks` to run BC pretraining in policy_network.init.""" make_demonstrations = functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name) bc_network = bc.pretraining.convert_policy_value_to_bc_network( ppo_networks.network) loss = bc.logp(ppo_networks.log_prob) # Note: despite only training the policy network, this will also include the # initial value network params. def bc_init(*unused_args): return bc.pretraining.train_with_bc(make_demonstrations, bc_network, loss) return dataclasses.replace(ppo_networks, network=networks_lib.FeedForwardNetwork( bc_init, ppo_networks.network.apply))