def Lambda(fn): # pylint: disable=invalid-name """Turn a normal function into a bound, callable Stax layer. Args: fn: a python function with _named_ args (i.e. no *args) and no kwargs. Returns: A callable, 'bound' staxlayer that can be assigned to a python variable and called like a function with other staxlayers as arguments. Like Bind, wherever this value is placed in the stax tree, it will always output the same cached value. """ # fn's args are just symbolic names that we fill with Vars. num_args = len(inspect.getargspec(fn).args) if num_args > 1: bound_args = Vars(num_args) return LambdaBind(stax.serial( stax.parallel(*bound_args), # capture inputs _PlaceholderInputs, # placeholders for input combinators inside fn fn(*bound_args) # feed captured inputs into fn's args )) elif num_args == 1: bound_arg = Var() return LambdaBind(stax.serial( bound_arg, # capture input _PlaceholderInputs, # placeholders for input combinators inside fn fn(bound_arg) # feed captured inputs into fn's args )) # LambdaBind when no args are given: else: return LambdaBind(fn())
def initialize_policy_and_value_nets(rng_key, num_actions, batch_observations_shape): """Setup and initialize the policy and value networks.""" key1, key2 = jax_random.split(rng_key) policy_net_init, policy_net_apply = stax.serial( stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu, stax.Dense(num_actions), stax.Softmax, ) _, policy_net_params = policy_net_init(key1, batch_observations_shape) value_net_init, value_net_apply = stax.serial( stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu, stax.Dense( 1), # 1 since we want to predict reward using value network. ) _, value_net_params = value_net_init(key2, batch_observations_shape) return ((policy_net_params, policy_net_apply), (value_net_params, value_net_apply))
def __call__(self, *args): if len(args) > 1: return stax.serial(stax.parallel(*args), self) elif len(args) == 1: return stax.serial(args[0], self) else: return self
def residual(*layers, **kwargs): """Constructs a residual version of layers, summing input to layers output.""" res = kwargs.get('res', stax.Identity) if len(layers) > 1: return stax.serial(stax.FanOut(2), stax.parallel(stax.serial(*layers), res), stax.FanInSum) elif len(layers) == 1: return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res), stax.FanInSum) else: raise ValueError('Empty residual combinator.')
def MultiHeadedAttention( # pylint: disable=invalid-name feature_depth, num_heads=8, dropout=1.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate - keep probability mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return stax.serial( stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Identity), PureMultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=xavier_uniform()), )
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(2), stax.parallel( stax.serial(stax.Dense(num_actions), stax.Softmax), stax.Dense(1) )]) net_init, net_apply = stax.serial(*layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def policy_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy net function.""" # Use the bottom_layers as the bottom part of the network and just add the # required layers on top of it. if bottom_layers is None: bottom_layers = [] bottom_layers.extend([stax.Dense(num_actions), stax.Softmax]) net_init, net_apply = stax.serial(*bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A value net function.""" del num_actions if bottom_layers is None: bottom_layers = [] bottom_layers.extend([ stax.Dense(1), ]) net_init, net_apply = stax.serial(*bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def repeat(layer, num_repeats): """Repeats layers serially num_repeats times.""" if num_repeats < 1: raise ValueError('Repeat combinator num_repeats must be >= 1.') layers = num_repeats * (layer,) return stax.serial(*layers)