def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_gain=1.0, W_init=stax.randn(1.0), b_gain=0.0, b_init=stax.randn(1.0)): """Layer construction function for a general convolution layer. Uses jax.experimental.stax.GeneralConv as a base. """ lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides, padding, W_init, b_init) def apply_fun(params, inputs, **kwargs): W, b = params norm = inputs.shape[lhs_spec.index('C')] norm *= functools.reduce(op.mul, filter_shape) norm = W_gain / np.sqrt(norm) return norm * lax.conv_general_dilated(inputs, W, strides, padding, one, one, dimension_numbers) + b_gain * b return init_fun, apply_fun
def decoder(hidden_dim: int, out_dim: int) -> Tuple[Callable, Callable]: return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def decoder(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel(stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)), )
def __init__(self, breakpoints): self.N_OUTPUTS = len(breakpoints) + 1 self.breakpoints = np.hstack(([0], breakpoints, [np.inf])) self.terminal_layer = [ stax.Dense(self.N_OUTPUTS, W_init=stax.randn(1e-7), b_init=stax.randn(1e-7)), stax.Exp, ]
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]: return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial( stax.Dense(z_dim, W_init=stax.randn()), stax.Exp, ), ), )
def MaskedDense(mask, bias=True, W_init=glorot(), b_init=randn()): """ As in jax.experimental.stax, each layer constructor function returns an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params, inputs, and an rng key and applies the layer. :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer. :param bool bias: whether to include bias term. :param array W_init: initialization method for the weights. :param array b_init: initialization method for the bias terms. :return: a (`init_fn`, `update_fn`) pair. """ def init_fun(rng, input_shape): k1, k2 = random.split(rng) W = W_init(k1, mask.shape) if bias: b = b_init(k2, mask.shape[-1:]) params = (W, b) else: params = W return input_shape[:-1] + mask.shape[-1:], params def apply_fun(params, inputs, **kwargs): if bias: W, b = params return np.dot(inputs, W * mask) + b else: W = params return np.dot(inputs, W * mask) return init_fun, apply_fun
def new(input_size, output_size, hidden_layers, key): _randn_fn = randn() def vector_init(shape): if isinstance(shape, int): shape = (shape, ) nonlocal key key, rng = random.split(key) return _randn_fn(rng, shape) _glorot_fn = glorot() def matrix_init(shape): nonlocal key key, rng = random.split(key) return _glorot_fn(rng, shape) input_state = vector_init(input_size) hidden_states = [] for size in hidden_layers: hidden_states.append(vector_init(size)) output_states = vector_init(output_size) states = [input_state, *hidden_states, output_states] # weights fwd_weights, bwd_weights = [], [] for prev, post in zip(states[:-1], states[1:]): fwd_weights.append(matrix_init((prev.shape[0], post.shape[0]))) bwd_weights.append(matrix_init((post.shape[0], prev.shape[0]))) return LayeredNet(states, [*fwd_weights, *bwd_weights])
def initialize(cls, key, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): in_shape = in_spec.shape shapes, inits, (strides, padding, one) = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) info = ConvInfo(strides, padding, one, use_bias) _, kernel_shape, bias_shape = shapes kernel_init, bias_init = inits k1, k2 = random.split(key) if use_bias: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), base.create_parameter(k2, bias_shape, init=bias_init), ) else: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), None) return base.LayerParams(params, info=info)
def Dense(out_dim, W_gain=1.0, W_init=stax.randn(1.0), b_gain=0.0, b_init=stax.randn(1.0)): """Layer constructor function for a dense (fully-connected) layer. Uses jax.experimental.stax.Dense as a base. """ init_fun, _ = stax.Dense(out_dim, W_init, b_init) def apply_fun(params, inputs, **kwargs): W, b = params norm = W_gain / np.sqrt(inputs.shape[-1]) return norm * np.dot(inputs, W) + b_gain * b return init_fun, apply_fun
def decoder(hidden_dim, out_dim): """Defines the decoder, i.e., the network taking us from latent variables back to observations (or at least observation space). Network structure: z -> dense layer of hidden_dim with softplus activation -> dense layer of out_dim with sigmoid activation :param hidden_dim: number of nodes in the hidden layer :param out_dim: dimensions of the observations :return: (init_fun, apply_fun) pair of the decoder: (decoder_init, decode) """ return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def Dense(name, out_dim, W_init=stax.glorot(), b_init=stax.randn()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, example_input): input_shape = example_input.shape k1, k2 = random.split(rng) W, b = W_init(k1, (out_dim, input_shape[-1])), b_init(k2, (out_dim, )) return W, b def apply_fun(params, inputs): W, b = params return np.dot(W, inputs) + b return core.Layer(name, init_fun, apply_fun).bind
def encoder(hidden_dim, z_dim): """Defines the encoder, i.e., the network taking us from observations to (a distribution of) latent variables. z is following a normal distribution, thus needs mean and variance. Network structure: x -> dense layer of hidden_dim with softplus activation --> dense layer of z_dim ( = means/loc of z) |-> dense layer of z_dim with (elementwise) exp() as activation func ( = variance of z ) (note: the exp() as activation function serves solely to ensure positivity of the variance) :param hidden_dim: number of nodes in the hidden layer :param z_dim: dimension of the latent variable z :return: (init_fun, apply_fun) pair of the encoder: (encoder_init, encode) """ return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)), )
def spec(cls, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): del use_bias in_shape = in_spec.shape shapes, _, _ = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) return state.Shape(shapes[0], dtype=in_spec.dtype)
def conv_info(in_shape, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), transpose=False): """Returns parameters and output shape information given input shapes.""" # Essentially the `stax` implementation if len(in_shape) != 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS one = (1, ) * len(filter_shape) strides = strides or one kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'), rhs_spec.index('I')) filter_shape_iter = iter(filter_shape) kernel_shape = tuple([ out_chan if c == 'O' else in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ]) if transpose: out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) else: out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) out_shape = out_shape[1:] shapes = (out_shape, kernel_shape, bias_shape) inits = (kernel_init, bias_init) return shapes, inits, (strides, padding, one)
def MaskedDense(out_dim, mask, W_init=glorot(), b_init=randn()): """ As in jax.experimental.stax, each layer constructor function returns an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params, inputs, and an rng key and applies the layer. :param int out_dim: Number of output dimensions. :param array mask: Mask applied to the weights of the layer. :param array W_init: initialization method for the weights. :param array b_init: initialization method for the bias terms. :return: a (`init_fn`, `update_fn`) pair. """ def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return np.dot(inputs, W * mask) + b return init_fun, apply_fun
def __init__(self, out_dim, W_init=stax.glorot(), b_init=stax.randn()): super(Dense, self).__init__() self.out_dim = out_dim self.W_init = W_init self.b_init = b_init
def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape)
def testRandnInitShape(self, shape): out = stax.randn()(shape) self.assertEqual(out.shape, shape)
def __init__(self): self.terminal_layer = [ stax.Dense(self.N_OUTPUTS, W_init=stax.randn(1e-10), b_init=stax.randn(1e-10)) ]