def global_conv_block(num_channels, grids, minval, maxval, downsample_factor): """Global convolution block. First downsample the input, then apply global conv, finally upsample and concatenate with the input. The input itself is one channel in the output. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.extend([linear_interpolation_transpose()] * downsample_factor) layers.append( exponential_global_convolution( num_channels=num_channels - 1, # one channel is reserved for input. grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([linear_interpolation()] * downsample_factor) global_conv_path = stax.serial(*layers) return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, global_conv_path), stax.FanInConcat(axis=-1), )
def _build_unet_shell(layer, num_filters, activation): """Builds a shell in the U-net structure. *--------------* | | *--------* *----------------------* | downsampling |---------------------| | | | | block | *------------* | concat |-----| upsampling block | | |---| layer |----| | | | *--------------* *------------* *--------* *----------------------* Args: layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure. num_filters: Integer, the number of filters used for downsampling and upsampling. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial(downsampling_block(num_filters, activation=activation), stax.FanOut(2), stax.parallel(stax.Identity, layer), stax.FanInConcat(axis=-1), upsampling_block(num_filters, activation=activation))
def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
def gen_bann(output_units = 10, mode = 'test', first_layer_width = 5, second_layer_width = 20, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'): """ This is a modern variant of the lenet with relu activation """ CSB = stax.serial( stax.FanOut(10), stax.parallel( stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.9, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(2 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.9, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(3 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.8, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(4 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.8, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(5 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.7, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(6 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.7, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(7 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.6, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(8 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.6, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(9 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.5, mode = mode) ), stax.serial( stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(10 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dropout(rate = 0.5, mode = mode) ) ), stax.FanInConcat() ) #optional Additive = stax.serial( stax.FanOut(2), stax.parallel( stax.serial( CSB, stax.Dropout(rate = 0.9, mode = mode) ), stax.serial( CSB, stax.Dropout(rate = 0.9, mode = mode) ) ), stax.FanInConcat() ) return stax.serial( CSB, stax.Dropout(rate = 0.9, mode = mode), stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)))