Beispiel #1
0
def MultiHeadedAttention(  # pylint: disable=invalid-name
        feature_depth,
        num_heads=8,
        dropout=1.0,
        mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate - keep probability
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return stax.serial(
        stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Identity),
        PureMultiHeadedAttention(feature_depth,
                                 num_heads=num_heads,
                                 dropout=dropout,
                                 mode=mode),
        stax.Dense(feature_depth, W_init=xavier_uniform()),
    )
Beispiel #2
0
def MultiHeadedAttention(feature_depth,
                         num_heads=8,
                         dropout=0.0,
                         mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return combinators.Serial(
        combinators.Parallel(
            stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
            stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
            stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
            combinators.Identity()),
        PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
            feature_depth=feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
    )
Beispiel #3
0
def initialize_policy_and_value_nets(rng_key, num_actions,
                                     batch_observations_shape):
    """Setup and initialize the policy and value networks."""
    key1, key2 = jax_random.split(rng_key)

    policy_net_init, policy_net_apply = stax.serial(
        stax.Dense(16),
        stax.Relu,
        stax.Dense(4),
        stax.Relu,
        stax.Dense(num_actions),
        stax.Softmax,
    )

    _, policy_net_params = policy_net_init(key1, batch_observations_shape)

    value_net_init, value_net_apply = stax.serial(
        stax.Dense(16),
        stax.Relu,
        stax.Dense(4),
        stax.Relu,
        stax.Dense(
            1),  # 1 since we want to predict reward using value network.
    )

    _, value_net_params = value_net_init(key2, batch_observations_shape)

    return ((policy_net_params, policy_net_apply), (value_net_params,
                                                    value_net_apply))
Beispiel #4
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  layers = []
  if bottom_layers is not None:
    layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  layers.extend([stax.FanOut(2), stax.parallel(
      stax.serial(stax.Dense(num_actions), stax.Softmax),
      stax.Dense(1)
  )])

  net_init, net_apply = stax.serial(*layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Beispiel #5
0
def policy_net(rng_key,
               batch_observations_shape,
               num_actions,
               bottom_layers=None):
    """A policy net function."""
    # Use the bottom_layers as the bottom part of the network and just add the
    # required layers on top of it.
    if bottom_layers is None:
        bottom_layers = []
    bottom_layers.extend([stax.Dense(num_actions), stax.Softmax])

    net_init, net_apply = stax.serial(*bottom_layers)

    _, net_params = net_init(rng_key, batch_observations_shape)
    return net_params, net_apply
Beispiel #6
0
def value_net(rng_key,
              batch_observations_shape,
              num_actions,
              bottom_layers=None):
    """A value net function."""
    del num_actions

    if bottom_layers is None:
        bottom_layers = []
    bottom_layers.extend([
        stax.Dense(1),
    ])

    net_init, net_apply = stax.serial(*bottom_layers)

    _, net_params = net_init(rng_key, batch_observations_shape)
    return net_params, net_apply
Beispiel #7
0
 def test_training_loop(self):
   env = gym.make("CartPole-v0")
   # Usually gym envs are wrapped in TimeLimit wrapper.
   env = gym_utils.remove_time_limit_wrapper(env)
   # Limit this to a small number for tests.
   env = gym.wrappers.TimeLimit(env, max_episode_steps=2)
   num_epochs = 2
   batch_size = 2
   # Common bottom layer(s).
   bottom_layers = [stax.Dense(1)]
   # Run the training loop.
   _, rewards, val_losses, ppo_objectives = ppo.training_loop(
       env=env,
       epochs=num_epochs,
       policy_net_fun=functools.partial(
           ppo.policy_net, bottom_layers=bottom_layers),
       value_net_fun=functools.partial(
           ppo.value_net, bottom_layers=bottom_layers),
       batch_size=batch_size,
       num_optimizer_steps=1,
       random_seed=0)
   self.assertLen(rewards, num_epochs)
   self.assertLen(val_losses, num_epochs)
   self.assertLen(ppo_objectives, num_epochs)
Beispiel #8
0
def common_stax_layers():
    return [stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu]