def __call__(self, query): """Computes relative position embedding logits. Arguments: query: [batch_size, heads, height, width, dim] Returns: output: [batch_size, heads, height, width, height, width] """ _, _, H, W, _ = query.shape rel_pos_emb_w_shape = (2 * W - 1, self.head_ch) rel_pos_emb_w = self.param( 'rel_pos_emb_w', initializers.normal(stddev=self.head_ch**-0.5), rel_pos_emb_w_shape) rel_pos_emb_h_shape = (2 * H - 1, self.head_ch) rel_pos_emb_h = self.param( 'rel_pos_emb_h', initializers.normal(stddev=self.head_ch**-0.5), rel_pos_emb_h_shape) rel_logits_w = self._relative_logits_1d(query, rel_pos_emb_w) rel_logits_w = rearrange(rel_logits_w, 'b h H I W V -> b h H W I V') rel_logits_h = self._relative_logits_1d( rearrange(query, 'b h H W d -> b h W H d'), rel_pos_emb_h) rel_logits_h = rearrange(rel_logits_h, 'b h W V H I -> b h H W I V') out = rel_logits_h + rel_logits_w return out
def generate_data_01(): batch_size = 8 input_shape = (batch_size, 4) def synth_batches(): while True: images = npr.rand(*input_shape).astype("float32") yield images batches = synth_batches() inputs = next(batches) init_func, predict_func = stax.serial( HomotopyDense(out_dim=4, W_init=glorot_uniform(), b_init=normal()), HomotopyDense(out_dim=1, W_init=glorot_uniform(), b_init=normal()), Sigmoid, ) ae_shape, ae_params = init_func(random.PRNGKey(0), input_shape) # assert ae_shape == input_shape bparam = [np.array([0.0], dtype=np.float64)] logits = predict_func(ae_params, inputs, bparam=bparam[0], activation_func=sigmoid) loss = np.mean( (np.subtract(logits, logits))) + l2_norm(ae_params) + l2_norm(bparam) return inputs, logits, ae_params, bparam, init_func, predict_func
class DeepViTConfig: num_classes: int = 1000 depth: int = 32 mlp_dim: int = 1224 token_dim: int = 64 emb_dim: int = 408 num_heads: int = 12 dim_head: int = 32 shared_theta: bool = True activation_fn: ModuleDef = nn.gelu dtype: jnp.dtype = jnp.float32 precision: Any = jax.lax.Precision.DEFAULT kernel_init: Callable = initializers.xavier_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) posemb_init: Callable = initializers.normal(stddev=0.02)
def init_NN(Q): layers = [] num_layers = len(Q) for i in range(0, num_layers - 2): layers.append( Dense(Q[i + 1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) layers.append(Tanh) layers.append( Dense(Q[-1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) net_init, net_apply = stax.serial(*layers) return net_init, net_apply
def vstate(request): N = 8 hi = nk.hilbert.Spin(1 / 2, N) ma = nk.models.RBM( alpha=1, dtype=float, hidden_bias_init=normal(), visible_bias_init=normal(), ) return nk.vqs.MCState( nk.sampler.MetropolisLocal(hi), ma, )
def FullCovarianceGaussian(conditioning_fn, event_dim, min_scale_diag=1e-4, W_init=glorot_normal(), b_init=normal()): """A conditional Gaussian with full covariance matrix. The distribution mean and covariance are functions of the conditioning set. The covariance is parameterized as the matrix square of the scale, and the scale is parameterized as a lower triangular matrix with positive diagonal and unrestricted off-diagonal elements. The diagonal elements are ensured to be positive by exponentiating them. """ def dist_fn(raw_params): loc = raw_params[:event_dim] raw_scale = raw_params[event_dim:] scale = unflatten_scale(raw_scale, event_dim, min_diag=min_scale_diag) cov = scale @ scale.T return tfd.MultivariateNormalFullCovariance(loc=loc, covariance_matrix=cov) param_dim = event_dim + int((event_dim * (event_dim + 1)) / 2) return ConditionalDistribution(conditioning_fn, dist_fn, event_dim, param_dim, W_init=W_init, b_init=b_init)
def BiasRealModPhase(b_init=normal()): def init_fun(rng, input_shape): assert input_shape[-1] % 2 == 0 input_size = input_shape[-1] // 2 output_shape = input_shape[:-1] k = jax.random.split(rng, 2) br = b_init(k[0], (input_size, )) bj = b_init(k[1], (input_size, )) return output_shape, (br, bj) def apply_fun(params, inputs, **kwargs): br, bj = params xr, xc = jax.numpy.split(inputs, 2, axis=-1) biasr = jax.numpy.dot( (xr + xc)[:, ], br, ) biasj = jax.numpy.dot( (xr - xc)[:, ], bj, ) return 0.5 * biasr + 0.5j * biasj return init_fun, apply_fun
def GRU( hidden_size, W_init=glorot_normal(), b_init=normal(), initial_state_fn=zeros, ): return Rnn(GRUCell(hidden_size, W_init, b_init, initial_state_fn))
def init_fun(rng, input_shape): rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4) # Primary convolutional layer. conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape) # Grouping all possible pairs. kernel_shape = [ filter_shape[0], filter_shape[1], conv_channels, pair_channels ] bias_shape = [1, 1, 1, pair_channels] W_init = glorot_normal(in_axis=2, out_axis=3) b_init = normal(1e-6) k1, k2 = jax.random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) pair_shape = conv_shape[:2] + (15, ) + (pair_channels, ) pair_params = (W, b) # Convolutional block. conv_block_shape, conv_block_params = conv_block_init( block_rng, pair_shape) # Forward pass. serial_shape, serial_params = serial_init(serial_rng, conv_block_shape) params = [conv_params, pair_params, conv_block_params, serial_params] return serial_shape, params
def init_GRU_params(rng, input_shape, W_init=glorot_normal(), b_init=normal()): """ Initialize the GRU layer """ batch_size, hiden_dim, input_data_dim = input_shape #input_data_dim=X,t # H0 = b_init(rng, (batch_size, hiden_dim)) # this is the H0 initial guess, that's why is dependent on batch size # H0 = b_init(rng, (1, hiden_dim)) # this is the H0 initial guess, that's why is dependent on batch size H0 = b_init(rng, (hiden_dim, )) k1, k2, k3 = random.split(rng, num=3) # W takes the X data and U takes the previous hidden state, # then combined by adding together with the bias post the matrix dot reset_W, reset_U, reset_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) k1, k2, k3 = random.split(rng, num=3) update_W, update_U, update_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) k1, k2, k3 = random.split(rng, num=3) out_W, out_U, out_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) GRU_params = ((update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b)) return H0, GRU_params
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=normal(1e-6)): """Layer construction function for a general transposed-convolution layer.""" lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1,) * len(filter_shape) strides = strides or one kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I')) @parametrized def conv_transpose(inputs): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else inputs.shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] bias_shape = tuple( itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec])) kernel = parameter(kernel_shape, kernel_init, 'kernel') bias = parameter(bias_shape, bias_init, 'bias') return lax.conv_transpose(inputs, kernel, strides, padding, dimension_numbers=dimension_numbers) + bias return conv_transpose
class Encoder(nn.Module): num_layers: int inner_num_heads: int outer_num_heads: int inner_expand_ratio: float = 4 outer_expand_ratio: float = 4 attn_dropout_rate: float = 0. dropout_rate: float = 0. activation_fn = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.zeros pos_embed_init: Callable = initializers.normal(stddev=0.02) @nn.compact def __call__(self, patch_embeddings, pixel_embeddings, is_training: bool): for _ in range(self.num_layers): patch_embeddings, pixel_embeddings = EncoderBlock( inner_num_heads=self.inner_num_heads, outer_num_heads=self.outer_num_heads, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(patch_embeddings, pixel_embeddings) output = patch_embeddings return output
class Jastrow(nn.Module): r""" Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`. The W matrix is stored as a non-symmetric matrix, and symmetrized during computation by doing :code:`W = W + W.T` in the computation. """ dtype: DType = jnp.complex128 """The dtype of the weights.""" kernel_init: NNInitFunc = normal() """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) kernel = kernel + kernel.T y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
class Gaussian(nn.Module): r""" Multivariate Gaussain function with mean 0 and parametrised covariance matrix :math:`\Sigma_{ij}`. The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`. The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as non-positive definite matrix A. """ dtype: DType = jnp.float64 """The dtype of the weights.""" kernel_init: NNInitFunc = normal(stddev=1.0) """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) kernel = jnp.dot(kernel.T, kernel) # print(kernel) y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
def LSTMCell( hidden_size, W_init=glorot_normal(), b_init=normal(), h_initial_state_fn=zeros, c_initial_state_fn=zeros, initial_state_seed=0, ): """Layer construction function for an LSTM cell. Formulation: Zaremba, W., 2015, https://arxiv.org/pdf/1409.2329.pdf""" def initial_state(): shape = (hidden_size, ) k1, k2 = jax.random.split(jax.random.PRNGKey(initial_state_seed)) return LSTMState(h_initial_state_fn(k1, shape), c_initial_state_fn(k2, shape)) def init(rng, input_shape): in_dim, out_dim = input_shape[-1] + hidden_size, 4 * hidden_size output_shape = input_shape[:-1] + (hidden_size, ) k1, k2 = jax.random.split(rng) W, b = W_init(k1, (in_dim, out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply(params, inputs, **kwargs): prev_state = kwargs.pop("prev_state", initial_state()) W, b = params xh = jnp.concatenate([inputs, prev_state.h], axis=-1) gated = jnp.matmul(xh, W) + b i, f, o, g = jnp.split(gated, indices_or_sections=4, axis=-1) c = sigmoid(f) * prev_state.c + sigmoid(i) * jnp.tanh(g) h = sigmoid(o) * jnp.tanh(c) return h, LSTMState(h, c) return (init, apply, initial_state)
def __init__(self, n_layers, n_hidden): """For simplicity, have everything have the same dimension.""" super().__init__() self.cells = ModuleTuple( [LSTMCell(n_hidden, n_hidden) for _ in range(n_layers)]) self.c_0s = ParameterTuple( [ParamInit((n_hidden, ), init.normal()) for _ in range(n_layers)])
def test_dense_is_dense_general(self): x = jax.random.normal(random.PRNGKey(0), (5, 3)) dense_module = nn.Dense.partial( features=4, bias=True, bias_init=initializers.normal(), ) y1, _ = dense_module.init(random.PRNGKey(1), x) dg_module = nn.DenseGeneral.partial( features=4, bias=True, bias_init=initializers.normal(), ) y2, _ = dg_module.init(random.PRNGKey(1), x) onp.testing.assert_allclose(y1, y2)
def DeepRNN(cell_type, hidden_dims, W_init=glorot_normal(), b_init=normal()): """Deep RNN cell, a wrapper for a stack of RNNs.""" cells = [cell_type(h, W_init=W_init, b_init=b_init) for h in hidden_dims] def init(key, input_dim): keys = jax.random.split(key, num=len(cells)) in_dims = [input_dim] + hidden_dims[:-1] params = [] for cell, key, dim in zip(cells, keys, in_dims): params.append(cell.init(key, dim)[1]) return [hidden_dims[-1]], params def apply(cells_params, inputs, prev_states, **kwargs): new_states = [] for cell, prev_state, params in zip(cells, prev_states, cells_params): new_state, new_out = cell.apply(params, inputs, prev_state) new_states.append(new_state) inputs = new_out return new_states, new_out def initial_state(): return [cell.initial_state() for cell in cells] return Module(init, apply, initial_state)
def MaskedDense(mask, bias=True, W_init=glorot_normal(), b_init=normal()): """ As in jax.experimental.stax, each layer constructor function returns an (init_fun, apply_fun) pair, where `init_fun` takes an rng_key key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params, inputs, and an rng_key key and applies the layer. :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer. :param bool bias: whether to include bias term. :param array W_init: initialization method for the weights. :param array b_init: initialization method for the bias terms. :return: a (`init_fn`, `update_fn`) pair. """ def init_fun(rng_key, input_shape): k1, k2 = random.split(rng_key) W = W_init(k1, mask.shape) if bias: b = b_init(k2, mask.shape[-1:]) params = (W, b) else: params = W return input_shape[:-1] + mask.shape[-1:], params def apply_fun(params, inputs, **kwargs): if bias: W, b = params return jnp.dot(inputs, W * mask) + b else: W = params return jnp.dot(inputs, W * mask) return init_fun, apply_fun
def vstate(request): M = request.param # keep this a prime number so we get different sizes on every rank... hi = nk.hilbert.Fock(M, 1) ma = nk.models.RBM( alpha=1, dtype=float, hidden_bias_init=normal(), visible_bias_init=normal(), ) return nk.vqs.MCState( nk.sampler.MetropolisLocal(hi), ma, )
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=normal(1e-6)): """Layer construction function for a general transposed-convolution layer.""" lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O')) def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_transpose_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] k1, k2 = random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return lax.conv_transpose( inputs, W, strides, padding, dimension_numbers=dimension_numbers) + b return init_fun, apply_fun
def init_parameters( self, init_fun: Optional[NNInitFunc] = None, *, seed: Optional[PRNGKeyT] = None ): r""" Re-initializes all the parameters with the provided initialization function, defaulting to the normal distribution of standard deviation 0.01. .. warning:: The init function will not change the dtype of the parameters, which is determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION Args: init_fun: a jax initializer such as :ref:`jax.nn.initializers.normal`. Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with the valid dtype and shape. If left unspecified, defaults to :code:`jax.nn.initializers.normal(stddev=0.01)` seed: Optional seed to be used. The seed is synced across all MPI processes. If unspecified, uses a random seed. """ if init_fun is None: init_fun = normal(stddev=0.01) rng = nkjax.PRNGSeq(nkjax.PRNGKey(seed)) def new_pars(par): return jnp.asarray( init_fun(rng.take(1)[0], shape=par.shape, dtype=par.dtype), dtype=par.dtype, ) self.parameters = jax.tree_map(new_pars, self.parameters)
def RNN(hidden_dim, W_init=glorot_normal(), b_init=normal(), activation=jax.nn.relu): """Recurrent Neural Network cell.""" input_to_hidden = Linear(hidden_dim, W_init=W_init) hidden_to_hidden = Affine(hidden_dim, W_init=W_init, b_init=b_init) def init(key, input_dim): output_shape = hidden_dim k1, k2 = jax.random.split(key) _, input_to_hidden_params = input_to_hidden.init(k1, input_dim) _, hidden_to_hidden_params = hidden_to_hidden.init(k2, hidden_dim) return [hidden_dim], RNNParams(input_to_hidden_params, hidden_to_hidden_params) def apply(params, inputs, prev_state, **kwargs): new_hidden_raw = ( input_to_hidden.apply(params.input_to_hidden, inputs) + hidden_to_hidden.apply(params.hidden_to_hidden, prev_state.hidden)) new_hidden = activation(new_hidden_raw) new_state = RNNState(hidden=new_hidden) return new_state, new_hidden def initial_state(): return RNNState(hidden=jnp.zeros([hidden_dim])) return Module(init, apply, initial_state)
def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()): self.bias_init = bias_init self.kernel_init = kernel_init self.out_dim = out_dim
def MLP(layer_dims, W_init=glorot_normal(), b_init=normal(), activation=jax.nn.relu, activate_final=False): """A multi-layered perceptron.""" layers = [] for dim in layer_dims[:-1]: layers.append(Dense(dim, W_init=W_init, b_init=b_init, activation=activation)) if activate_final: layers.append(Dense(layer_dims[-1], W_init=W_init, b_init=b_init, activation=activation)) else: layers.append(Affine(layer_dims[-1], W_init=W_init, b_init=b_init)) def init(key, input_dim): keys = jax.random.split(key, num=len(layer_dims)) input_dims = [input_dim] + layer_dims[:-1] params = [] for layer, key, in_dim in zip(layers, keys, input_dims): params.append(layer.init(key, in_dim)[1]) return layer_dims[-1], MLPParams(params) def apply(params, inputs): for layer, param in zip(layers, params.layer_params): inputs = layer.apply(param, inputs) return inputs return Module(init, apply)
def test_dense_is_dense_general(self): x = jax.random.normal(random.PRNGKey(0), (5, 3)) dense_module = nn.Dense( features=4, use_bias=True, bias_init=initializers.normal(), ) y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x) dg_module = nn.DenseGeneral( features=4, use_bias=True, bias_init=initializers.normal(), ) y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x) np.testing.assert_allclose(y1, y2)
def ConcatSquashLinear(out_dim, W_init=he_normal(), b_init=normal()): """ y = Sigmoid(at + c)(Wx + b) + dt. Note: he_normal only takes multi dim. """ def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2, k3, k4, k5 = random.split(rng, 5) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) w_t, w_tb = b_init(k3, (out_dim, )), b_init(k4, (out_dim, )) b_t = b_init(k5, (out_dim, )) return output_shape, (W, b, w_t, w_tb, b_t) def apply_fun(params, inputs, **kwargs): x, t = inputs W, b, w_t, w_tb, b_t = params # (W.xtt + b) * out = np.dot(x, W) + b # sigmoid(a.t + c) + out *= jax.nn.sigmoid(w_t * t + w_tb) # d.t out += b_t * t return (out, t) return init_fun, apply_fun
def IgnoreConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def apply_fun(params, inputs, **kwargs): x, t = inputs out = apply_fun_wrapped(params, x, **kwargs) return (out, t) return init_fun_wrapped, apply_fun_wrapped
def DenseVMAP(out_dim, W_init=glorot_normal(), b_init=normal()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = jax_random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return jnp.dot(inputs, W) + b apply_fun_vmap = vmap(apply_fun, (None, 0)) return init_fun, apply_fun_vmap #model_params = [ # [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)), # ConvTranspose(16, (6, 6), padding='VALID'), LayerNormConv(), Relu, # 10x10 # ConvTranspose(8, (6, 6), padding='VALID'), LayerNormConv(), Relu, # 15x15 # ConvTranspose(1, (6, 6), padding='VALID'), LayerNormConv(), Reshape((400,))], # 20x20 # [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)), # Conv(16, (4, 4), padding='same'), LayerNormConv(), Relu, # Conv(8, (3, 3), padding='same'), LayerNormConv(), Relu, # Conv(1, (3, 3), padding='same'), LayerNormConv(), Reshape((25,)), # 2 from Conv before # Dense(21)] # ]
def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))() bias = Parameter(lambda key: bias_init(key, (out_dim,)))() return np.dot(inputs, kernel) + bias return dense