Ejemplo n.º 1
0
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor):
    """Global convolution block.

  First downsample the input, then apply global conv, finally upsample and
  concatenate with the input. The input itself is one channel in the output.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.

  Returns:
    (init_fn, apply_fn) pair.
  """
    layers = []
    layers.extend([linear_interpolation_transpose()] * downsample_factor)
    layers.append(
        exponential_global_convolution(
            num_channels=num_channels -
            1,  # one channel is reserved for input.
            grids=grids,
            minval=minval,
            maxval=maxval,
            downsample_factor=downsample_factor))
    layers.extend([linear_interpolation()] * downsample_factor)
    global_conv_path = stax.serial(*layers)
    return stax.serial(
        stax.FanOut(2),
        stax.parallel(stax.Identity, global_conv_path),
        stax.FanInConcat(axis=-1),
    )
Ejemplo n.º 2
0
def _build_unet_shell(layer, num_filters, activation):
    """Builds a shell in the U-net structure.

  *--------------*
  |              |                     *--------*     *----------------------*
  | downsampling |---------------------|        |     |                      |
  |   block      |   *------------*    | concat |-----|   upsampling block   |
  |              |---|    layer   |----|        |     |                      |
  *--------------*   *------------*    *--------*     *----------------------*

  Args:
    layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure.
    num_filters: Integer, the number of filters used for downsampling and
        upsampling.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
    return stax.serial(downsampling_block(num_filters, activation=activation),
                       stax.FanOut(2), stax.parallel(stax.Identity, layer),
                       stax.FanInConcat(axis=-1),
                       upsampling_block(num_filters, activation=activation))
Ejemplo n.º 3
0
 def testFanInConcat(self, input_shapes, axis):
     init_fun, apply_fun = stax.FanInConcat(axis)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
Ejemplo n.º 4
0
def gen_bann(output_units = 10, mode = 'test', first_layer_width = 5, second_layer_width = 20, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'):
    """ This is a modern variant of the lenet with relu activation """
    CSB = stax.serial(
        stax.FanOut(10),
        stax.parallel(
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.9, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(2 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.9, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(3 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.8, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(4 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.8, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(5 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.7, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(6 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.7, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(7 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.6, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(8 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.6, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(9 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.5, mode = mode)
            ),
            stax.serial(
                stax.Dense(first_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dense(10 * second_layer_width, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu,
                stax.Dropout(rate = 0.5, mode = mode)
            )
        ),
        stax.FanInConcat()
    )

    #optional
    Additive = stax.serial(
        stax.FanOut(2),
        stax.parallel(
            stax.serial(
                CSB,
                stax.Dropout(rate = 0.9, mode = mode)
            ),
            stax.serial(
                CSB,
                stax.Dropout(rate = 0.9, mode = mode)
            )
        ),
        stax.FanInConcat()
    )

    return stax.serial(
    CSB,
    stax.Dropout(rate = 0.9, mode = mode),
    stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)))