def _build_unet_shell(layer, num_filters, activation): """Builds a shell in the U-net structure. *--------------* | | *--------* *----------------------* | downsampling |---------------------| | | | | block | *------------* | concat |-----| upsampling block | | |---| layer |----| | | | *--------------* *------------* *--------* *----------------------* Args: layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure. num_filters: Integer, the number of filters used for downsampling and upsampling. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial( downsampling_block(num_filters, activation=activation), stax.FanOut(2), stax.parallel(stax.Identity, layer), stax.FanInConcat(axis=-1), upsampling_block(num_filters, activation=activation) )
def parallel(*layers: Layer) -> InternalLayer: """Combinator for composing layers in parallel. The layer resulting from this combinator is often used with the `FanOut` and `FanInSum`/`FanInConcat` layers. Based on `jax.example_libraries.stax.parallel`. Args: *layers: a sequence of layers, each with a `(init_fn, apply_fn, kernel_fn)` triple. Returns: A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triples, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument `layers`. """ init_fns, apply_fns, kernel_fns = zip(*layers) init_fn_stax, apply_fn_stax = ostax.parallel(*zip(init_fns, apply_fns)) def init_fn(rng: random.KeyArray, input_shape: Shapes): return type(input_shape)(init_fn_stax(rng, input_shape)) def apply_fn(params, inputs, **kwargs): return type(inputs)(apply_fn_stax(params, inputs, **kwargs)) @requires(**_get_input_req_attr(kernel_fns, fold=op.and_)) def kernel_fn(ks: NTTrees[Kernel], **kwargs) -> NTTrees[Kernel]: return type(ks)(f(k, **kwargs) for k, f in zip(ks, kernel_fns)) return init_fn, apply_fn, kernel_fn
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor): """Global convolution block. First downsample the input, then apply global conv, finally upsample and concatenate with the input. The input itself is one channel in the output. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.extend([linear_interpolation_transpose()] * downsample_factor) layers.append(exponential_global_convolution( num_channels=num_channels - 1, # one channel is reserved for input. grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([linear_interpolation()] * downsample_factor) global_conv_path = stax.serial(*layers) return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, global_conv_path), stax.FanInConcat(axis=-1), )
def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial( Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp), ), )
def IdentityBlock(kernel_size, filters): ks = kernel_size filters1, filters2 = filters def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def UNetBlock(filters, kernel_size, inner_block, **kwargs): def make_main(input_shape): return stax.serial( UnbiasedConv(filters, kernel_size, **kwargs), inner_block, UnbiasedConvTranspose(input_shape[3], kernel_size, **kwargs), ) Main = stax.shape_dependent(make_main) return stax.serial(stax.FanOut(2), stax.parallel(Main, stax.Identity), stax.FanInSum)
def wrap_network_with_self_interaction_layer(network, grids, interaction_fn): """Wraps a network with self-interaction layer. Args: network: an (init_fn, apply_fn) pair. * init_fn: The init_fn of the neural network. It takes an rng key and an input shape and returns an (output_shape, params) pair. * apply_fn: The apply_fn of the neural network. It takes params, inputs, and an rng key and applies the layer. grids: Float numpy array with shape (num_grids,). interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. Returns: (init_fn, apply_fn) pair. """ return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, network), self_interaction_layer(grids, interaction_fn), )
def BlockNeuralAutoregressiveNN(input_dim, hidden_factors=[8, 8], residual=None): """ An implementation of Block Neural Autoregressive neural network. **References** 1. *Block Neural Autoregressive Flow*, Nicola De Cao, Ivan Titov, Wilker Aziz :param int input_dim: The dimensionality of the input. :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1]. The elements of hidden_factors must be integers. :param str residual: Type of residual connections to use. One of `None`, `"normal"`, `"gated"`. :return: an (`init_fn`, `update_fn`) pair. """ layers = [] in_factor = 1 for hidden_factor in hidden_factors: layers.append(BlockMaskedDense(input_dim, in_factor, hidden_factor)) layers.append(Tanh()) in_factor = hidden_factor layers.append(BlockMaskedDense(input_dim, in_factor, 1)) arn = stax.serial(*layers) if residual is not None: FanInResidual = (FanInResidualGated if residual == "gated" else FanInResidualNormal) arn = stax.serial(stax.FanOut(2), stax.parallel(arn, stax.Identity), FanInResidual()) def init_fun(rng, input_shape): return arn[0](rng, input_shape) def apply_fun(params, inputs, **kwargs): out, logdet = arn[1](params, (inputs, None), **kwargs) return out, logdet.reshape(inputs.shape) return init_fun, apply_fun
def image_grid(nrow, ncol, imagevecs, imshape): """Reshape a stack of image vectors into an image grid for plotting.""" images = iter(imagevecs.reshape((-1, ) + imshape)) return jnp.vstack([ jnp.hstack([next(images).T for _ in range(ncol)][::-1]) for _ in range(nrow) ]).T encoder_init, encode = stax.serial( Dense(512), Relu, Dense(512), Relu, FanOut(2), stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)), ) decoder_init, decode = stax.serial( Dense(512), Relu, Dense(512), Relu, Dense(28 * 28), ) if __name__ == "__main__": step_size = 0.001 num_epochs = 100 batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size
def AutoregressiveNN( input_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=stax.Relu, ): """ An implementation of a MADE-like auto-regressive neural network. Similar to the purely functional layer implemented in jax.example_libraries.stax, the `AutoregressiveNN` class has `init_fun` and `apply_fun` methods, where `init_fun` takes an rng_key key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params and inputs and applies the layer. :param input_dim: the dimensionality of the input :type input_dim: int :param hidden_dims: the dimensionality of the hidden units per layer :type hidden_dims: list[int] :param param_dims: shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow. :type param_dims: list[int] :param permutation: an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is triangular. Defaults to identity permutation. :type permutation: array of ints :param bool skip_connection: whether to add skip connections from the input to the output. :type skip_connections: bool :param nonlinearity: The nonlinearity to use in the feedforward network such as ReLU. Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number. defaults to ReLU. :type nonlinearity: callable. :return: a tuple (init_fun, apply_fun) Reference: MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle """ output_multiplier = sum(param_dims) all_ones = (np.array(param_dims) == 1).all() # Calculate the indices on the output corresponding to each parameter ends = np.cumsum(np.array(param_dims), axis=0) starts = np.concatenate((np.zeros(1), ends[:-1])) param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)] # Hidden dimension must be not less than the input otherwise it isn't # possible to connect to the outputs correctly for h in hidden_dims: if h < input_dim: raise ValueError( "Hidden dimension must not be less than input dimension.") if permutation is None: permutation = jnp.arange(input_dim) # Create masks masks, mask_skip = create_mask( input_dim=input_dim, hidden_dims=hidden_dims, permutation=permutation, output_dim_multiplier=output_multiplier, ) main_layers = [] # Create masked layers for i, mask in enumerate(masks): main_layers.append(MaskedDense(mask)) if i < len(masks) - 1: main_layers.append(nonlinearity) if skip_connections: net_init, net = stax.serial( stax.FanOut(2), stax.parallel(stax.serial(*main_layers), MaskedDense(mask_skip, bias=False)), stax.FanInSum, ) else: net_init, net = stax.serial(*main_layers) def init_fun(rng_key, input_shape): """ :param rng_key: rng_key used to initialize parameters :param input_shape: input shape """ assert input_dim == input_shape[-1] return net_init(rng_key, input_shape) def apply_fun(params, inputs, **kwargs): """ :param params: layer parameters :param inputs: layer inputs """ out = net(params, inputs, **kwargs) # reshape output as necessary out = jnp.reshape(out, inputs.shape[:-1] + (output_multiplier, input_dim)) # move param dims to the first dimension out = jnp.moveaxis(out, -2, 0) if all_ones: # Squeeze dimension if all parameters are one dimensional out = tuple([out[i] for i in range(output_multiplier)]) else: # If not all ones, then probably don't want to squeeze a single dimension parameter out = tuple([out[s] for s in param_slices]) # if len(param_dims) == 1, we return the array instead of a tuple of arrays return out[0] if len(param_dims) == 1 else out return init_fun, apply_fun