def MultiHeadedAttention( # pylint: disable=invalid-name feature_depth, num_heads=8, dropout=1.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate - keep probability mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return stax.serial( stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Identity), PureMultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=xavier_uniform()), )
def MultiHeadedAttention(feature_depth, num_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return combinators.Serial( combinators.Parallel( stax.Dense(feature_depth, W_init=stax.xavier_uniform()), stax.Dense(feature_depth, W_init=stax.xavier_uniform()), stax.Dense(feature_depth, W_init=stax.xavier_uniform()), combinators.Identity()), PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()), )
def initialize_policy_and_value_nets(rng_key, num_actions, batch_observations_shape): """Setup and initialize the policy and value networks.""" key1, key2 = jax_random.split(rng_key) policy_net_init, policy_net_apply = stax.serial( stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu, stax.Dense(num_actions), stax.Softmax, ) _, policy_net_params = policy_net_init(key1, batch_observations_shape) value_net_init, value_net_apply = stax.serial( stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu, stax.Dense( 1), # 1 since we want to predict reward using value network. ) _, value_net_params = value_net_init(key2, batch_observations_shape) return ((policy_net_params, policy_net_apply), (value_net_params, value_net_apply))
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(2), stax.parallel( stax.serial(stax.Dense(num_actions), stax.Softmax), stax.Dense(1) )]) net_init, net_apply = stax.serial(*layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def policy_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy net function.""" # Use the bottom_layers as the bottom part of the network and just add the # required layers on top of it. if bottom_layers is None: bottom_layers = [] bottom_layers.extend([stax.Dense(num_actions), stax.Softmax]) net_init, net_apply = stax.serial(*bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A value net function.""" del num_actions if bottom_layers is None: bottom_layers = [] bottom_layers.extend([ stax.Dense(1), ]) net_init, net_apply = stax.serial(*bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def test_training_loop(self): env = gym.make("CartPole-v0") # Usually gym envs are wrapped in TimeLimit wrapper. env = gym_utils.remove_time_limit_wrapper(env) # Limit this to a small number for tests. env = gym.wrappers.TimeLimit(env, max_episode_steps=2) num_epochs = 2 batch_size = 2 # Common bottom layer(s). bottom_layers = [stax.Dense(1)] # Run the training loop. _, rewards, val_losses, ppo_objectives = ppo.training_loop( env=env, epochs=num_epochs, policy_net_fun=functools.partial( ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial( ppo.value_net, bottom_layers=bottom_layers), batch_size=batch_size, num_optimizer_steps=1, random_seed=0) self.assertLen(rewards, num_epochs) self.assertLen(val_losses, num_epochs) self.assertLen(ppo_objectives, num_epochs)
def common_stax_layers(): return [stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu]