def test_discrete_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=10, num_observations=100, obs_shape=(10, ), obs_dtype=np.float32) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec, discrete_actions=True) def logp_fn(logits, actions): max_logits = jnp.max(logits, axis=-1, keepdims=True) logits = logits - max_logits logits_actions = jnp.sum( jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) log_prob = logits_actions - special.logsumexp(logits, axis=-1) return log_prob if loss_name == 'logp': loss_fn = bc.logp(logp_fn=logp_fn) elif loss_name == 'rcal': base_loss_fn = bc.logp(logp_fn=logp_fn) loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def test_continuous_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec) if loss_name == 'logp': loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) elif loss_name == 'mse': loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params .sample(seed=key)) elif loss_name == 'peerbc': base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) loss_fn = bc.peerbc(base_loss_fn, zeta=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def inner_wrapper(*args, **kwargs): if FLAGS.jaxline_disable_pmap_jit: with chex.fake_pmap_and_jit(): return fn(*args, **kwargs) else: return fn(*args, **kwargs)
def inner_wrapper(*args, **kwargs): if _JAXLINE_DISABLE_PMAP_JIT.value: with chex.fake_pmap_and_jit(): return fn(*args, **kwargs) else: return fn(*args, **kwargs)