def Lpg(hparams):
    phi = serial(Dense(16), Dense(1))
    return serial(
        # FanOut(6),
        parallel(Identity, Identity, Identity, Identity, phi, phi),
        FanInConcat(),
        LSTMCell(hparams.hidden_size)[0:2],
        DiscardHidden(),
        Relu,
        FanOut(2),
        parallel(phi, phi),
    )
Ejemplo n.º 2
0
def residual(*layers, **kwargs):
    """Constructs a residual version of layers, summing input to layers output."""
    res = kwargs.get('res', stax.Identity)
    if len(layers) > 1:
        return stax.serial(stax.FanOut(2),
                           stax.parallel(stax.serial(*layers), res),
                           stax.FanInSum)
    elif len(layers) == 1:
        return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res),
                           stax.FanInSum)
    else:
        raise ValueError('Empty residual combinator.')
Ejemplo n.º 3
0
def MultiHeadedAttention(  # pylint: disable=invalid-name
        feature_depth,
        num_heads=8,
        dropout=1.0,
        mode='train'):
    """Transformer-style multi-headed attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate - keep probability
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return stax.serial(
        stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Dense(feature_depth, W_init=xavier_uniform()),
                      stax.Identity),
        PureMultiHeadedAttention(feature_depth,
                                 num_heads=num_heads,
                                 dropout=dropout,
                                 mode=mode),
        stax.Dense(feature_depth, W_init=xavier_uniform()),
    )
Ejemplo n.º 4
0
def network(rng: random.PRNGKey, num_in: int, num_hidden: int) -> Tuple:
    """Factory for producing the dequantization neural network and its
    parameterizations.

    Args:
        rng: Pseudo-random number generator seed.
        num_in: Number of inputs to the network.
        num_hidden: Number of hidden units in the hidden layer.

    Returns:
        out: A tuple containing the network parameters and a callable function
            that returns the neural network output for the given input.

    """
    num_in_sq = num_in**2
    num_in_choose_two = num_in * (num_in - 1) // 2
    params_init, fn = stax.serial(
        stax.Flatten, stax.Dense(num_hidden), stax.Relu,
        stax.Dense(num_hidden), stax.Relu, stax.FanOut(4),
        stax.parallel(
            stax.Dense(num_in),
            stax.Dense(num_in_choose_two),
            stax.serial(stax.Dense(num_in), stax.Softplus),
            stax.serial(stax.Dense(num_in_choose_two), stax.Softplus),
        ))
    _, params = params_init(rng, (-1, num_in_sq))
    return params, fn
Ejemplo n.º 5
0
def WideResnetBlock(channels,
                    strides=(1, 1),
                    channel_mismatch=False,
                    nonlin=Relu,
                    parameterization='standard',
                    order=None):
    Main = jax_stax.serial(
        nonlin,
        MyConv(channels, (3, 3),
               strides,
               padding='SAME',
               parameterization=parameterization,
               order=order), nonlin,
        MyConv(channels, (3, 3),
               padding='SAME',
               parameterization=parameterization,
               order=order))
    Shortcut = Identity if not channel_mismatch else MyConv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization,
        order=order)
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut),
                           FanInSum)
Ejemplo n.º 6
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),
                      stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
Ejemplo n.º 7
0
def ConvBlock(kernel_size,
              filters,
              strides=(2, 2),
              batchnorm=True,
              parameterization='standard',
              nonlin=Relu):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    if parameterization == 'standard':

        def MyConv(*args, **kwargs):
            return Conv(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyConv(*args, **kwargs):
            return stax.Conv(*args, **kwargs)[:2]

    if batchnorm:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(),
                               nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               BatchNorm(), nonlin, MyConv(filters3, (1, 1)),
                               BatchNorm())
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides),
                                   BatchNorm())
    else:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               nonlin, MyConv(filters3, (1, 1)))
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides))
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut),
                           FanInSum, nonlin)
Ejemplo n.º 8
0
def IdentityBlock(kernel_size,
                  filters,
                  batchnorm=True,
                  parameterization='standard',
                  nonlin=Relu):
    ks = kernel_size
    filters1, filters2 = filters
    if parameterization == 'standard':

        def MyConv(*args, **kwargs):
            return Conv(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyConv(*args, **kwargs):
            return stax.Conv(*args, **kwargs)[:2]

    def make_main(input_shape):
        # the number of output channels depends on the number of input channels
        if batchnorm:
            return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(),
                                   nonlin,
                                   MyConv(filters2, (ks, ks), padding='SAME'),
                                   BatchNorm(), nonlin,
                                   MyConv(input_shape[3], (1, 1)), BatchNorm())
        else:
            return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin,
                                   MyConv(filters2, (ks, ks), padding='SAME'),
                                   nonlin, MyConv(input_shape[3], (1, 1)))

    Main = jax_stax.shape_dependent(make_main)
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Identity),
                           FanInSum, nonlin)
Ejemplo n.º 9
0
def Lambda(fn):  # pylint: disable=invalid-name
    """Turn a normal function into a bound, callable Stax layer.

  Args:
    fn: a python function with _named_ args (i.e. no *args) and no kwargs.

  Returns:
    A callable, 'bound' staxlayer that can be assigned to a python variable and
    called like a function with other staxlayers as arguments.  Like Bind,
    wherever this value is placed in the stax tree, it will always output the
    same cached value.
  """
    # fn's args are just symbolic names that we fill with Vars.
    num_args = len(inspect.getargspec(fn).args)
    if num_args > 1:
        bound_args = Vars(num_args)
        return LambdaBind(
            stax.serial(
                stax.parallel(*bound_args),  # capture inputs
                _PlaceholderInputs,  # placeholders for input combinators inside fn
                fn(*bound_args)  # feed captured inputs into fn's args
            ))
    elif num_args == 1:
        bound_arg = Var()
        return LambdaBind(
            stax.serial(
                bound_arg,  # capture input
                _PlaceholderInputs,  # placeholders for input combinators inside fn
                fn(bound_arg)  # feed captured inputs into fn's args
            ))
    # LambdaBind when no args are given:
    else:
        return LambdaBind(fn())
Ejemplo n.º 10
0
 def __call__(self, *args):
     if len(args) > 1:
         return stax.serial(stax.parallel(*args), self)
     elif len(args) == 1:
         return stax.serial(args[0], self)
     else:
         return self
Ejemplo n.º 11
0
def CifarBasicBlockv2(planes,
                      stride=1,
                      option="A",
                      normalization_method=None,
                      use_fixup=False,
                      num_layers=None,
                      w_init=None,
                      actfn=stax.Relu):
    assert not use_fixup, "nah"
    Main = stax.serial(
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             W_init=w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False),
    )
    Shortcut = Identity
    if stride > 1:
        if option == "A":
            # For CIFAR10 ResNet paper uses option A.
            Shortcut = LambdaLayer(_shortcut_pad)
        elif option == "B":
            Shortcut = Conv(planes, (1, 1),
                            strides=(stride, stride),
                            W_init=w_init,
                            bias=False)
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
Ejemplo n.º 12
0
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor):
    """Global convolution block.

  First downsample the input, then apply global conv, finally upsample and
  concatenate with the input. The input itself is one channel in the output.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.

  Returns:
    (init_fn, apply_fn) pair.
  """
    layers = []
    layers.extend([linear_interpolation_transpose()] * downsample_factor)
    layers.append(
        exponential_global_convolution(
            num_channels=num_channels -
            1,  # one channel is reserved for input.
            grids=grids,
            minval=minval,
            maxval=maxval,
            downsample_factor=downsample_factor))
    layers.extend([linear_interpolation()] * downsample_factor)
    global_conv_path = stax.serial(*layers)
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(stax.Identity, global_conv_path),
        stax.FanInConcat(axis=-1),
    )
Ejemplo n.º 13
0
def ConvBlock(kernel_size, filters, strides=(2, 2)):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
                       Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(),
                       Relu, Conv(filters3, (1, 1)), BatchNorm())
    Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       Relu)
Ejemplo n.º 14
0
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format):
    double_conv = stax.serial(
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
        Elu,
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum)
    )
Ejemplo n.º 15
0
def PolicyNetwork():
    """Policy network for the experiments in:
    https://arxiv.org/abs/2102.12425"""
    return serial(
        helx.nn.rnn.LSTM(256),
        Dense(256),
        Relu,
        FanOut(2),
        parallel(Dense(1), Dense(1)),
    )
Ejemplo n.º 16
0
def decoder(hidden_dim, x_dim=2, activation='Tanh'):
    activation = getattr(stax, activation)
    decoder_init, decode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(x_dim), stax.Dense(x_dim)),
    )
    return decoder_init, decode
Ejemplo n.º 17
0
def encoder(hidden_dim, z_dim, activation='Tanh'):
    activation = getattr(stax, activation)
    encoder_init, encode = stax.serial(
        stax.Dense(hidden_dim),
        activation,
        # stax.Dense(hidden_dim), activation,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim), stax.Dense(z_dim)),
    )
    return encoder_init, encode
Ejemplo n.º 18
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.serial(stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum)
Ejemplo n.º 19
0
 def ConvBlock(self, kernel_size, filters, strides=(2, 2)):
     filters1, filters2, filters3 = filters
     Main = stax.serial(
         stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu,
         stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'),
         stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)),
         stax.BatchNorm())
     Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                            stax.BatchNorm())
     return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                        stax.FanInSum, stax.Relu)
Ejemplo n.º 20
0
def multiplex(*args):
    """Helper to form input argument lists of bound variables.

  Args:
    *args: list of bound layers or raw stax Identity layers.

  Returns:
    A layer returning in parallel the bound variables as well as
  (multiple) copies of this layer's input wherever Identity has been specified.
  """
    return stax.serial(stax.FanOut(len(args)), stax.parallel(*args))
Ejemplo n.º 21
0
def jaxRbmSpinPhase(hilbert, alpha):
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
            stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer,
                        SumLayer),
        ),
        FanInSum2ModPhase,
    )
Ejemplo n.º 22
0
def UNetBlock(filters, kernel_size, inner_block, **kwargs):
    def make_main(input_shape):
        return stax.serial(
            UnbiasedConv(filters, kernel_size, **kwargs),
            inner_block,
            UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs),
        )

    Main = stax.shape_dependent(make_main)
    return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity),
                       stax.FanInSum)
Ejemplo n.º 23
0
def IdentityBlock(kernel_size, filters):
  ks = kernel_size
  filters1, filters2 = filters
  def make_main(input_shape):
    # the number of output channels depends on the number of input channels
    return stax.serial(
        Conv(filters1, (1, 1)), BatchNorm(), Relu,
        Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
        Conv(input_shape[3], (1, 1)), BatchNorm())
  Main = stax.shape_dependent(make_main)
  return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
Ejemplo n.º 24
0
def SyntheticReturn(features_network):
    """Synthetic return module as described in:
    https://arxiv.org/abs/2102.12425,
    Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021."""
    #  sigmoid gate
    g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid)
    #  state utility contribution
    c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    #  state utility baseline
    b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    return serial(features_network, Flatten, FanOut(3),
                  parallel(g(), c(), b()))
Ejemplo n.º 25
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.serial(stax.Conv(filters1, (1, 1),
                                 strides), stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum, stax.Relu)
Ejemplo n.º 26
0
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(
                stax.Dense(z_dim, W_init=stax.randn()),
                stax.Exp,
            ),
        ),
    )
Ejemplo n.º 27
0
def NdmSpinPhase(hilbert,
                 alpha,
                 beta,
                 use_hidden_bias=True,
                 use_visible_bias=True):
    r"""
    A fully connected Neural Density Matrix (DBM). This type density matrix is
    obtained purifying a RBM with spin 1/2 hidden units.

    The number of purification hidden units can be chosen arbitrarily.

    The weights are taken to be complex-valued. A complete definition of this
    machine can be found in Eq. 2 of Hartmann, M. J. & Carleo, G.,
    Phys. Rev. Lett. 122, 250502 (2019).

    Args:
        hilbert: Hilbert space of the system.
        alpha: `alpha * hilbert.size` is the number of hidden spins used for
                the pure state part of the density-matrix.
        beta: `beta * hilbert.size` is the number of hidden spins used for the purification.
            beta=0 for example corresponds to a pure state.
        use_hidden_bias: If ``True`` bias on the hidden units is taken.
                         Default ``True``.
    """
    mod_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSum2,
    )

    phs_pure = stax.serial(
        DensePureRowCol(alpha * hilbert.size, use_hidden_bias),
        stax.parallel(LogCoshLayer, LogCoshLayer),
        stax.parallel(SumLayer, SumLayer),
        FanInSub2,
    )

    mixing = stax.serial(
        DenseMixingReal(beta * hilbert.size, use_hidden_bias),
        LogCoshLayer,
        SumLayer,
    )

    if use_visible_bias:
        biases = BiasRealModPhase()
        net = stax.serial(
            stax.FanOut(4),
            stax.parallel(mod_pure, phs_pure, mixing, biases),
            stax.FanInSum,
        )
    else:
        net = stax.serial(
            stax.FanOut(3),
            stax.parallel(mod_pure, phs_pure, mixing),
            stax.FanInSum,
        )

    return Jax(hilbert, net, dtype=float, outdtype=complex)
Ejemplo n.º 28
0
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False):
    """Wide ResNet block."""
    pre = stax.serial(stax.BatchNorm(), stax.Relu)
    mid = stax.serial(
        pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'),
        stax.BatchNorm(), stax.Relu,
        stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME'))
    if channel_mismatch:
        cut = stax.serial(
            pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'))
    else:
        cut = stax.Identity
    return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
Ejemplo n.º 29
0
    def IdentityBlock(self, kernel_size, filters):
        filters1, filters2 = filters

        def make_main(input_shape):
            # the number of output channels depends on the number of input channels
            return stax.serial(
                stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu,
                stax.Conv(filters2, (kernel_size, kernel_size),
                          padding='SAME'), stax.BatchNorm(), stax.Relu,
                stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())

        Main = stax.shape_dependent(make_main)
        return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity),
                           stax.FanInSum, stax.Relu)
Ejemplo n.º 30
0
def ResNet(
    hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format
):
    residual = stax.serial(
        GeneralConv(input_format, hidden_channels, kernel_size, strides, padding),
        *[
            ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format)
            for _ in range(depth)
        ],
        GeneralConv(input_format, out_channels, kernel_size, strides, padding)
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1))
    )