Пример #1
0
def Lambda(fn):  # pylint: disable=invalid-name
  """Turn a normal function into a bound, callable Stax layer.

  Args:
    fn: a python function with _named_ args (i.e. no *args) and no kwargs.

  Returns:
    A callable, 'bound' staxlayer that can be assigned to a python variable and
    called like a function with other staxlayers as arguments.  Like Bind,
    wherever this value is placed in the stax tree, it will always output the
    same cached value.
  """
  # fn's args are just symbolic names that we fill with Vars.
  num_args = len(inspect.getargspec(fn).args)
  if num_args > 1:
    bound_args = Vars(num_args)
    return LambdaBind(stax.serial(
        stax.parallel(*bound_args),  # capture inputs
        _PlaceholderInputs,  # placeholders for input combinators inside fn
        fn(*bound_args)  # feed captured inputs into fn's args
    ))
  elif num_args == 1:
    bound_arg = Var()
    return LambdaBind(stax.serial(
        bound_arg,  # capture input
        _PlaceholderInputs,  # placeholders for input combinators inside fn
        fn(bound_arg)  # feed captured inputs into fn's args
    ))
  # LambdaBind when no args are given:
  else:
    return LambdaBind(fn())
Пример #2
0
def initialize_policy_and_value_nets(rng_key, num_actions,
                                     batch_observations_shape):
    """Setup and initialize the policy and value networks."""
    key1, key2 = jax_random.split(rng_key)

    policy_net_init, policy_net_apply = stax.serial(
        stax.Dense(16),
        stax.Relu,
        stax.Dense(4),
        stax.Relu,
        stax.Dense(num_actions),
        stax.Softmax,
    )

    _, policy_net_params = policy_net_init(key1, batch_observations_shape)

    value_net_init, value_net_apply = stax.serial(
        stax.Dense(16),
        stax.Relu,
        stax.Dense(4),
        stax.Relu,
        stax.Dense(
            1),  # 1 since we want to predict reward using value network.
    )

    _, value_net_params = value_net_init(key2, batch_observations_shape)

    return ((policy_net_params, policy_net_apply), (value_net_params,
                                                    value_net_apply))
Пример #3
0
 def __call__(self, *args):
   if len(args) > 1:
     return stax.serial(stax.parallel(*args), self)
   elif len(args) == 1:
     return stax.serial(args[0], self)
   else:
     return self
Пример #4
0
def residual(*layers, **kwargs):
    """Constructs a residual version of layers, summing input to layers output."""
    res = kwargs.get('res', stax.Identity)
    if len(layers) > 1:
        return stax.serial(stax.FanOut(2),
                           stax.parallel(stax.serial(*layers), res),
                           stax.FanInSum)
    elif len(layers) == 1:
        return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res),
                           stax.FanInSum)
    else:
        raise ValueError('Empty residual combinator.')
Пример #5
0
def MultiHeadedAttention(  # pylint: disable=invalid-name
        feature_depth,
        num_heads=8,
        dropout=1.0,
        mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate - keep probability
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return stax.serial(
        stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Identity),
        PureMultiHeadedAttention(feature_depth,
                                 num_heads=num_heads,
                                 dropout=dropout,
                                 mode=mode),
        stax.Dense(feature_depth, W_init=xavier_uniform()),
    )
Пример #6
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  layers = []
  if bottom_layers is not None:
    layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  layers.extend([stax.FanOut(2), stax.parallel(
      stax.serial(stax.Dense(num_actions), stax.Softmax),
      stax.Dense(1)
  )])

  net_init, net_apply = stax.serial(*layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Пример #7
0
def policy_net(rng_key,
               batch_observations_shape,
               num_actions,
               bottom_layers=None):
    """A policy net function."""
    # Use the bottom_layers as the bottom part of the network and just add the
    # required layers on top of it.
    if bottom_layers is None:
        bottom_layers = []
    bottom_layers.extend([stax.Dense(num_actions), stax.Softmax])

    net_init, net_apply = stax.serial(*bottom_layers)

    _, net_params = net_init(rng_key, batch_observations_shape)
    return net_params, net_apply
Пример #8
0
def value_net(rng_key,
              batch_observations_shape,
              num_actions,
              bottom_layers=None):
    """A value net function."""
    del num_actions

    if bottom_layers is None:
        bottom_layers = []
    bottom_layers.extend([
        stax.Dense(1),
    ])

    net_init, net_apply = stax.serial(*bottom_layers)

    _, net_params = net_init(rng_key, batch_observations_shape)
    return net_params, net_apply
Пример #9
0
def repeat(layer, num_repeats):
  """Repeats layers serially num_repeats times."""
  if num_repeats < 1:
    raise ValueError('Repeat combinator num_repeats must be >= 1.')
  layers = num_repeats * (layer,)
  return stax.serial(*layers)