コード例 #1
0
def NdmSpinPhase(hilbert,
                 alpha,
                 beta,
                 use_hidden_bias=True,
                 use_visible_bias=True):
    r"""
    A fully connected Neural Density Matrix (DBM). This type density matrix is
    obtained purifying a RBM with spin 1/2 hidden units.

    The number of purification hidden units can be chosen arbitrarily.

    The weights are taken to be complex-valued. A complete definition of this
    machine can be found in Eq. 2 of Hartmann, M. J. & Carleo, G.,
    Phys. Rev. Lett. 122, 250502 (2019).

    Args:
        hilbert: Hilbert space of the system.
        alpha: `alpha * hilbert.size` is the number of hidden spins used for
                the pure state part of the density-matrix.
        beta: `beta * hilbert.size` is the number of hidden spins used for the purification.
            beta=0 for example corresponds to a pure state.
        use_hidden_bias: If ``True`` bias on the hidden units is taken.
                         Default ``True``.
    """
    mod_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSum2,
    )

    phs_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSub2,
    )

    mixing = stax.serial(
        DenseMixingReal(beta * hilbert.size, use_hidden_bias),
        LogCoshLayer,
        SumLayer,
    )

    if use_visible_bias:
        biases = BiasRealModPhase()
        net = stax.serial(
            stax.FanOut(4),
            stax.parallel(mod_pure, phs_pure, mixing, biases),
            stax.FanInSum,
        )
    else:
        net = stax.serial(
            stax.FanOut(3),
            stax.parallel(mod_pure, phs_pure, mixing),
            stax.FanInSum,
        )

    return Jax(hilbert, net, dtype=float, outdtype=complex)
コード例 #2
0
ファイル: slax.py プロジェクト: nwilliam868/tensor2tensor
def residual(*layers, **kwargs):
    """Constructs a residual version of layers, summing input to layers output."""
    res = kwargs.get('res', stax.Identity)
    if len(layers) > 1:
        return stax.serial(stax.FanOut(2),
                           stax.parallel(stax.serial(*layers), res),
                           stax.FanInSum)
    elif len(layers) == 1:
        return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res),
                           stax.FanInSum)
    else:
        raise ValueError('Empty residual combinator.')
コード例 #3
0
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor):
    """Global convolution block.

  First downsample the input, then apply global conv, finally upsample and
  concatenate with the input. The input itself is one channel in the output.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.

  Returns:
    (init_fn, apply_fn) pair.
  """
    layers = []
    layers.extend([linear_interpolation_transpose()] * downsample_factor)
    layers.append(
        exponential_global_convolution(
            num_channels=num_channels -
            1,  # one channel is reserved for input.
            grids=grids,
            minval=minval,
            maxval=maxval,
            downsample_factor=downsample_factor))
    layers.extend([linear_interpolation()] * downsample_factor)
    global_conv_path = stax.serial(*layers)
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(stax.Identity, global_conv_path),
        stax.FanInConcat(axis=-1),
    )
コード例 #4
0
def network(rng: random.PRNGKey, num_in: int, num_hidden: int) -> Tuple:
    """Factory for producing the dequantization neural network and its
    parameterizations.

    Args:
        rng: Pseudo-random number generator seed.
        num_in: Number of inputs to the network.
        num_hidden: Number of hidden units in the hidden layer.

    Returns:
        out: A tuple containing the network parameters and a callable function
            that returns the neural network output for the given input.

    """
    num_in_sq = num_in**2
    num_in_choose_two = num_in * (num_in - 1) // 2
    params_init, fn = stax.serial(
        stax.Flatten, stax.Dense(num_hidden), stax.Relu,
        stax.Dense(num_hidden), stax.Relu, stax.FanOut(4),
        stax.parallel(
            stax.Dense(num_in),
            stax.Dense(num_in_choose_two),
            stax.serial(stax.Dense(num_in), stax.Softplus),
            stax.serial(stax.Dense(num_in_choose_two), stax.Softplus),
        ))
    _, params = params_init(rng, (-1, num_in_sq))
    return params, fn
コード例 #5
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),
                      stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
コード例 #6
0
def FanOut(num):
    """Layer construction function for a fan-out layer.

  Based on `jax.experimental.stax.FanOut`.
  """
    init_fun, apply_fun = stax.FanOut(num)
    ker_fun = lambda kernels: [kernels] * num
    return init_fun, apply_fun, ker_fun
コード例 #7
0
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format):
    double_conv = stax.serial(
        stax.GeneralConv(input_format, out_channels, kernel_size, stride, padding),
        stax.Elu,
    )
    return Module(
        *stax.serial(
            stax.FanOut(2), stax.parallel(double_conv, stax.Identity), stax.FanInSum
        )
    )
コード例 #8
0
ファイル: train.py プロジェクト: Justin-Tan/jax_sumo
def decoder(hidden_dim, x_dim=2, activation='Tanh'):
    activation = getattr(stax, activation)
    decoder_init, decode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(x_dim), stax.Dense(x_dim)),
    )
    return decoder_init, decode
コード例 #9
0
ファイル: train.py プロジェクト: Justin-Tan/jax_sumo
def encoder(hidden_dim, z_dim, activation='Tanh'):
    activation = getattr(stax, activation)
    encoder_init, encode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim), stax.Dense(z_dim)),
    )
    return encoder_init, encode
コード例 #10
0
ファイル: resnet.py プロジェクト: skhong0831/tensor2tensor
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.serial(stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum)
コード例 #11
0
 def ConvBlock(self, kernel_size, filters, strides=(2, 2)):
     filters1, filters2, filters3 = filters
     Main = stax.serial(
         stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu,
         stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'),
         stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)),
         stax.BatchNorm())
     Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                            stax.BatchNorm())
     return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                        stax.FanInSum, stax.Relu)
コード例 #12
0
ファイル: slax.py プロジェクト: nwilliam868/tensor2tensor
def multiplex(*args):
    """Helper to form input argument lists of bound variables.

  Args:
    *args: list of bound layers or raw stax Identity layers.

  Returns:
    A layer returning in parallel the bound variables as well as
  (multiple) copies of this layer's input wherever Identity has been specified.
  """
    return stax.serial(stax.FanOut(len(args)), stax.parallel(*args))
コード例 #13
0
def UNetBlock(filters, kernel_size, inner_block, **kwargs):
    def make_main(input_shape):
        return stax.serial(
            UnbiasedConv(filters, kernel_size, **kwargs),
            inner_block,
            UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs),
        )

    Main = stax.shape_dependent(make_main)
    return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity),
                       stax.FanInSum)
コード例 #14
0
def jaxRbmSpinPhase(hilbert, alpha):
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
        ),
        FanInSum2ModPhase,
    )
コード例 #15
0
ファイル: jax.py プロジェクト: vlpap/netket
def ndmSpinPhase(hilbert,
                 alpha,
                 beta,
                 use_hidden_bias=True,
                 use_visible_bias=True):
    mod_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSum2,
    )

    phs_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSub2,
    )

    mixing = stax.serial(
        DenseMixingReal(int(beta * hilbert.size), use_hidden_bias),
        LogCoshLayer,
        SumLayer,
    )

    if use_visible_bias:
        biases = BiasRealModPhase()
        net = stax.serial(
            stax.FanOut(4),
            stax.parallel(mod_pure, phs_pure, mixing, biases),
            stax.FanInSum,
        )
    else:
        net = stax.serial(
            stax.FanOut(3),
            stax.parallel(mod_pure, phs_pure, mixing),
            stax.FanInSum,
        )

    return net
コード例 #16
0
ファイル: resnet.py プロジェクト: skhong0831/tensor2tensor
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.serial(stax.Conv(filters1, (1, 1),
                                 strides), stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum, stax.Relu)
コード例 #17
0
ファイル: train_svhn.py プロジェクト: adp-anonymous/adp
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False):
    """Wide ResNet block."""
    pre = stax.serial(stax.BatchNorm(), stax.Relu)
    mid = stax.serial(
        pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'),
        stax.BatchNorm(), stax.Relu,
        stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME'))
    if channel_mismatch:
        cut = stax.serial(
            pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'))
    else:
        cut = stax.Identity
    return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
コード例 #18
0
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(
                stax.Dense(z_dim, W_init=stax.randn()),
                stax.Exp,
            ),
        ),
    )
コード例 #19
0
    def IdentityBlock(self, kernel_size, filters):
        filters1, filters2 = filters

        def make_main(input_shape):
            # the number of output channels depends on the number of input channels
            return stax.serial(
                stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu,
                stax.Conv(filters2, (kernel_size, kernel_size),
                          padding='SAME'), stax.BatchNorm(), stax.Relu,
                stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())

        Main = stax.shape_dependent(make_main)
        return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity),
                           stax.FanInSum, stax.Relu)
コード例 #20
0
def JaxRbmSpinPhase(hilbert, alpha, dtype=float):
    return Jax(
        hilbert,
        stax.serial(
            stax.FanOut(2),
            stax.parallel(
                stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                            SumLayer),
                stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                            SumLayer),
            ),
            FanInSum2ModPhase,
        ),
        dtype=dtype,
    )
コード例 #21
0
ファイル: resnet.py プロジェクト: skhong0831/tensor2tensor
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2 = filters

    def MakeMain(input_shape):
        # the number of output channels depends on the number of input channels
        return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                           stax.Relu,
                           stax.Conv(filters2, (ks, ks), padding='SAME'),
                           stax.BatchNorm(), stax.Relu,
                           stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())

    main = stax.shape_dependent(MakeMain)
    return stax.serial(stax.FanOut(2), stax.parallel(main, stax.Identity),
                       stax.FanInSum, stax.Relu)
コード例 #22
0
ファイル: ppo.py プロジェクト: epignatelli/helx
def Cnn(n_actions: int, hidden_size: int = 512) -> Module:
    return stax.serial(
        stax.Conv(32, (8, 8), (4, 4), "VALID"),
        stax.Relu,
        stax.Conv(64, (4, 4), (2, 2), "VALID"),
        stax.Relu,
        stax.Conv(64, (3, 3), (1, 1), "VALID"),
        stax.Relu,
        stax.Flatten,
        stax.Dense(hidden_size),
        stax.Relu,
        stax.FanOut(2),
        stax.parallel(
            stax.serial(
                stax.Dense(n_actions),
                stax.Softmax,
            ),  #  actor
            stax.serial(stax.Dense(1), ),  # critic
        ),
    )
コード例 #23
0
def wrap_network_with_self_interaction_layer(network, grids, interaction_fn):
    """Wraps a network with self-interaction layer.

  Args:
    network: an (init_fn, apply_fn) pair.
     * init_fn: The init_fn of the neural network. It takes an rng key and
         an input shape and returns an (output_shape, params) pair.
     * apply_fn: The apply_fn of the neural network. It takes params,
         inputs, and an rng key and applies the layer.
    grids: Float numpy array with shape (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.

  Returns:
    (init_fn, apply_fn) pair.
  """
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(stax.Identity, network),
        self_interaction_layer(grids, interaction_fn),
    )
コード例 #24
0
def BlockNeuralAutoregressiveNN(input_dim,
                                hidden_factors=[8, 8],
                                residual=None):
    """
    An implementation of Block Neural Autoregressive neural network.

    **References**

    1. *Block Neural Autoregressive Flow*,
       Nicola De Cao, Ivan Titov, Wilker Aziz

    :param int input_dim: The dimensionality of the input.
    :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per
        input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1].
        The elements of hidden_factors must be integers.
    :param str residual: Type of residual connections to use. One of `None`, `"normal"`, `"gated"`.
    :return: an (`init_fn`, `update_fn`) pair.
    """
    layers = []
    in_factor = 1
    for hidden_factor in hidden_factors:
        layers.append(BlockMaskedDense(input_dim, in_factor, hidden_factor))
        layers.append(Tanh())
        in_factor = hidden_factor
    layers.append(BlockMaskedDense(input_dim, in_factor, 1))
    arn = stax.serial(*layers)
    if residual is not None:
        FanInResidual = (FanInResidualGated
                         if residual == "gated" else FanInResidualNormal)
        arn = stax.serial(stax.FanOut(2), stax.parallel(arn, stax.Identity),
                          FanInResidual())

    def init_fun(rng, input_shape):
        return arn[0](rng, input_shape)

    def apply_fun(params, inputs, **kwargs):
        out, logdet = arn[1](params, (inputs, None), **kwargs)
        return out, logdet.reshape(inputs.shape)

    return init_fun, apply_fun
コード例 #25
0
    def init_fun(rng, input_shape):
        """
        :param rng: rng used to initialize parameters
        :param input_shape: input shape
        """
        # TODO: consider removing permutation so we can move those layer constructions outside
        # init_fun. It seems that we can add a PermuteTransform layer to achieve the same effect.
        nonlocal permutation, net

        if permutation is None:
            # By default set a random permutation of variables, which is
            # important for performance with multiple steps
            rng, rng_perm = random.split(rng)
            permutation = random.shuffle(rng_perm, np.arange(input_dim))

        # Create masks
        masks, mask_skip = create_mask(input_dim=input_dim,
                                       hidden_dims=hidden_dims,
                                       permutation=permutation,
                                       output_dim_multiplier=output_multiplier)

        main_layers = []
        # Create masked layers
        for i, mask in enumerate(masks):
            main_layers.append(MaskedDense(mask))
            if i < len(masks) - 1:
                main_layers.append(nonlinearity)

        if skip_connections:
            net_init, net = stax.serial(
                stax.FanOut(2),
                stax.parallel(stax.serial(*main_layers),
                              MaskedDense(mask_skip, bias=False)),
                stax.FanInSum)
        else:
            net_init, net = stax.serial(*main_layers)

        return net_init(rng, input_shape)
コード例 #26
0
def network_factory(rng: jnp.ndarray, num_in: int, num_out: int,
                    num_hidden: int) -> Tuple:
    """Factory for producing neural networks and their parameterizations.

    Args:
        rng: Pseudo-random number generator seed.
        num_in: Number of inputs to the network.
        num_out: Number of variables to transform by an affine transformation.
            Each variable receives an associated shift and scale.
        num_hidden: Number of hidden units in the hidden layer.

    Returns:
        out: A tuple containing the network parameters and a callable function
            that returns the neural network output for the given input.

    """
    params_init, fn = stax.serial(
        stax.Dense(num_hidden), stax.Relu, stax.Dense(num_hidden), stax.Relu,
        stax.FanOut(2),
        stax.parallel(stax.Dense(num_out),
                      stax.serial(stax.Dense(num_out), stax.Softplus)))
    _, params = params_init(rng, (-1, num_in))
    return params, fn
コード例 #27
0
ファイル: vae.py プロジェクト: byzhang/d3p
def encoder(hidden_dim, z_dim):
    """Defines the encoder, i.e., the network taking us from observations
        to (a distribution of) latent variables.

    z is following a normal distribution, thus needs mean and variance.

    Network structure:
    x -> dense layer of hidden_dim with softplus activation --> dense layer of z_dim ( = means/loc of z)
                                                            |-> dense layer of z_dim with (elementwise) exp() as activation func ( = variance of z )
    (note: the exp() as activation function serves solely to ensure positivity of the variance)

    :param hidden_dim: number of nodes in the hidden layer
    :param z_dim: dimension of the latent variable z
    :return: (init_fun, apply_fun) pair of the encoder: (encoder_init, encode)
    """
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
コード例 #28
0
def _build_unet_shell(layer, num_filters, activation):
    """Builds a shell in the U-net structure.

  *--------------*
  |              |                     *--------*     *----------------------*
  | downsampling |---------------------|        |     |                      |
  |   block      |   *------------*    | concat |-----|   upsampling block   |
  |              |---|    layer   |----|        |     |                      |
  *--------------*   *------------*    *--------*     *----------------------*

  Args:
    layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure.
    num_filters: Integer, the number of filters used for downsampling and
        upsampling.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
    return stax.serial(downsampling_block(num_filters, activation=activation),
                       stax.FanOut(2), stax.parallel(stax.Identity, layer),
                       stax.FanInConcat(axis=-1),
                       upsampling_block(num_filters, activation=activation))
コード例 #29
0
def ResNet(hidden_channels, out_channels, depth):
    # time integration module
    backbone = stax.serial(
        stax.GeneralConv(
            ("NCDWH", "IDWHO", "NCDWH"), hidden_channels, (4, 3, 3), (1, 1, 1), "SAME"
        ),
        *[
            ResidualBlock(
                hidden_channels,
                (4, 3, 3),
                (1, 1, 1),
                "SAME",
                ("NCDWH", "IDWHO", "NCDWH"),
            )
            for _ in range(depth)
        ],
        stax.GeneralConv(
            ("NCDWH", "IDWHO", "NCDWH"), out_channels, (4, 3, 3), (1, 1, 1), "SAME"
        ),
        stax.GeneralConv(("NDCWH", "IDWHO", "NDCWH"), 3, (3, 3, 3), (1, 1, 1), "SAME"),
    )

    #  euler scheme
    return stax.serial(stax.FanOut(2), stax.parallel(stax.Identity, backbone), Euler())
コード例 #30
0
ファイル: auto_reg_nn.py プロジェクト: ucals/numpyro
def AutoregressiveNN(input_dim,
                     hidden_dims,
                     param_dims=[1, 1],
                     permutation=None,
                     skip_connections=False,
                     nonlinearity=stax.Relu):
    """
    An implementation of a MADE-like auto-regressive neural network.

    Similar to the purely functional layer implemented in jax.experimental.stax,
    the `AutoregressiveNN` class has `init_fun` and `apply_fun` methods,
    where `init_fun` takes an rng_key key and an input shape and returns an
    (output_shape, params) pair, and `apply_fun` takes params and inputs
    and applies the layer.

    :param input_dim: the dimensionality of the input
    :type input_dim: int
    :param hidden_dims: the dimensionality of the hidden units per layer
    :type hidden_dims: list[int]
    :param param_dims: shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims
        when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters
        of dimension (input_dim), which is useful for inverse autoregressive flow.
    :type param_dims: list[int]
    :param permutation: an optional permutation that is applied to the inputs and controls the order of the
        autoregressive factorization. in particular for the identity permutation the autoregressive structure
        is such that the Jacobian is triangular. Defaults to identity permutation.
    :type permutation: array of ints
    :param bool skip_connection: whether to add skip connections from the input to the output.
    :type skip_connections: bool
    :param nonlinearity: The nonlinearity to use in the feedforward network such as ReLU. Note that no
        nonlinearity is applied to the final network output, so the output is an unbounded real number.
        defaults to ReLU.
    :type nonlinearity: callable.
    :return: a tuple (init_fun, apply_fun)

    Reference:

    MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509]
    Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle
    """
    output_multiplier = sum(param_dims)
    all_ones = (jnp.array(param_dims) == 1).all()

    # Calculate the indices on the output corresponding to each parameter
    ends = jnp.cumsum(jnp.array(param_dims), axis=0)
    starts = jnp.concatenate((jnp.zeros(1), ends[:-1]))
    param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)]

    # Hidden dimension must be not less than the input otherwise it isn't
    # possible to connect to the outputs correctly
    for h in hidden_dims:
        if h < input_dim:
            raise ValueError(
                'Hidden dimension must not be less than input dimension.')

    if permutation is None:
        permutation = jnp.arange(input_dim)

    # Create masks
    masks, mask_skip = create_mask(input_dim=input_dim,
                                   hidden_dims=hidden_dims,
                                   permutation=permutation,
                                   output_dim_multiplier=output_multiplier)

    main_layers = []
    # Create masked layers
    for i, mask in enumerate(masks):
        main_layers.append(MaskedDense(mask))
        if i < len(masks) - 1:
            main_layers.append(nonlinearity)

    if skip_connections:
        net_init, net = stax.serial(
            stax.FanOut(2),
            stax.parallel(stax.serial(*main_layers),
                          MaskedDense(mask_skip, bias=False)), stax.FanInSum)
    else:
        net_init, net = stax.serial(*main_layers)

    def init_fun(rng_key, input_shape):
        """
        :param rng_key: rng_key used to initialize parameters
        :param input_shape: input shape
        """
        assert input_dim == input_shape[-1]
        return net_init(rng_key, input_shape)

    def apply_fun(params, inputs, **kwargs):
        """
        :param params: layer parameters
        :param inputs: layer inputs
        """
        out = net(params, inputs, **kwargs)

        # reshape output as necessary
        out = jnp.reshape(out,
                          inputs.shape[:-1] + (output_multiplier, input_dim))
        # move param dims to the first dimension
        out = jnp.moveaxis(out, -2, 0)

        if all_ones:
            # Squeeze dimension if all parameters are one dimensional
            out = tuple([out[i] for i in range(output_multiplier)])
        else:
            # If not all ones, then probably don't want to squeeze a single dimension parameter
            out = tuple([out[s] for s in param_slices])

        # if len(param_dims) == 1, we return the array instead of a tuple of arrays
        return out[0] if len(param_dims) == 1 else out

    return init_fun, apply_fun