def init_NN(Q): layers = [] num_layers = len(Q) for i in range(0, num_layers - 2): layers.append( Dense(Q[i + 1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) layers.append(Tanh) layers.append( Dense(Q[-1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) net_init, net_apply = stax.serial(*layers) return net_init, net_apply
def Dense(out_dim, W_init=glorot_normal(), b_init=glorot_normal()): """(Custom) Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = random.split(rng) # the below line is different from the original jax's Dense W, b = W_init(k1, (input_shape[-1], out_dim)), b_init( k2, (input_shape[-1], out_dim)) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return jnp.dot(inputs, W) + b return init_fun, apply_fun
def ConditionedSqueezeExcitation(ratio=4, W_cond_init=glorot_normal(), W1_init=glorot_normal(), W2_init=glorot_normal(), name='unnamed'): # language=rst """ Like squeeze excitation, but has an extra input to help form W PURPOSE IS TO FIGURE OUT WHICH FEATURE MAPS MATTER GIVEN A CONDITIONER :param ratio: How to reduce the number of channels for the FC layer """ def init_fun(key, input_shape): (H, W, C), (K,) = input_shape k1, k2, k3 = random.split(key, 3) # Will be shrinking the conditioner down to the size of the number of channels W_cond = W_cond_init(k1, (C, K)) # Going to be concatenating the conditioner C_concat = C + C assert C_concat%ratio == 0 # Create the parameters for the squeeze and excite W1 = W1_init(k2, (C_concat//ratio, C_concat)) W2 = W2_init(k3, (C, C_concat//ratio)) output_shape = (H, W, C) params = (W_cond, W1, W2) state = () return name, output_shape, params, state def apply_fun(params, state, inputs, **kwargs): W_cond, W1, W2 = params inputs, cond = inputs # Apply the SE transforms x = np.mean(inputs, axis=(-2, -3)) x = np.concatenate([x, np.dot(cond, W_cond.T)], axis=-1) x = np.dot(x, W1.T) x = jax.nn.relu(x) x = np.dot(x, W2.T) x = jax.nn.sigmoid(x) # Scale the input if(x.ndim == 3): out = inputs*x[None, None,:] else: out = inputs*x[:,None,None,:] return out, state return init_fun, apply_fun
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=normal(1e-6)): """Layer construction function for a general transposed-convolution layer.""" lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O')) def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_transpose_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] k1, k2 = random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return lax.conv_transpose( inputs, W, strides, padding, dimension_numbers=dimension_numbers) + b return init_fun, apply_fun
def TallAffineDiagCov(flow, out_dim, n_training_importance_samples=32, A_init=glorot_normal(), b_init=normal(), name='unnamed'): """ Affine function to go up a dimension Args: """ _init_fun, _forward, _inverse = flow def init_fun(key, input_shape, condition_shape): x_shape = input_shape output_shape = x_shape[:-1] + (out_dim,) keys = random.split(key, 3) x_dim = x_shape[-1] z_dim = out_dim A = A_init(keys[0], (x_shape[-1], out_dim)) b = b_init(keys[1], (x_shape[-1],)) flow_name, flow_output_shape, flow_params, flow_state = _init_fun(keys[2], output_shape, condition_shape) log_diag_cov = jnp.ones(input_shape[-1])*0.0 params = ((A, b, log_diag_cov), flow_params) state = ((), flow_state) return (name, flow_name), flow_output_shape, params, state def forward(params, state, log_px, x, condition, **kwargs): ((A, b, log_diag_cov), flow_params) = params _, flow_state = state # Get the terms to compute and sample from the posterior sigma = kwargs.get('sigma', 1.0) z, log_hx, sigma_ATA_chol = tall_affine_posterior_diag_cov(x, b, A, log_diag_cov, sigma) # Importance sample from N(z|\mu(x),\Sigma(x)) and compile the results log_pz, z, updated_flow_states = importance_sample_prior(_forward, flow_params, flow_state, z, condition, sigma_ATA_chol, n_training_importance_samples, **kwargs) # Compute the final estimate of the integral log_px = log_px + log_pz + log_hx return log_px, z, ((), updated_flow_states) def inverse(params, state, log_pz, z, condition, **kwargs): ((A, b, log_diag_cov), flow_params) = params _, flow_state = state log_pz, z, updated_state = _inverse(flow_params, flow_state, log_pz, z, condition, **kwargs) # Compute Az + b # Don't need to sample because we already sampled from p(z)!!!! x = jnp.dot(z, A.T) + b key = kwargs.pop('key', None) if(key is not None): sigma = kwargs.get('sigma', 1.0) noise = random.normal(key, x.shape)*sigma x += noise*jnp.exp(0.5*log_diag_cov) # Compute N(x|Az + b, \Sigma). This is just the log partition function. log_px = - 0.5*jnp.sum(log_diag_cov) - 0.5*x.shape[-1]*jnp.log(2*jnp.pi) return log_pz + log_px, x, ((), updated_state) return init_fun, forward, inverse
def AAEmbedding(embedding_dims: int = 10, E_init=glorot_normal(), **kwargs): """ Initial n-dimensional embedding of each amino-acid """ def init_fun(rng, input_shape): """ Generates the inital AA embedding matrix. `input_shape`: one-hot encoded AA sequence -> (n_aa, n_unique_aa) `output_dims`: embedded sequence -> (n_aa, embedding_dims) `emb_matrix`: embedding matrix -> (n_unique_aa, embedding_dims) """ k1, _ = random.split(rng) emb_matrix = E_init(k1, (input_shape[1], embedding_dims)) output_dims = (-1, embedding_dims) return output_dims, emb_matrix def apply_fun(params, inputs, **kwargs): """ Embed a single AA sequence """ emb_matrix = params # (n_aa, n_unique_aa) * (n_unique_aa, embedding_dims) => (n_aa, embedding_dims) # noqa: E501 return np.matmul(inputs, emb_matrix) return init_fun, apply_fun
def LSTMCell( hidden_size, W_init=glorot_normal(), b_init=normal(), h_initial_state_fn=zeros, c_initial_state_fn=zeros, initial_state_seed=0, ): """Layer construction function for an LSTM cell. Formulation: Zaremba, W., 2015, https://arxiv.org/pdf/1409.2329.pdf""" def initial_state(): shape = (hidden_size, ) k1, k2 = jax.random.split(jax.random.PRNGKey(initial_state_seed)) return LSTMState(h_initial_state_fn(k1, shape), c_initial_state_fn(k2, shape)) def init(rng, input_shape): in_dim, out_dim = input_shape[-1] + hidden_size, 4 * hidden_size output_shape = input_shape[:-1] + (hidden_size, ) k1, k2 = jax.random.split(rng) W, b = W_init(k1, (in_dim, out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply(params, inputs, **kwargs): prev_state = kwargs.pop("prev_state", initial_state()) W, b = params xh = jnp.concatenate([inputs, prev_state.h], axis=-1) gated = jnp.matmul(xh, W) + b i, f, o, g = jnp.split(gated, indices_or_sections=4, axis=-1) c = sigmoid(f) * prev_state.c + sigmoid(i) * jnp.tanh(g) h = sigmoid(o) * jnp.tanh(c) return h, LSTMState(h, c) return (init, apply, initial_state)
def init_fun(rng, input_shape): rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4) # Primary convolutional layer. conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape) # Grouping all possible pairs. kernel_shape = [ filter_shape[0], filter_shape[1], conv_channels, pair_channels ] bias_shape = [1, 1, 1, pair_channels] W_init = glorot_normal(in_axis=2, out_axis=3) b_init = normal(1e-6) k1, k2 = jax.random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) pair_shape = conv_shape[:2] + (15, ) + (pair_channels, ) pair_params = (W, b) # Convolutional block. conv_block_shape, conv_block_params = conv_block_init( block_rng, pair_shape) # Forward pass. serial_shape, serial_params = serial_init(serial_rng, conv_block_shape) params = [conv_params, pair_params, conv_block_params, serial_params] return serial_shape, params
def GRU( hidden_size, W_init=glorot_normal(), b_init=normal(), initial_state_fn=zeros, ): return Rnn(GRUCell(hidden_size, W_init, b_init, initial_state_fn))
def DenseVMAP(out_dim, W_init=glorot_normal(), b_init=normal()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = jax_random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return jnp.dot(inputs, W) + b apply_fun_vmap = vmap(apply_fun, (None, 0)) return init_fun, apply_fun_vmap #model_params = [ # [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)), # ConvTranspose(16, (6, 6), padding='VALID'), LayerNormConv(), Relu, # 10x10 # ConvTranspose(8, (6, 6), padding='VALID'), LayerNormConv(), Relu, # 15x15 # ConvTranspose(1, (6, 6), padding='VALID'), LayerNormConv(), Reshape((400,))], # 20x20 # [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)), # Conv(16, (4, 4), padding='same'), LayerNormConv(), Relu, # Conv(8, (3, 3), padding='same'), LayerNormConv(), Relu, # Conv(1, (3, 3), padding='same'), LayerNormConv(), Reshape((25,)), # 2 from Conv before # Dense(21)] # ]
def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()): self.bias_init = bias_init self.kernel_init = kernel_init self.out_dim = out_dim
def init_GRU_params(rng, input_shape, W_init=glorot_normal(), b_init=normal()): """ Initialize the GRU layer """ batch_size, hiden_dim, input_data_dim = input_shape #input_data_dim=X,t # H0 = b_init(rng, (batch_size, hiden_dim)) # this is the H0 initial guess, that's why is dependent on batch size # H0 = b_init(rng, (1, hiden_dim)) # this is the H0 initial guess, that's why is dependent on batch size H0 = b_init(rng, (hiden_dim, )) k1, k2, k3 = random.split(rng, num=3) # W takes the X data and U takes the previous hidden state, # then combined by adding together with the bias post the matrix dot reset_W, reset_U, reset_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) k1, k2, k3 = random.split(rng, num=3) update_W, update_U, update_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) k1, k2, k3 = random.split(rng, num=3) out_W, out_U, out_b = ( W_init(k1, (input_data_dim, hiden_dim)), W_init(k2, (hiden_dim, hiden_dim)), b_init(k3, (hiden_dim, )), ) GRU_params = ((update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b)) return H0, GRU_params
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=normal(1e-6)): """Layer construction function for a general transposed-convolution layer.""" lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1,) * len(filter_shape) strides = strides or one kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I')) @parametrized def conv_transpose(inputs): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else inputs.shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] bias_shape = tuple( itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec])) kernel = parameter(kernel_shape, kernel_init, 'kernel') bias = parameter(bias_shape, bias_init, 'bias') return lax.conv_transpose(inputs, kernel, strides, padding, dimension_numbers=dimension_numbers) + bias return conv_transpose
def DeepRNN(cell_type, hidden_dims, W_init=glorot_normal(), b_init=normal()): """Deep RNN cell, a wrapper for a stack of RNNs.""" cells = [cell_type(h, W_init=W_init, b_init=b_init) for h in hidden_dims] def init(key, input_dim): keys = jax.random.split(key, num=len(cells)) in_dims = [input_dim] + hidden_dims[:-1] params = [] for cell, key, dim in zip(cells, keys, in_dims): params.append(cell.init(key, dim)[1]) return [hidden_dims[-1]], params def apply(cells_params, inputs, prev_states, **kwargs): new_states = [] for cell, prev_state, params in zip(cells, prev_states, cells_params): new_state, new_out = cell.apply(params, inputs, prev_state) new_states.append(new_state) inputs = new_out return new_states, new_out def initial_state(): return [cell.initial_state() for cell in cells] return Module(init, apply, initial_state)
def MaskedDense(mask, bias=True, W_init=glorot_normal(), b_init=normal()): """ As in jax.experimental.stax, each layer constructor function returns an (init_fun, apply_fun) pair, where `init_fun` takes an rng_key key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params, inputs, and an rng_key key and applies the layer. :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer. :param bool bias: whether to include bias term. :param array W_init: initialization method for the weights. :param array b_init: initialization method for the bias terms. :return: a (`init_fn`, `update_fn`) pair. """ def init_fun(rng_key, input_shape): k1, k2 = random.split(rng_key) W = W_init(k1, mask.shape) if bias: b = b_init(k2, mask.shape[-1:]) params = (W, b) else: params = W return input_shape[:-1] + mask.shape[-1:], params def apply_fun(params, inputs, **kwargs): if bias: W, b = params return jnp.dot(inputs, W * mask) + b else: W = params return jnp.dot(inputs, W * mask) return init_fun, apply_fun
def FullCovarianceGaussian(conditioning_fn, event_dim, min_scale_diag=1e-4, W_init=glorot_normal(), b_init=normal()): """A conditional Gaussian with full covariance matrix. The distribution mean and covariance are functions of the conditioning set. The covariance is parameterized as the matrix square of the scale, and the scale is parameterized as a lower triangular matrix with positive diagonal and unrestricted off-diagonal elements. The diagonal elements are ensured to be positive by exponentiating them. """ def dist_fn(raw_params): loc = raw_params[:event_dim] raw_scale = raw_params[event_dim:] scale = unflatten_scale(raw_scale, event_dim, min_diag=min_scale_diag) cov = scale @ scale.T return tfd.MultivariateNormalFullCovariance(loc=loc, covariance_matrix=cov) param_dim = event_dim + int((event_dim * (event_dim + 1)) / 2) return ConditionalDistribution(conditioning_fn, dist_fn, event_dim, param_dim, W_init=W_init, b_init=b_init)
def RNN(hidden_dim, W_init=glorot_normal(), b_init=normal(), activation=jax.nn.relu): """Recurrent Neural Network cell.""" input_to_hidden = Linear(hidden_dim, W_init=W_init) hidden_to_hidden = Affine(hidden_dim, W_init=W_init, b_init=b_init) def init(key, input_dim): output_shape = hidden_dim k1, k2 = jax.random.split(key) _, input_to_hidden_params = input_to_hidden.init(k1, input_dim) _, hidden_to_hidden_params = hidden_to_hidden.init(k2, hidden_dim) return [hidden_dim], RNNParams(input_to_hidden_params, hidden_to_hidden_params) def apply(params, inputs, prev_state, **kwargs): new_hidden_raw = ( input_to_hidden.apply(params.input_to_hidden, inputs) + hidden_to_hidden.apply(params.hidden_to_hidden, prev_state.hidden)) new_hidden = activation(new_hidden_raw) new_state = RNNState(hidden=new_hidden) return new_state, new_hidden def initial_state(): return RNNState(hidden=jnp.zeros([hidden_dim])) return Module(init, apply, initial_state)
def MLP(layer_dims, W_init=glorot_normal(), b_init=normal(), activation=jax.nn.relu, activate_final=False): """A multi-layered perceptron.""" layers = [] for dim in layer_dims[:-1]: layers.append(Dense(dim, W_init=W_init, b_init=b_init, activation=activation)) if activate_final: layers.append(Dense(layer_dims[-1], W_init=W_init, b_init=b_init, activation=activation)) else: layers.append(Affine(layer_dims[-1], W_init=W_init, b_init=b_init)) def init(key, input_dim): keys = jax.random.split(key, num=len(layer_dims)) input_dims = [input_dim] + layer_dims[:-1] params = [] for layer, key, in_dim in zip(layers, keys, input_dims): params.append(layer.init(key, in_dim)[1]) return layer_dims[-1], MLPParams(params) def apply(params, inputs): for layer, param in zip(layers, params.layer_params): inputs = layer.apply(param, inputs) return inputs return Module(init, apply)
def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))() bias = Parameter(lambda key: bias_init(key, (out_dim,)))() return np.dot(inputs, kernel) + bias return dense
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return np.dot(inputs, kernel) + bias return dense
def DensePurificationComplex(out_pure, out_mix, use_hidden_bias=True, W_init=glorot_normal(), b_init=normal()): """Layer constructor function for a complex purification layer.""" def init_fun(rng, input_shape): assert input_shape[-1] % 2 == 0 output_shape = input_shape[:-1] + (2 * out_pure + out_mix, ) k = jax.random.split(rng, 7) input_size = input_shape[-1] // 2 # Weights for the pure part Wr, Wi = ( W_init(k[0], (input_size, out_pure)), W_init(k[1], (input_size, out_pure)), ) # Weights for the mixing part Vr, Vi = ( W_init(k[2], (input_size, out_mix)), W_init(k[3], (input_size, out_mix)), ) if use_hidden_bias: br, bi = (b_init(k[4], (out_pure, )), b_init(k[5], (out_pure, ))) cr = b_init(k[6], (out_mix, )) return output_shape, (Wr, Wi, Vr, Vi, br, bi, cr) else: return output_shape, (Wr, Wi, Vr, Vi) @jax.jit def apply_fun(params, inputs, **kwargs): if use_hidden_bias: Wr, Wi, Vr, Vi, br, bi, cr = params else: Wr, Wi, Vr, Vi = params xr, xc = jax.numpy.split(inputs, 2, axis=-1) thetar = jax.numpy.dot(xr[:, ], (Wr + 1.0j * Wi)) thetac = jax.numpy.dot(xc[:, ], (Wr - 1.0j * Wi)) thetam = jax.numpy.dot(xr[:, ], (Vr + 1.0j * Vi)) thetam += jax.numpy.dot(xc[:, ], (Vr - 1.0j * Vi)) if use_hidden_bias: thetar += br + 1.0j * bi thetac += br - 1.0j * bi thetam += 2 * cr return jax.numpy.hstack((thetar, thetam, thetac)) return init_fun, apply_fun
def init_param(rng, input_units, feature_size, label_size, label_units, hidden_units): init = glorot_normal() k1, k2, k3, k4, k5 = npr.split(rng, num=5) A_1 = init_dense(k1, (input_units, hidden_units)) A_2 = init_dense(k2, (hidden_units, feature_size)) B = init(k3, (feature_size, label_size)) C_1 = init_dense(k4, (label_size, label_units)) c_2 = init(k5, (label_units, 1)) return Param(A_1, A_2, B, C_1, c_2)
def Jastrow(W_init=glorot_normal()): def init_fun(rng, input_shape): N = input_shape[-1] return input_shape[:-1], W_init(rng, (N, N)) def apply_fun(W, x, **kwargs): return jax.vmap(lambda W, x: jax.numpy.einsum("i,ij,j", x, W, x), in_axes=(None, 0))(W, x) return init_fun, apply_fun
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): """Layer constructor function for a dense (fully-connected) layer.""" @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel') bias = parameter((out_dim,), bias_init, name='bias') return np.dot(inputs, kernel) + bias return dense
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim,) k1, k2 = random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,)) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): W, b = params return np.dot(inputs, W) + b return init_fun, apply_fun
def Linear(out_dim, W_init=glorot_normal()): """A Linear layer (no bias).""" def init(key, input_dim): W = W_init(key, (input_dim, out_dim)) return out_dim, LinearParams(W) def apply(params, inputs): return jnp.dot(inputs, params.W) return Module(init, apply)
def DenseNoBias(out_dim, W_init=glorot_normal()): """Layer constructor function for a dense (fully-connected) layer but without any bias term.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) W = W_init(rng, (input_shape[-1], out_dim)) return output_shape, W def apply_fun(W, inputs, **_kwargs): return inputs @ W return init_fun, apply_fun
def Affine(out_dim, W_init=glorot_normal(), b_init=normal()): """An affine layer.""" def init(key, input_dim): k1, k2 = jax.random.split(key) W, b = W_init(k1, (input_dim, out_dim)), b_init(k2, (out_dim,)) return out_dim, AffineParams(W, b) def apply(params, inputs): return jnp.dot(inputs, params.W) + params.b return Module(init, apply)
def Dense(out_dim, W_init=glorot_normal(), b_init=normal(), activation=jax.nn.relu): """A single-layer MLP (Affine layer with an activation).""" affine = Affine(out_dim, W_init=W_init, b_init=b_init) def init(key, input_dim): return affine.init(key, input_dim) def apply(params, inputs): return activation(affine.apply(params, inputs)) return Module(init, apply)
def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=normal(1e-6), bias=True): """Layer construction function for a general convolution layer.""" lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O')) def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) k1, k2 = random.split(rng) W = W_init(k1, kernel_shape) if bias: b = b_init(k2, bias_shape) return output_shape, (W, b) else: return output_shape, (W) def apply_fun(params, inputs, **kwargs): if bias: W, b = params else: W = params batchdim = True if inputs.ndim == 3: batchdim = False inputs = np.expand_dims(inputs, 0) out = lax.conv_general_dilated(inputs, W, strides, padding, one, one, dimension_numbers) out = out + b if bias else out if not batchdim: out = out[0] return out return init_fun, apply_fun