def Lpg(hparams): phi = serial(Dense(16), Dense(1)) return serial( # FanOut(6), parallel(Identity, Identity, Identity, Identity, phi, phi), FanInConcat(), LSTMCell(hparams.hidden_size)[0:2], DiscardHidden(), Relu, FanOut(2), parallel(phi, phi), )
def residual(*layers, **kwargs): """Constructs a residual version of layers, summing input to layers output.""" res = kwargs.get('res', stax.Identity) if len(layers) > 1: return stax.serial(stax.FanOut(2), stax.parallel(stax.serial(*layers), res), stax.FanInSum) elif len(layers) == 1: return stax.serial(stax.FanOut(2), stax.parallel(layers[0], res), stax.FanInSum) else: raise ValueError('Empty residual combinator.')
def MultiHeadedAttention( # pylint: disable=invalid-name feature_depth, num_heads=8, dropout=1.0, mode='train'): """Transformer-style multi-headed attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate - keep probability mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return stax.serial( stax.parallel(stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Dense(feature_depth, W_init=xavier_uniform()), stax.Identity), PureMultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dense(feature_depth, W_init=xavier_uniform()), )
def network(rng: random.PRNGKey, num_in: int, num_hidden: int) -> Tuple: """Factory for producing the dequantization neural network and its parameterizations. Args: rng: Pseudo-random number generator seed. num_in: Number of inputs to the network. num_hidden: Number of hidden units in the hidden layer. Returns: out: A tuple containing the network parameters and a callable function that returns the neural network output for the given input. """ num_in_sq = num_in**2 num_in_choose_two = num_in * (num_in - 1) // 2 params_init, fn = stax.serial( stax.Flatten, stax.Dense(num_hidden), stax.Relu, stax.Dense(num_hidden), stax.Relu, stax.FanOut(4), stax.parallel( stax.Dense(num_in), stax.Dense(num_in_choose_two), stax.serial(stax.Dense(num_in), stax.Softplus), stax.serial(stax.Dense(num_in_choose_two), stax.Softplus), )) _, params = params_init(rng, (-1, num_in_sq)) return params, fn
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False, nonlin=Relu, parameterization='standard', order=None): Main = jax_stax.serial( nonlin, MyConv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization, order=order), nonlin, MyConv(channels, (3, 3), padding='SAME', parameterization=parameterization, order=order)) Shortcut = Identity if not channel_mismatch else MyConv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization, order=order) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut), FanInSum)
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel(stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)), )
def ConvBlock(kernel_size, filters, strides=(2, 2), batchnorm=True, parameterization='standard', nonlin=Relu): ks = kernel_size filters1, filters2, filters3 = filters if parameterization == 'standard': def MyConv(*args, **kwargs): return Conv(*args, **kwargs) elif parameterization == 'ntk': def MyConv(*args, **kwargs): return stax.Conv(*args, **kwargs)[:2] if batchnorm: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(filters3, (1, 1)), BatchNorm()) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides), BatchNorm()) else: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(filters3, (1, 1))) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides)) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut), FanInSum, nonlin)
def IdentityBlock(kernel_size, filters, batchnorm=True, parameterization='standard', nonlin=Relu): ks = kernel_size filters1, filters2 = filters if parameterization == 'standard': def MyConv(*args, **kwargs): return Conv(*args, **kwargs) elif parameterization == 'ntk': def MyConv(*args, **kwargs): return stax.Conv(*args, **kwargs)[:2] def make_main(input_shape): # the number of output channels depends on the number of input channels if batchnorm: return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(input_shape[3], (1, 1)), BatchNorm()) else: return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(input_shape[3], (1, 1))) Main = jax_stax.shape_dependent(make_main) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Identity), FanInSum, nonlin)
def Lambda(fn): # pylint: disable=invalid-name """Turn a normal function into a bound, callable Stax layer. Args: fn: a python function with _named_ args (i.e. no *args) and no kwargs. Returns: A callable, 'bound' staxlayer that can be assigned to a python variable and called like a function with other staxlayers as arguments. Like Bind, wherever this value is placed in the stax tree, it will always output the same cached value. """ # fn's args are just symbolic names that we fill with Vars. num_args = len(inspect.getargspec(fn).args) if num_args > 1: bound_args = Vars(num_args) return LambdaBind( stax.serial( stax.parallel(*bound_args), # capture inputs _PlaceholderInputs, # placeholders for input combinators inside fn fn(*bound_args) # feed captured inputs into fn's args )) elif num_args == 1: bound_arg = Var() return LambdaBind( stax.serial( bound_arg, # capture input _PlaceholderInputs, # placeholders for input combinators inside fn fn(bound_arg) # feed captured inputs into fn's args )) # LambdaBind when no args are given: else: return LambdaBind(fn())
def __call__(self, *args): if len(args) > 1: return stax.serial(stax.parallel(*args), self) elif len(args) == 1: return stax.serial(args[0], self) else: return self
def CifarBasicBlockv2(planes, stride=1, option="A", normalization_method=None, use_fixup=False, num_layers=None, w_init=None, actfn=stax.Relu): assert not use_fixup, "nah" Main = stax.serial( maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", W_init=w_init, bias=False), maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False), ) Shortcut = Identity if stride > 1: if option == "A": # For CIFAR10 ResNet paper uses option A. Shortcut = LambdaLayer(_shortcut_pad) elif option == "B": Shortcut = Conv(planes, (1, 1), strides=(stride, stride), W_init=w_init, bias=False) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor): """Global convolution block. First downsample the input, then apply global conv, finally upsample and concatenate with the input. The input itself is one channel in the output. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.extend([linear_interpolation_transpose()] * downsample_factor) layers.append( exponential_global_convolution( num_channels=num_channels - 1, # one channel is reserved for input. grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([linear_interpolation()] * downsample_factor) global_conv_path = stax.serial(*layers) return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, global_conv_path), stax.FanInConcat(axis=-1), )
def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format): double_conv = stax.serial( GeneralConv(input_format, out_channels, kernel_size, stride, padding), Elu, GeneralConv(input_format, out_channels, kernel_size, stride, padding), ) return Module( *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum) )
def PolicyNetwork(): """Policy network for the experiments in: https://arxiv.org/abs/2102.12425""" return serial( helx.nn.rnn.LSTM(256), Dense(256), Relu, FanOut(2), parallel(Dense(1), Dense(1)), )
def decoder(hidden_dim, x_dim=2, activation='Tanh'): activation = getattr(stax, activation) decoder_init, decode = stax.serial( stax.Dense(hidden_dim), activation, # stax.Dense(hidden_dim), activation, stax.FanOut(2), stax.parallel(stax.Dense(x_dim), stax.Dense(x_dim)), ) return decoder_init, decode
def encoder(hidden_dim, z_dim, activation='Tanh'): activation = getattr(stax, activation) encoder_init, encode = stax.serial( stax.Dense(hidden_dim), activation, # stax.Dense(hidden_dim), activation, stax.FanOut(2), stax.parallel(stax.Dense(z_dim), stax.Dense(z_dim)), ) return encoder_init, encode
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.serial(stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum)
def ConvBlock(self, kernel_size, filters, strides=(2, 2)): filters1, filters2, filters3 = filters Main = stax.serial( stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum, stax.Relu)
def multiplex(*args): """Helper to form input argument lists of bound variables. Args: *args: list of bound layers or raw stax Identity layers. Returns: A layer returning in parallel the bound variables as well as (multiple) copies of this layer's input wherever Identity has been specified. """ return stax.serial(stax.FanOut(len(args)), stax.parallel(*args))
def jaxRbmSpinPhase(hilbert, alpha): return stax.serial( stax.FanOut(2), stax.parallel( stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer), stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer), ), FanInSum2ModPhase, )
def UNetBlock(filters, kernel_size, inner_block, **kwargs): def make_main(input_shape): return stax.serial( UnbiasedConv(filters, kernel_size, **kwargs), inner_block, UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs), ) Main = stax.shape_dependent(make_main) return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity), stax.FanInSum)
def IdentityBlock(kernel_size, filters): ks = kernel_size filters1, filters2 = filters def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def SyntheticReturn(features_network): """Synthetic return module as described in: https://arxiv.org/abs/2102.12425, Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021.""" # sigmoid gate g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid) # state utility contribution c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) # state utility baseline b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) return serial(features_network, Flatten, FanOut(3), parallel(g(), c(), b()))
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum, stax.Relu)
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]: return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial( stax.Dense(z_dim, W_init=stax.randn()), stax.Exp, ), ), )
def NdmSpinPhase(hilbert, alpha, beta, use_hidden_bias=True, use_visible_bias=True): r""" A fully connected Neural Density Matrix (DBM). This type density matrix is obtained purifying a RBM with spin 1/2 hidden units. The number of purification hidden units can be chosen arbitrarily. The weights are taken to be complex-valued. A complete definition of this machine can be found in Eq. 2 of Hartmann, M. J. & Carleo, G., Phys. Rev. Lett. 122, 250502 (2019). Args: hilbert: Hilbert space of the system. alpha: `alpha * hilbert.size` is the number of hidden spins used for the pure state part of the density-matrix. beta: `beta * hilbert.size` is the number of hidden spins used for the purification. beta=0 for example corresponds to a pure state. use_hidden_bias: If ``True`` bias on the hidden units is taken. Default ``True``. """ mod_pure = stax.serial( DensePureRowCol(alpha * hilbert.size, use_hidden_bias), stax.parallel(LogCoshLayer, LogCoshLayer), stax.parallel(SumLayer, SumLayer), FanInSum2, ) phs_pure = stax.serial( DensePureRowCol(alpha * hilbert.size, use_hidden_bias), stax.parallel(LogCoshLayer, LogCoshLayer), stax.parallel(SumLayer, SumLayer), FanInSub2, ) mixing = stax.serial( DenseMixingReal(beta * hilbert.size, use_hidden_bias), LogCoshLayer, SumLayer, ) if use_visible_bias: biases = BiasRealModPhase() net = stax.serial( stax.FanOut(4), stax.parallel(mod_pure, phs_pure, mixing, biases), stax.FanInSum, ) else: net = stax.serial( stax.FanOut(3), stax.parallel(mod_pure, phs_pure, mixing), stax.FanInSum, ) return Jax(hilbert, net, dtype=float, outdtype=complex)
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False): """Wide ResNet block.""" pre = stax.serial(stax.BatchNorm(), stax.Relu) mid = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME')) if channel_mismatch: cut = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME')) else: cut = stax.Identity return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
def IdentityBlock(self, kernel_size, filters): filters1, filters2 = filters def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm()) Main = stax.shape_dependent(make_main) return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity), stax.FanInSum, stax.Relu)
def ResNet( hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format ): residual = stax.serial( GeneralConv(input_format, hidden_channels, kernel_size, strides, padding), *[ ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format) for _ in range(depth) ], GeneralConv(input_format, out_channels, kernel_size, strides, padding) ) return Module( *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1)) )