def residual(*layers, **kwargs): """Constructs a residual version of layers, summing input to layers output.""" res = kwargs.get('res', stax.Identity) if len(layers) > 1: return stax.serial(stax.FanOut(2), stax.parallel(stax.serial(*layers), res), stax.FanInSum) elif len(layers) == 1: return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res), stax.FanInSum) else: raise ValueError('Empty residual combinator.')
def Lambda(fn): # pylint: disable=invalid-name """Turn a normal function into a bound, callable Stax layer. Args: fn: a python function with _named_ args (i.e. no *args) and no kwargs. Returns: A callable, 'bound' staxlayer that can be assigned to a python variable and called like a function with other staxlayers as arguments. Like Bind, wherever this value is placed in the stax tree, it will always output the same cached value. """ # fn's args are just symbolic names that we fill with Vars. num_args = len(inspect.getargspec(fn).args) if num_args > 1: bound_args = Vars(num_args) return LambdaBind(stax.serial( stax.parallel(*bound_args), # capture inputs _PlaceholderInputs, # placeholders for input combinators inside fn fn(*bound_args) # feed captured inputs into fn's args )) elif num_args == 1: bound_arg = Var() return LambdaBind(stax.serial( bound_arg, # capture input _PlaceholderInputs, # placeholders for input combinators inside fn fn(bound_arg) # feed captured inputs into fn's args )) # LambdaBind when no args are given: else: return LambdaBind(fn())
def __call__(self, *args): if len(args) > 1: return stax.serial(stax.parallel(*args), self) elif len(args) == 1: return stax.serial(args[0], self) else: return self
def MultiHeadedAttention( # pylint: disable=invalid-name feature_depth, num_heads=8, dropout=1.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate - keep probability mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return stax.serial( stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Identity), PureMultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=xavier_uniform()), )
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(2), stax.parallel( stax.serial(stax.Dense(num_actions), stax.Softmax), stax.Dense(1) )]) net_init, net_apply = stax.serial(*layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply