Beispiel #1
0
def residual(*layers, **kwargs):
    """Constructs a residual version of layers, summing input to layers output."""
    res = kwargs.get('res', stax.Identity)
    if len(layers) > 1:
        return stax.serial(stax.FanOut(2),
                           stax.parallel(stax.serial(*layers), res),
                           stax.FanInSum)
    elif len(layers) == 1:
        return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res),
                           stax.FanInSum)
    else:
        raise ValueError('Empty residual combinator.')
Beispiel #2
0
def Lambda(fn):  # pylint: disable=invalid-name
  """Turn a normal function into a bound, callable Stax layer.

  Args:
    fn: a python function with _named_ args (i.e. no *args) and no kwargs.

  Returns:
    A callable, 'bound' staxlayer that can be assigned to a python variable and
    called like a function with other staxlayers as arguments.  Like Bind,
    wherever this value is placed in the stax tree, it will always output the
    same cached value.
  """
  # fn's args are just symbolic names that we fill with Vars.
  num_args = len(inspect.getargspec(fn).args)
  if num_args > 1:
    bound_args = Vars(num_args)
    return LambdaBind(stax.serial(
        stax.parallel(*bound_args),  # capture inputs
        _PlaceholderInputs,  # placeholders for input combinators inside fn
        fn(*bound_args)  # feed captured inputs into fn's args
    ))
  elif num_args == 1:
    bound_arg = Var()
    return LambdaBind(stax.serial(
        bound_arg,  # capture input
        _PlaceholderInputs,  # placeholders for input combinators inside fn
        fn(bound_arg)  # feed captured inputs into fn's args
    ))
  # LambdaBind when no args are given:
  else:
    return LambdaBind(fn())
Beispiel #3
0
 def __call__(self, *args):
   if len(args) > 1:
     return stax.serial(stax.parallel(*args), self)
   elif len(args) == 1:
     return stax.serial(args[0], self)
   else:
     return self
Beispiel #4
0
def MultiHeadedAttention(  # pylint: disable=invalid-name
        feature_depth,
        num_heads=8,
        dropout=1.0,
        mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate - keep probability
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return stax.serial(
        stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Identity),
        PureMultiHeadedAttention(feature_depth,
                                 num_heads=num_heads,
                                 dropout=dropout,
                                 mode=mode),
        stax.Dense(feature_depth, W_init=xavier_uniform()),
    )
Beispiel #5
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  layers = []
  if bottom_layers is not None:
    layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  layers.extend([stax.FanOut(2), stax.parallel(
      stax.serial(stax.Dense(num_actions), stax.Softmax),
      stax.Dense(1)
  )])

  net_init, net_apply = stax.serial(*layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply