def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def test_discrete_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=10, num_observations=100, obs_shape=(10, ), obs_dtype=np.float32) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec, discrete_actions=True) def logp_fn(logits, actions): max_logits = jnp.max(logits, axis=-1, keepdims=True) logits = logits - max_logits logits_actions = jnp.sum( jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) log_prob = logits_actions - special.logsumexp(logits, axis=-1) return log_prob if loss_name == 'logp': loss_fn = bc.logp(logp_fn=logp_fn) elif loss_name == 'rcal': base_loss_fn = bc.logp(logp_fn=logp_fn) loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def make_ensemble_regressor_learner( name: str, num_networks: int, logger_fn: loggers.LoggerFactory, counter: counting.Counter, rng_key: jnp.ndarray, iterator: Iterator[types.Transition], base_network: networks_lib.FeedForwardNetwork, loss: mbop_losses.TransitionLoss, optimizer: optax.GradientTransformation, num_sgd_steps_per_step: int, ): """Creates an ensemble regressor learner from the base network. Args: name: Name of the learner used for logging and counters. num_networks: Number of networks in the ensemble. logger_fn: Constructs a logger for a label. counter: Parent counter object. rng_key: Random key. iterator: An iterator of time-batched transitions used to train the networks. base_network: Base network for the ensemble. loss: Training loss to use. optimizer: Optax optimizer. num_sgd_steps_per_step: Number of gradient updates per step. Returns: An ensemble regressor learner. """ mbop_ensemble = ensemble.make_ensemble(base_network, ensemble.apply_all, num_networks) local_counter = counting.Counter(parent=counter, prefix=name) local_logger = logger_fn(name, local_counter.get_steps_key()) if logger_fn else None def loss_fn(apply_fn: Callable[..., networks_lib.NetworkOutput], params: networks_lib.Params, key: jnp.ndarray, transitions: types.Transition) -> jnp.ndarray: del key return loss(functools.partial(apply_fn, params), transitions) # This is effectively a regressor learner. return bc.BCLearner( mbop_ensemble, rng_key, loss_fn, optimizer, iterator, num_sgd_steps_per_step, logger=local_logger, counter=local_counter)
def test_continuous_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec) if loss_name == 'logp': loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) elif loss_name == 'mse': loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params .sample(seed=key)) elif loss_name == 'peerbc': base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) loss_fn = bc.peerbc(base_loss_fn, zeta=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()