def __call__(self, query):
        """Computes relative position embedding logits.
        Arguments:
            query:      [batch_size, heads, height, width, dim]
        Returns:
            output:     [batch_size, heads, height, width, height, width]
        """

        _, _, H, W, _ = query.shape

        rel_pos_emb_w_shape = (2 * W - 1, self.head_ch)
        rel_pos_emb_w = self.param(
            'rel_pos_emb_w', initializers.normal(stddev=self.head_ch**-0.5),
            rel_pos_emb_w_shape)

        rel_pos_emb_h_shape = (2 * H - 1, self.head_ch)
        rel_pos_emb_h = self.param(
            'rel_pos_emb_h', initializers.normal(stddev=self.head_ch**-0.5),
            rel_pos_emb_h_shape)

        rel_logits_w = self._relative_logits_1d(query, rel_pos_emb_w)
        rel_logits_w = rearrange(rel_logits_w, 'b h H I W V -> b h H W I V')

        rel_logits_h = self._relative_logits_1d(
            rearrange(query, 'b h H W d -> b h W H d'), rel_pos_emb_h)
        rel_logits_h = rearrange(rel_logits_h, 'b h W V H I -> b h H W I V')
        out = rel_logits_h + rel_logits_w
        return out
Example #2
0
def generate_data_01():
    batch_size = 8
    input_shape = (batch_size, 4)

    def synth_batches():
        while True:
            images = npr.rand(*input_shape).astype("float32")
            yield images

    batches = synth_batches()
    inputs = next(batches)

    init_func, predict_func = stax.serial(
        HomotopyDense(out_dim=4, W_init=glorot_uniform(), b_init=normal()),
        HomotopyDense(out_dim=1, W_init=glorot_uniform(), b_init=normal()),
        Sigmoid,
    )

    ae_shape, ae_params = init_func(random.PRNGKey(0), input_shape)
    # assert ae_shape == input_shape
    bparam = [np.array([0.0], dtype=np.float64)]
    logits = predict_func(ae_params,
                          inputs,
                          bparam=bparam[0],
                          activation_func=sigmoid)
    loss = np.mean(
        (np.subtract(logits, logits))) + l2_norm(ae_params) + l2_norm(bparam)

    return inputs, logits, ae_params, bparam, init_func, predict_func
Example #3
0
class DeepViTConfig:
    num_classes: int = 1000
    depth: int = 32
    mlp_dim: int = 1224
    token_dim: int = 64
    emb_dim: int = 408
    num_heads: int = 12
    dim_head: int = 32
    shared_theta: bool = True
    activation_fn: ModuleDef = nn.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Any = jax.lax.Precision.DEFAULT
    kernel_init: Callable = initializers.xavier_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)
    posemb_init: Callable = initializers.normal(stddev=0.02)
Example #4
0
def init_NN(Q):
    layers = []
    num_layers = len(Q)
    for i in range(0, num_layers - 2):
        layers.append(
            Dense(Q[i + 1],
                  W_init=glorot_normal(dtype=np.float64),
                  b_init=normal(dtype=np.float64)))
        layers.append(Tanh)
    layers.append(
        Dense(Q[-1],
              W_init=glorot_normal(dtype=np.float64),
              b_init=normal(dtype=np.float64)))
    net_init, net_apply = stax.serial(*layers)
    return net_init, net_apply
Example #5
0
def vstate(request):
    N = 8
    hi = nk.hilbert.Spin(1 / 2, N)

    ma = nk.models.RBM(
        alpha=1,
        dtype=float,
        hidden_bias_init=normal(),
        visible_bias_init=normal(),
    )

    return nk.vqs.MCState(
        nk.sampler.MetropolisLocal(hi),
        ma,
    )
Example #6
0
def FullCovarianceGaussian(conditioning_fn,
                           event_dim,
                           min_scale_diag=1e-4,
                           W_init=glorot_normal(),
                           b_init=normal()):
    """A conditional Gaussian with full covariance matrix.
  
  The distribution mean and covariance are functions of the conditioning set. 
  The covariance is parameterized as the matrix square of the scale, and the
  scale is parameterized as a lower triangular matrix with positive diagonal
  and unrestricted off-diagonal elements. The diagonal elements are ensured
  to be positive by exponentiating them.
  """
    def dist_fn(raw_params):
        loc = raw_params[:event_dim]
        raw_scale = raw_params[event_dim:]
        scale = unflatten_scale(raw_scale, event_dim, min_diag=min_scale_diag)
        cov = scale @ scale.T
        return tfd.MultivariateNormalFullCovariance(loc=loc,
                                                    covariance_matrix=cov)

    param_dim = event_dim + int((event_dim * (event_dim + 1)) / 2)
    return ConditionalDistribution(conditioning_fn,
                                   dist_fn,
                                   event_dim,
                                   param_dim,
                                   W_init=W_init,
                                   b_init=b_init)
Example #7
0
def BiasRealModPhase(b_init=normal()):
    def init_fun(rng, input_shape):
        assert input_shape[-1] % 2 == 0
        input_size = input_shape[-1] // 2

        output_shape = input_shape[:-1]

        k = jax.random.split(rng, 2)

        br = b_init(k[0], (input_size, ))
        bj = b_init(k[1], (input_size, ))

        return output_shape, (br, bj)

    def apply_fun(params, inputs, **kwargs):
        br, bj = params

        xr, xc = jax.numpy.split(inputs, 2, axis=-1)

        biasr = jax.numpy.dot(
            (xr + xc)[:, ],
            br,
        )
        biasj = jax.numpy.dot(
            (xr - xc)[:, ],
            bj,
        )

        return 0.5 * biasr + 0.5j * biasj

    return init_fun, apply_fun
Example #8
0
def GRU(
        hidden_size,
        W_init=glorot_normal(),
        b_init=normal(),
        initial_state_fn=zeros,
):
    return Rnn(GRUCell(hidden_size, W_init, b_init, initial_state_fn))
Example #9
0
    def init_fun(rng, input_shape):
        rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4)

        # Primary convolutional layer.
        conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape)

        # Grouping all possible pairs.
        kernel_shape = [
            filter_shape[0], filter_shape[1], conv_channels, pair_channels
        ]
        bias_shape = [1, 1, 1, pair_channels]
        W_init = glorot_normal(in_axis=2, out_axis=3)
        b_init = normal(1e-6)
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        pair_shape = conv_shape[:2] + (15, ) + (pair_channels, )
        pair_params = (W, b)

        # Convolutional block.
        conv_block_shape, conv_block_params = conv_block_init(
            block_rng, pair_shape)

        # Forward pass.
        serial_shape, serial_params = serial_init(serial_rng, conv_block_shape)
        params = [conv_params, pair_params, conv_block_params, serial_params]
        return serial_shape, params
def init_GRU_params(rng, input_shape, W_init=glorot_normal(), b_init=normal()):
    """ Initialize the GRU layer """
    batch_size, hiden_dim, input_data_dim = input_shape  #input_data_dim=X,t

    # H0 = b_init(rng, (batch_size, hiden_dim))  # this is the H0 initial guess, that's why is dependent on batch size
    # H0 = b_init(rng, (1, hiden_dim))  # this is the H0 initial guess, that's why is dependent on batch size
    H0 = b_init(rng, (hiden_dim, ))

    k1, k2, k3 = random.split(rng, num=3)
    # W takes the X data and U takes the previous hidden state,
    # then combined by adding together with the bias post the matrix dot
    reset_W, reset_U, reset_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    k1, k2, k3 = random.split(rng, num=3)
    update_W, update_U, update_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    k1, k2, k3 = random.split(rng, num=3)
    out_W, out_U, out_b = (
        W_init(k1, (input_data_dim, hiden_dim)),
        W_init(k2, (hiden_dim, hiden_dim)),
        b_init(k3, (hiden_dim, )),
    )

    GRU_params = ((update_W, update_U, update_b), (reset_W, reset_U, reset_b),
                  (out_W, out_U, out_b))
    return H0, GRU_params
Example #11
0
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
                         strides=None, padding='VALID', kernel_init=None,
                         bias_init=normal(1e-6)):
    """Layer construction function for a general transposed-convolution layer."""

    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1,) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I'))

    @parametrized
    def conv_transpose(inputs):
        filter_shape_iter = iter(filter_shape)

        kernel_shape = [out_chan if c == 'O' else
                        inputs.shape[lhs_spec.index('C')] if c == 'I' else
                        next(filter_shape_iter) for c in rhs_spec]

        bias_shape = tuple(
            itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec]))

        kernel = parameter(kernel_shape, kernel_init, 'kernel')
        bias = parameter(bias_shape, bias_init, 'bias')
        return lax.conv_transpose(inputs, kernel, strides, padding,
                                  dimension_numbers=dimension_numbers) + bias

    return conv_transpose
class Encoder(nn.Module):
    num_layers: int
    inner_num_heads: int
    outer_num_heads: int
    inner_expand_ratio: float = 4
    outer_expand_ratio: float = 4
    attn_dropout_rate: float = 0.
    dropout_rate: float = 0.
    activation_fn = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.zeros
    pos_embed_init: Callable = initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, patch_embeddings, pixel_embeddings, is_training: bool):
        for _ in range(self.num_layers):
            patch_embeddings, pixel_embeddings = EncoderBlock(
                inner_num_heads=self.inner_num_heads,
                outer_num_heads=self.outer_num_heads,
                attn_dropout_rate=self.attn_dropout_rate,
                dropout_rate=self.dropout_rate,
                activation_fn=self.activation_fn,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init)(patch_embeddings, pixel_embeddings)

            output = patch_embeddings
            return output
Example #13
0
class Jastrow(nn.Module):
    r"""
    Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`.

    The W matrix is stored as a non-symmetric matrix, and symmetrized
    during computation by doing :code:`W = W + W.T` in the computation.
    """

    dtype: DType = jnp.complex128
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal()
    """Initializer for the weights."""

    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)
        kernel = kernel + kernel.T
        y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
Example #14
0
class Gaussian(nn.Module):
    r"""
    Multivariate Gaussain function with mean 0 and parametrised covariance matrix
    :math:`\Sigma_{ij}`.

    The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`.
    The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as
    non-positive definite matrix A.
    """

    dtype: DType = jnp.float64
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal(stddev=1.0)
    """Initializer for the weights."""
    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)

        kernel = jnp.dot(kernel.T, kernel)

        # print(kernel)
        y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
def LSTMCell(
    hidden_size,
    W_init=glorot_normal(),
    b_init=normal(),
    h_initial_state_fn=zeros,
    c_initial_state_fn=zeros,
    initial_state_seed=0,
):
    """Layer construction function for an LSTM cell.
    Formulation: Zaremba, W., 2015, https://arxiv.org/pdf/1409.2329.pdf"""
    def initial_state():
        shape = (hidden_size, )
        k1, k2 = jax.random.split(jax.random.PRNGKey(initial_state_seed))
        return LSTMState(h_initial_state_fn(k1, shape),
                         c_initial_state_fn(k2, shape))

    def init(rng, input_shape):
        in_dim, out_dim = input_shape[-1] + hidden_size, 4 * hidden_size
        output_shape = input_shape[:-1] + (hidden_size, )
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, (in_dim, out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply(params, inputs, **kwargs):
        prev_state = kwargs.pop("prev_state", initial_state())
        W, b = params
        xh = jnp.concatenate([inputs, prev_state.h], axis=-1)
        gated = jnp.matmul(xh, W) + b
        i, f, o, g = jnp.split(gated, indices_or_sections=4, axis=-1)
        c = sigmoid(f) * prev_state.c + sigmoid(i) * jnp.tanh(g)
        h = sigmoid(o) * jnp.tanh(c)
        return h, LSTMState(h, c)

    return (init, apply, initial_state)
Example #16
0
 def __init__(self, n_layers, n_hidden):
     """For simplicity, have everything have the same dimension."""
     super().__init__()
     self.cells = ModuleTuple(
         [LSTMCell(n_hidden, n_hidden) for _ in range(n_layers)])
     self.c_0s = ParameterTuple(
         [ParamInit((n_hidden, ), init.normal()) for _ in range(n_layers)])
Example #17
0
    def test_dense_is_dense_general(self):
        x = jax.random.normal(random.PRNGKey(0), (5, 3))
        dense_module = nn.Dense.partial(
            features=4,
            bias=True,
            bias_init=initializers.normal(),
        )
        y1, _ = dense_module.init(random.PRNGKey(1), x)
        dg_module = nn.DenseGeneral.partial(
            features=4,
            bias=True,
            bias_init=initializers.normal(),
        )
        y2, _ = dg_module.init(random.PRNGKey(1), x)

        onp.testing.assert_allclose(y1, y2)
Example #18
0
def DeepRNN(cell_type, hidden_dims, W_init=glorot_normal(), b_init=normal()):
    """Deep RNN cell, a wrapper for a stack of RNNs."""

    cells = [cell_type(h, W_init=W_init, b_init=b_init) for h in hidden_dims]

    def init(key, input_dim):
        keys = jax.random.split(key, num=len(cells))
        in_dims = [input_dim] + hidden_dims[:-1]
        params = []
        for cell, key, dim in zip(cells, keys, in_dims):
            params.append(cell.init(key, dim)[1])
        return [hidden_dims[-1]], params

    def apply(cells_params, inputs, prev_states, **kwargs):
        new_states = []
        for cell, prev_state, params in zip(cells, prev_states, cells_params):
            new_state, new_out = cell.apply(params, inputs, prev_state)
            new_states.append(new_state)
            inputs = new_out
        return new_states, new_out

    def initial_state():
        return [cell.initial_state() for cell in cells]

    return Module(init, apply, initial_state)
Example #19
0
def MaskedDense(mask, bias=True, W_init=glorot_normal(), b_init=normal()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng_key key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng_key key and applies the layer.

    :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer.
    :param bool bias: whether to include bias term.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng_key, input_shape):
        k1, k2 = random.split(rng_key)
        W = W_init(k1, mask.shape)
        if bias:
            b = b_init(k2, mask.shape[-1:])
            params = (W, b)
        else:
            params = W
        return input_shape[:-1] + mask.shape[-1:], params

    def apply_fun(params, inputs, **kwargs):
        if bias:
            W, b = params
            return jnp.dot(inputs, W * mask) + b
        else:
            W = params
            return jnp.dot(inputs, W * mask)

    return init_fun, apply_fun
Example #20
0
def vstate(request):
    M = request.param
    # keep this a prime number so we get different sizes on every rank...
    hi = nk.hilbert.Fock(M, 1)

    ma = nk.models.RBM(
        alpha=1,
        dtype=float,
        hidden_bias_init=normal(),
        visible_bias_init=normal(),
    )

    return nk.vqs.MCState(
        nk.sampler.MetropolisLocal(hi),
        ma,
    )
Example #21
0
def GeneralConvTranspose(dimension_numbers,
                         out_chan,
                         filter_shape,
                         strides=None,
                         padding='VALID',
                         W_init=None,
                         b_init=normal(1e-6)):
    """Layer construction function for a general transposed-convolution layer."""
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1, ) * len(filter_shape)
    strides = strides or one
    W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))

    def init_fun(rng, input_shape):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [
            out_chan if c == 'O' else input_shape[lhs_spec.index('C')]
            if c == 'I' else next(filter_shape_iter) for c in rhs_spec
        ]
        output_shape = lax.conv_transpose_shape_tuple(input_shape,
                                                      kernel_shape, strides,
                                                      padding,
                                                      dimension_numbers)
        bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
        k1, k2 = random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return lax.conv_transpose(
            inputs, W, strides, padding,
            dimension_numbers=dimension_numbers) + b

    return init_fun, apply_fun
Example #22
0
    def init_parameters(
        self, init_fun: Optional[NNInitFunc] = None, *, seed: Optional[PRNGKeyT] = None
    ):
        r"""
        Re-initializes all the parameters with the provided initialization function,
        defaulting to the normal distribution of standard deviation 0.01.

        .. warning::

            The init function will not change the dtype of the parameters, which is
            determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION

        Args:
            init_fun: a jax initializer such as :ref:`jax.nn.initializers.normal`.
                Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the
                dtype, and outputting an array with the valid dtype and shape. If left
                unspecified, defaults to :code:`jax.nn.initializers.normal(stddev=0.01)`
            seed: Optional seed to be used. The seed is synced across all MPI processes.
                If unspecified, uses a random seed.
        """
        if init_fun is None:
            init_fun = normal(stddev=0.01)

        rng = nkjax.PRNGSeq(nkjax.PRNGKey(seed))

        def new_pars(par):
            return jnp.asarray(
                init_fun(rng.take(1)[0], shape=par.shape, dtype=par.dtype),
                dtype=par.dtype,
            )

        self.parameters = jax.tree_map(new_pars, self.parameters)
Example #23
0
def RNN(hidden_dim,
        W_init=glorot_normal(),
        b_init=normal(),
        activation=jax.nn.relu):
    """Recurrent Neural Network cell."""

    input_to_hidden = Linear(hidden_dim, W_init=W_init)
    hidden_to_hidden = Affine(hidden_dim, W_init=W_init, b_init=b_init)

    def init(key, input_dim):
        output_shape = hidden_dim
        k1, k2 = jax.random.split(key)
        _, input_to_hidden_params = input_to_hidden.init(k1, input_dim)
        _, hidden_to_hidden_params = hidden_to_hidden.init(k2, hidden_dim)
        return [hidden_dim], RNNParams(input_to_hidden_params,
                                       hidden_to_hidden_params)

    def apply(params, inputs, prev_state, **kwargs):
        new_hidden_raw = (
            input_to_hidden.apply(params.input_to_hidden, inputs) +
            hidden_to_hidden.apply(params.hidden_to_hidden, prev_state.hidden))
        new_hidden = activation(new_hidden_raw)
        new_state = RNNState(hidden=new_hidden)
        return new_state, new_hidden

    def initial_state():
        return RNNState(hidden=jnp.zeros([hidden_dim]))

    return Module(init, apply, initial_state)
Example #24
0
 def __init__(self,
              out_dim,
              kernel_init=glorot_normal(),
              bias_init=normal()):
     self.bias_init = bias_init
     self.kernel_init = kernel_init
     self.out_dim = out_dim
Example #25
0
def MLP(layer_dims,
        W_init=glorot_normal(),
        b_init=normal(),
        activation=jax.nn.relu,
        activate_final=False):
  """A multi-layered perceptron."""

  layers = []
  for dim in layer_dims[:-1]:
    layers.append(Dense(dim, W_init=W_init, b_init=b_init,
                        activation=activation))
  if activate_final:
    layers.append(Dense(layer_dims[-1], W_init=W_init, b_init=b_init,
                        activation=activation))
  else:
    layers.append(Affine(layer_dims[-1], W_init=W_init, b_init=b_init))

  def init(key, input_dim):
    keys = jax.random.split(key, num=len(layer_dims))
    input_dims = [input_dim] + layer_dims[:-1]
    params = []
    for layer, key, in_dim in zip(layers, keys, input_dims):
      params.append(layer.init(key, in_dim)[1])
    return layer_dims[-1], MLPParams(params)

  def apply(params, inputs):
    for layer, param in zip(layers, params.layer_params):
      inputs = layer.apply(param, inputs)
    return inputs

  return Module(init, apply)
Example #26
0
  def test_dense_is_dense_general(self):
    x = jax.random.normal(random.PRNGKey(0), (5, 3))
    dense_module = nn.Dense(
        features=4,
        use_bias=True,
        bias_init=initializers.normal(),
    )
    y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x)
    dg_module = nn.DenseGeneral(
        features=4,
        use_bias=True,
        bias_init=initializers.normal(),
    )
    y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x)

    np.testing.assert_allclose(y1, y2)
def ConcatSquashLinear(out_dim, W_init=he_normal(), b_init=normal()):
    """ y = Sigmoid(at + c)(Wx + b) + dt. Note: he_normal only takes multi dim.
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2, k3, k4, k5 = random.split(rng, 5)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        w_t, w_tb = b_init(k3, (out_dim, )), b_init(k4, (out_dim, ))
        b_t = b_init(k5, (out_dim, ))
        return output_shape, (W, b, w_t, w_tb, b_t)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        W, b, w_t, w_tb, b_t = params

        # (W.xtt + b) *
        out = np.dot(x, W) + b
        # sigmoid(a.t + c)  +
        out *= jax.nn.sigmoid(w_t * t + w_tb)
        # d.t
        out += b_t * t

        return (out, t)

    return init_fun, apply_fun
def IgnoreConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        out = apply_fun_wrapped(params, x, **kwargs)
        return (out, t)

    return init_fun_wrapped, apply_fun_wrapped
Example #29
0
def DenseVMAP(out_dim, W_init=glorot_normal(), b_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = jax_random.split(rng)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return jnp.dot(inputs, W) + b

    apply_fun_vmap = vmap(apply_fun, (None, 0))
    return init_fun, apply_fun_vmap


#model_params = [
#            [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)),
#             ConvTranspose(16, (6, 6), padding='VALID'), LayerNormConv(), Relu,  # 10x10
#             ConvTranspose(8, (6, 6), padding='VALID'), LayerNormConv(), Relu,  # 15x15
#             ConvTranspose(1, (6, 6), padding='VALID'), LayerNormConv(), Reshape((400,))],  # 20x20
#            [Dense(25), LayerNorm(), Relu, Reshape((1, 5, 5, 1)),
#             Conv(16, (4, 4), padding='same'), LayerNormConv(), Relu,
#             Conv(8, (3, 3), padding='same'), LayerNormConv(), Relu,
#             Conv(1, (3, 3), padding='same'), LayerNormConv(), Reshape((25,)),  # 2 from Conv before
#             Dense(21)]
#        ]
Example #30
0
    def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
            bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
            return np.dot(inputs, kernel) + bias

        return dense