Ejemplo n.º 1
0
def init_NN(Q):
    layers = []
    num_layers = len(Q)
    for i in range(0, num_layers - 2):
        layers.append(
            Dense(Q[i + 1],
                  W_init=glorot_normal(dtype=np.float64),
                  b_init=normal(dtype=np.float64)))
        layers.append(Tanh)
    layers.append(
        Dense(Q[-1],
              W_init=glorot_normal(dtype=np.float64),
              b_init=normal(dtype=np.float64)))
    net_init, net_apply = stax.serial(*layers)
    return net_init, net_apply
Ejemplo n.º 2
0
def Dense(out_dim, W_init=glorot_normal(), b_init=glorot_normal()):
    """(Custom) Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = random.split(rng)
        # the below line is different from the original jax's Dense
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(
            k2, (input_shape[-1], out_dim))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return jnp.dot(inputs, W) + b

    return init_fun, apply_fun
Ejemplo n.º 3
0
def ConditionedSqueezeExcitation(ratio=4, W_cond_init=glorot_normal(), W1_init=glorot_normal(), W2_init=glorot_normal(), name='unnamed'):
    # language=rst
    """
    Like squeeze excitation, but has an extra input to help form W
    PURPOSE IS TO FIGURE OUT WHICH FEATURE MAPS MATTER GIVEN A CONDITIONER

    :param ratio: How to reduce the number of channels for the FC layer
    """
    def init_fun(key, input_shape):
        (H, W, C), (K,) = input_shape
        k1, k2, k3 = random.split(key, 3)

        # Will be shrinking the conditioner down to the size of the number of channels
        W_cond = W_cond_init(k1, (C, K))

        # Going to be concatenating the conditioner
        C_concat = C + C
        assert C_concat%ratio == 0

        # Create the parameters for the squeeze and excite
        W1 = W1_init(k2, (C_concat//ratio, C_concat))
        W2 = W2_init(k3, (C, C_concat//ratio))

        output_shape = (H, W, C)
        params = (W_cond, W1, W2)
        state = ()
        return name, output_shape, params, state

    def apply_fun(params, state, inputs, **kwargs):
        W_cond, W1, W2 = params
        inputs, cond = inputs

        # Apply the SE transforms
        x = np.mean(inputs, axis=(-2, -3))
        x = np.concatenate([x, np.dot(cond, W_cond.T)], axis=-1)
        x = np.dot(x, W1.T)
        x = jax.nn.relu(x)
        x = np.dot(x, W2.T)
        x = jax.nn.sigmoid(x)

        # Scale the input
        if(x.ndim == 3):
            out = inputs*x[None, None,:]
        else:
            out = inputs*x[:,None,None,:]
        return out, state

    return init_fun, apply_fun
Ejemplo n.º 4
0
def GeneralConvTranspose(dimension_numbers,
                         out_chan,
                         filter_shape,
                         strides=None,
                         padding='VALID',
                         W_init=None,
                         b_init=normal(1e-6)):
    """Layer construction function for a general transposed-convolution layer."""
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1, ) * len(filter_shape)
    strides = strides or one
    W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))

    def init_fun(rng, input_shape):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [
            out_chan if c == 'O' else input_shape[lhs_spec.index('C')]
            if c == 'I' else next(filter_shape_iter) for c in rhs_spec
        ]
        output_shape = lax.conv_transpose_shape_tuple(input_shape,
                                                      kernel_shape, strides,
                                                      padding,
                                                      dimension_numbers)
        bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
        k1, k2 = random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return lax.conv_transpose(
            inputs, W, strides, padding,
            dimension_numbers=dimension_numbers) + b

    return init_fun, apply_fun
def TallAffineDiagCov(flow, out_dim, n_training_importance_samples=32, A_init=glorot_normal(), b_init=normal(), name='unnamed'):
    """ Affine function to go up a dimension

        Args:
    """
    _init_fun, _forward, _inverse = flow

    def init_fun(key, input_shape, condition_shape):
        x_shape = input_shape
        output_shape = x_shape[:-1] + (out_dim,)
        keys = random.split(key, 3)

        x_dim = x_shape[-1]
        z_dim = out_dim
        A = A_init(keys[0], (x_shape[-1], out_dim))
        b = b_init(keys[1], (x_shape[-1],))
        flow_name, flow_output_shape, flow_params, flow_state = _init_fun(keys[2], output_shape, condition_shape)
        log_diag_cov = jnp.ones(input_shape[-1])*0.0
        params = ((A, b, log_diag_cov), flow_params)
        state = ((), flow_state)
        return (name, flow_name), flow_output_shape, params, state

    def forward(params, state, log_px, x, condition, **kwargs):
        ((A, b, log_diag_cov), flow_params) = params
        _, flow_state = state

        # Get the terms to compute and sample from the posterior
        sigma = kwargs.get('sigma', 1.0)
        z, log_hx, sigma_ATA_chol = tall_affine_posterior_diag_cov(x, b, A, log_diag_cov, sigma)

        # Importance sample from N(z|\mu(x),\Sigma(x)) and compile the results
        log_pz, z, updated_flow_states = importance_sample_prior(_forward, flow_params, flow_state, z, condition, sigma_ATA_chol, n_training_importance_samples, **kwargs)

        # Compute the final estimate of the integral
        log_px = log_px + log_pz + log_hx

        return log_px, z, ((), updated_flow_states)

    def inverse(params, state, log_pz, z, condition, **kwargs):
        ((A, b, log_diag_cov), flow_params) = params
        _, flow_state = state

        log_pz, z, updated_state = _inverse(flow_params, flow_state, log_pz, z, condition, **kwargs)

        # Compute Az + b
        # Don't need to sample because we already sampled from p(z)!!!!
        x = jnp.dot(z, A.T) + b
        key = kwargs.pop('key', None)
        if(key is not None):

            sigma = kwargs.get('sigma', 1.0)
            noise = random.normal(key, x.shape)*sigma

            x += noise*jnp.exp(0.5*log_diag_cov)

        # Compute N(x|Az + b, \Sigma).  This is just the log partition function.
        log_px = - 0.5*jnp.sum(log_diag_cov) - 0.5*x.shape[-1]*jnp.log(2*jnp.pi)
        return log_pz + log_px, x, ((), updated_state)

    return init_fun, forward, inverse
Ejemplo n.º 6
0
def AAEmbedding(embedding_dims: int = 10, E_init=glorot_normal(), **kwargs):
    """
    Initial n-dimensional embedding of each amino-acid
    """
    def init_fun(rng, input_shape):
        """
        Generates the inital AA embedding matrix.

        `input_shape`:
            one-hot encoded AA sequence -> (n_aa, n_unique_aa)
        `output_dims`:
            embedded sequence -> (n_aa, embedding_dims)
        `emb_matrix`:
            embedding matrix -> (n_unique_aa, embedding_dims)
        """
        k1, _ = random.split(rng)
        emb_matrix = E_init(k1, (input_shape[1], embedding_dims))
        output_dims = (-1, embedding_dims)

        return output_dims, emb_matrix

    def apply_fun(params, inputs, **kwargs):
        """
        Embed a single AA sequence
        """
        emb_matrix = params
        # (n_aa, n_unique_aa) * (n_unique_aa, embedding_dims) => (n_aa, embedding_dims) # noqa: E501
        return np.matmul(inputs, emb_matrix)

    return init_fun, apply_fun
def LSTMCell(
    hidden_size,
    W_init=glorot_normal(),
    b_init=normal(),
    h_initial_state_fn=zeros,
    c_initial_state_fn=zeros,
    initial_state_seed=0,
):
    """Layer construction function for an LSTM cell.
    Formulation: Zaremba, W., 2015, https://arxiv.org/pdf/1409.2329.pdf"""
    def initial_state():
        shape = (hidden_size, )
        k1, k2 = jax.random.split(jax.random.PRNGKey(initial_state_seed))
        return LSTMState(h_initial_state_fn(k1, shape),
                         c_initial_state_fn(k2, shape))

    def init(rng, input_shape):
        in_dim, out_dim = input_shape[-1] + hidden_size, 4 * hidden_size
        output_shape = input_shape[:-1] + (hidden_size, )
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, (in_dim, out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply(params, inputs, **kwargs):
        prev_state = kwargs.pop("prev_state", initial_state())
        W, b = params
        xh = jnp.concatenate([inputs, prev_state.h], axis=-1)
        gated = jnp.matmul(xh, W) + b
        i, f, o, g = jnp.split(gated, indices_or_sections=4, axis=-1)
        c = sigmoid(f) * prev_state.c + sigmoid(i) * jnp.tanh(g)
        h = sigmoid(o) * jnp.tanh(c)
        return h, LSTMState(h, c)

    return (init, apply, initial_state)
Ejemplo n.º 8
0
    def init_fun(rng, input_shape):
        rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4)

        # Primary convolutional layer.
        conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape)

        # Grouping all possible pairs.
        kernel_shape = [
            filter_shape[0], filter_shape[1], conv_channels, pair_channels
        ]
        bias_shape = [1, 1, 1, pair_channels]
        W_init = glorot_normal(in_axis=2, out_axis=3)
        b_init = normal(1e-6)
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        pair_shape = conv_shape[:2] + (15, ) + (pair_channels, )
        pair_params = (W, b)

        # Convolutional block.
        conv_block_shape, conv_block_params = conv_block_init(
            block_rng, pair_shape)

        # Forward pass.
        serial_shape, serial_params = serial_init(serial_rng, conv_block_shape)
        params = [conv_params, pair_params, conv_block_params, serial_params]
        return serial_shape, params
Ejemplo n.º 9
0
def GRU(
        hidden_size,
        W_init=glorot_normal(),
        b_init=normal(),
        initial_state_fn=zeros,
):
    return Rnn(GRUCell(hidden_size, W_init, b_init, initial_state_fn))
Ejemplo n.º 10
0
def DenseVMAP(out_dim, W_init=glorot_normal(), b_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = jax_random.split(rng)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return jnp.dot(inputs, W) + b

    apply_fun_vmap = vmap(apply_fun, (None, 0))
    return init_fun, apply_fun_vmap


#model_params = [
#            [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)),
#             ConvTranspose(16, (6, 6), padding='VALID'), LayerNormConv(), Relu,  # 10x10
#             ConvTranspose(8, (6, 6), padding='VALID'), LayerNormConv(), Relu,  # 15x15
#             ConvTranspose(1, (6, 6), padding='VALID'), LayerNormConv(), Reshape((400,))],  # 20x20
#            [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)),
#             Conv(16, (4, 4), padding='same'), LayerNormConv(), Relu,
#             Conv(8, (3, 3), padding='same'), LayerNormConv(), Relu,
#             Conv(1, (3, 3), padding='same'), LayerNormConv(), Reshape((25,)),  # 2 from Conv before
#             Dense(21)]
#        ]
Ejemplo n.º 11
0
 def __init__(self,
              out_dim,
              kernel_init=glorot_normal(),
              bias_init=normal()):
     self.bias_init = bias_init
     self.kernel_init = kernel_init
     self.out_dim = out_dim
def init_GRU_params(rng, input_shape, W_init=glorot_normal(), b_init=normal()):
    """ Initialize the GRU layer """
    batch_size, hiden_dim, input_data_dim = input_shape  #input_data_dim=X,t

    # H0 = b_init(rng, (batch_size, hiden_dim))  # this is the H0 initial guess, that's why is dependent on batch size
    # H0 = b_init(rng, (1, hiden_dim))  # this is the H0 initial guess, that's why is dependent on batch size
    H0 = b_init(rng, (hiden_dim, ))

    k1, k2, k3 = random.split(rng, num=3)
    # W takes the X data and U takes the previous hidden state,
    # then combined by adding together with the bias post the matrix dot
    reset_W, reset_U, reset_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    k1, k2, k3 = random.split(rng, num=3)
    update_W, update_U, update_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    k1, k2, k3 = random.split(rng, num=3)
    out_W, out_U, out_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    GRU_params = ((update_W, update_U, update_b), (reset_W, reset_U, reset_b),
                  (out_W, out_U, out_b))
    return H0, GRU_params
Ejemplo n.º 13
0
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
                         strides=None, padding='VALID', kernel_init=None,
                         bias_init=normal(1e-6)):
    """Layer construction function for a general transposed-convolution layer."""

    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1,) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I'))

    @parametrized
    def conv_transpose(inputs):
        filter_shape_iter = iter(filter_shape)

        kernel_shape = [out_chan if c == 'O' else
                        inputs.shape[lhs_spec.index('C')] if c == 'I' else
                        next(filter_shape_iter) for c in rhs_spec]

        bias_shape = tuple(
            itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec]))

        kernel = parameter(kernel_shape, kernel_init, 'kernel')
        bias = parameter(bias_shape, bias_init, 'bias')
        return lax.conv_transpose(inputs, kernel, strides, padding,
                                  dimension_numbers=dimension_numbers) + bias

    return conv_transpose
Ejemplo n.º 14
0
def DeepRNN(cell_type, hidden_dims, W_init=glorot_normal(), b_init=normal()):
    """Deep RNN cell, a wrapper for a stack of RNNs."""

    cells = [cell_type(h, W_init=W_init, b_init=b_init) for h in hidden_dims]

    def init(key, input_dim):
        keys = jax.random.split(key, num=len(cells))
        in_dims = [input_dim] + hidden_dims[:-1]
        params = []
        for cell, key, dim in zip(cells, keys, in_dims):
            params.append(cell.init(key, dim)[1])
        return [hidden_dims[-1]], params

    def apply(cells_params, inputs, prev_states, **kwargs):
        new_states = []
        for cell, prev_state, params in zip(cells, prev_states, cells_params):
            new_state, new_out = cell.apply(params, inputs, prev_state)
            new_states.append(new_state)
            inputs = new_out
        return new_states, new_out

    def initial_state():
        return [cell.initial_state() for cell in cells]

    return Module(init, apply, initial_state)
Ejemplo n.º 15
0
def MaskedDense(mask, bias=True, W_init=glorot_normal(), b_init=normal()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng_key key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng_key key and applies the layer.

    :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer.
    :param bool bias: whether to include bias term.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng_key, input_shape):
        k1, k2 = random.split(rng_key)
        W = W_init(k1, mask.shape)
        if bias:
            b = b_init(k2, mask.shape[-1:])
            params = (W, b)
        else:
            params = W
        return input_shape[:-1] + mask.shape[-1:], params

    def apply_fun(params, inputs, **kwargs):
        if bias:
            W, b = params
            return jnp.dot(inputs, W * mask) + b
        else:
            W = params
            return jnp.dot(inputs, W * mask)

    return init_fun, apply_fun
Ejemplo n.º 16
0
def FullCovarianceGaussian(conditioning_fn,
                           event_dim,
                           min_scale_diag=1e-4,
                           W_init=glorot_normal(),
                           b_init=normal()):
    """A conditional Gaussian with full covariance matrix.
  
  The distribution mean and covariance are functions of the conditioning set. 
  The covariance is parameterized as the matrix square of the scale, and the
  scale is parameterized as a lower triangular matrix with positive diagonal
  and unrestricted off-diagonal elements. The diagonal elements are ensured
  to be positive by exponentiating them.
  """
    def dist_fn(raw_params):
        loc = raw_params[:event_dim]
        raw_scale = raw_params[event_dim:]
        scale = unflatten_scale(raw_scale, event_dim, min_diag=min_scale_diag)
        cov = scale @ scale.T
        return tfd.MultivariateNormalFullCovariance(loc=loc,
                                                    covariance_matrix=cov)

    param_dim = event_dim + int((event_dim * (event_dim + 1)) / 2)
    return ConditionalDistribution(conditioning_fn,
                                   dist_fn,
                                   event_dim,
                                   param_dim,
                                   W_init=W_init,
                                   b_init=b_init)
Ejemplo n.º 17
0
def RNN(hidden_dim,
        W_init=glorot_normal(),
        b_init=normal(),
        activation=jax.nn.relu):
    """Recurrent Neural Network cell."""

    input_to_hidden = Linear(hidden_dim, W_init=W_init)
    hidden_to_hidden = Affine(hidden_dim, W_init=W_init, b_init=b_init)

    def init(key, input_dim):
        output_shape = hidden_dim
        k1, k2 = jax.random.split(key)
        _, input_to_hidden_params = input_to_hidden.init(k1, input_dim)
        _, hidden_to_hidden_params = hidden_to_hidden.init(k2, hidden_dim)
        return [hidden_dim], RNNParams(input_to_hidden_params,
                                       hidden_to_hidden_params)

    def apply(params, inputs, prev_state, **kwargs):
        new_hidden_raw = (
            input_to_hidden.apply(params.input_to_hidden, inputs) +
            hidden_to_hidden.apply(params.hidden_to_hidden, prev_state.hidden))
        new_hidden = activation(new_hidden_raw)
        new_state = RNNState(hidden=new_hidden)
        return new_state, new_hidden

    def initial_state():
        return RNNState(hidden=jnp.zeros([hidden_dim]))

    return Module(init, apply, initial_state)
Ejemplo n.º 18
0
def MLP(layer_dims,
        W_init=glorot_normal(),
        b_init=normal(),
        activation=jax.nn.relu,
        activate_final=False):
  """A multi-layered perceptron."""

  layers = []
  for dim in layer_dims[:-1]:
    layers.append(Dense(dim, W_init=W_init, b_init=b_init,
                        activation=activation))
  if activate_final:
    layers.append(Dense(layer_dims[-1], W_init=W_init, b_init=b_init,
                        activation=activation))
  else:
    layers.append(Affine(layer_dims[-1], W_init=W_init, b_init=b_init))

  def init(key, input_dim):
    keys = jax.random.split(key, num=len(layer_dims))
    input_dims = [input_dim] + layer_dims[:-1]
    params = []
    for layer, key, in_dim in zip(layers, keys, input_dims):
      params.append(layer.init(key, in_dim)[1])
    return layer_dims[-1], MLPParams(params)

  def apply(params, inputs):
    for layer, param in zip(layers, params.layer_params):
      inputs = layer.apply(param, inputs)
    return inputs

  return Module(init, apply)
Ejemplo n.º 19
0
    def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
            bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
            return np.dot(inputs, kernel) + bias

        return dense
Ejemplo n.º 20
0
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return np.dot(inputs, kernel) + bias

        return dense
Ejemplo n.º 21
0
def DensePurificationComplex(out_pure,
                             out_mix,
                             use_hidden_bias=True,
                             W_init=glorot_normal(),
                             b_init=normal()):
    """Layer constructor function for a complex purification layer."""
    def init_fun(rng, input_shape):
        assert input_shape[-1] % 2 == 0
        output_shape = input_shape[:-1] + (2 * out_pure + out_mix, )

        k = jax.random.split(rng, 7)

        input_size = input_shape[-1] // 2

        # Weights for the pure part
        Wr, Wi = (
            W_init(k[0], (input_size, out_pure)),
            W_init(k[1], (input_size, out_pure)),
        )

        # Weights for the mixing part
        Vr, Vi = (
            W_init(k[2], (input_size, out_mix)),
            W_init(k[3], (input_size, out_mix)),
        )

        if use_hidden_bias:
            br, bi = (b_init(k[4], (out_pure, )), b_init(k[5], (out_pure, )))
            cr = b_init(k[6], (out_mix, ))

            return output_shape, (Wr, Wi, Vr, Vi, br, bi, cr)
        else:
            return output_shape, (Wr, Wi, Vr, Vi)

    @jax.jit
    def apply_fun(params, inputs, **kwargs):
        if use_hidden_bias:
            Wr, Wi, Vr, Vi, br, bi, cr = params
        else:
            Wr, Wi, Vr, Vi = params

        xr, xc = jax.numpy.split(inputs, 2, axis=-1)

        thetar = jax.numpy.dot(xr[:, ], (Wr + 1.0j * Wi))
        thetac = jax.numpy.dot(xc[:, ], (Wr - 1.0j * Wi))

        thetam = jax.numpy.dot(xr[:, ], (Vr + 1.0j * Vi))
        thetam += jax.numpy.dot(xc[:, ], (Vr - 1.0j * Vi))

        if use_hidden_bias:
            thetar += br + 1.0j * bi
            thetac += br - 1.0j * bi
            thetam += 2 * cr

        return jax.numpy.hstack((thetar, thetam, thetac))

    return init_fun, apply_fun
Ejemplo n.º 22
0
def init_param(rng, input_units, feature_size, label_size, label_units,
               hidden_units):
    init = glorot_normal()
    k1, k2, k3, k4, k5 = npr.split(rng, num=5)
    A_1 = init_dense(k1, (input_units, hidden_units))
    A_2 = init_dense(k2, (hidden_units, feature_size))
    B = init(k3, (feature_size, label_size))
    C_1 = init_dense(k4, (label_size, label_units))
    c_2 = init(k5, (label_units, 1))
    return Param(A_1, A_2, B, C_1, c_2)
Ejemplo n.º 23
0
    def Jastrow(W_init=glorot_normal()):
        def init_fun(rng, input_shape):
            N = input_shape[-1]
            return input_shape[:-1], W_init(rng, (N, N))

        def apply_fun(W, x, **kwargs):
            return jax.vmap(lambda W, x: jax.numpy.einsum("i,ij,j", x, W, x),
                            in_axes=(None, 0))(W, x)

        return init_fun, apply_fun
Ejemplo n.º 24
0
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""

    @parametrized
    def dense(inputs):
        kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel')
        bias = parameter((out_dim,), bias_init, name='bias')
        return np.dot(inputs, kernel) + bias

    return dense
Ejemplo n.º 25
0
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""
  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
    return output_shape, (W, b)
  def apply_fun(params, inputs, **kwargs):
    W, b = params
    return np.dot(inputs, W) + b
  return init_fun, apply_fun
Ejemplo n.º 26
0
def Linear(out_dim, W_init=glorot_normal()):
  """A Linear layer (no bias)."""

  def init(key, input_dim):
    W = W_init(key, (input_dim, out_dim))
    return out_dim, LinearParams(W)

  def apply(params, inputs):
    return jnp.dot(inputs, params.W)

  return Module(init, apply)
Ejemplo n.º 27
0
def DenseNoBias(out_dim, W_init=glorot_normal()):
    """Layer constructor function for a dense (fully-connected) layer but without
  any bias term."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        W = W_init(rng, (input_shape[-1], out_dim))
        return output_shape, W

    def apply_fun(W, inputs, **_kwargs):
        return inputs @ W

    return init_fun, apply_fun
Ejemplo n.º 28
0
def Affine(out_dim, W_init=glorot_normal(), b_init=normal()):
  """An affine layer."""

  def init(key, input_dim):
    k1, k2 = jax.random.split(key)
    W, b = W_init(k1, (input_dim, out_dim)), b_init(k2, (out_dim,))
    return out_dim, AffineParams(W, b)

  def apply(params, inputs):
    return jnp.dot(inputs, params.W) + params.b

  return Module(init, apply)
Ejemplo n.º 29
0
def Dense(out_dim,
          W_init=glorot_normal(),
          b_init=normal(),
          activation=jax.nn.relu):
  """A single-layer MLP (Affine layer with an activation)."""
  affine = Affine(out_dim, W_init=W_init, b_init=b_init)

  def init(key, input_dim):
    return affine.init(key, input_dim)

  def apply(params, inputs):
    return activation(affine.apply(params, inputs))

  return Module(init, apply)
Ejemplo n.º 30
0
def GeneralConv(dimension_numbers,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                W_init=None,
                b_init=normal(1e-6),
                bias=True):
    """Layer construction function for a general convolution layer."""
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1, ) * len(filter_shape)
    strides = strides or one
    W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))

    def init_fun(rng, input_shape):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [
            out_chan if c == 'O' else input_shape[lhs_spec.index('C')]
            if c == 'I' else next(filter_shape_iter) for c in rhs_spec
        ]
        output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                                    strides, padding,
                                                    dimension_numbers)
        bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
        bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
        k1, k2 = random.split(rng)
        W = W_init(k1, kernel_shape)
        if bias:
            b = b_init(k2, bias_shape)
            return output_shape, (W, b)
        else:
            return output_shape, (W)

    def apply_fun(params, inputs, **kwargs):
        if bias:
            W, b = params
        else:
            W = params
        batchdim = True
        if inputs.ndim == 3:
            batchdim = False
            inputs = np.expand_dims(inputs, 0)
        out = lax.conv_general_dilated(inputs, W, strides, padding, one, one,
                                       dimension_numbers)
        out = out + b if bias else out
        if not batchdim:
            out = out[0]
        return out

    return init_fun, apply_fun