Exemple #1
0
def global_conv_block(num_channels, grids, minval, maxval, downsample_factor):
  """Global convolution block.

  First downsample the input, then apply global conv, finally upsample and
  concatenate with the input. The input itself is one channel in the output.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.

  Returns:
    (init_fn, apply_fn) pair.
  """
  layers = []
  layers.extend([linear_interpolation_transpose()] * downsample_factor)
  layers.append(exponential_global_convolution(
      num_channels=num_channels - 1,  # one channel is reserved for input.
      grids=grids,
      minval=minval,
      maxval=maxval,
      downsample_factor=downsample_factor))
  layers.extend([linear_interpolation()] * downsample_factor)
  global_conv_path = stax.serial(*layers)
  return stax.serial(
      stax.FanOut(2),
      stax.parallel(stax.Identity, global_conv_path),
      stax.FanInConcat(axis=-1),
  )
Exemple #2
0
def _build_unet_shell(layer, num_filters, activation):
  """Builds a shell in the U-net structure.

  *--------------*
  |              |                     *--------*     *----------------------*
  | downsampling |---------------------|        |     |                      |
  |   block      |   *------------*    | concat |-----|   upsampling block   |
  |              |---|    layer   |----|        |     |                      |
  *--------------*   *------------*    *--------*     *----------------------*

  Args:
    layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure.
    num_filters: Integer, the number of filters used for downsampling and
        upsampling.
    activation: String, the activation function to use in the network.

  Returns:
    (init_fn, apply_fn) pair.
  """
  return stax.serial(
      downsampling_block(num_filters, activation=activation),
      stax.FanOut(2),
      stax.parallel(stax.Identity, layer),
      stax.FanInConcat(axis=-1),
      upsampling_block(num_filters, activation=activation)
  )
Exemple #3
0
def FanInConcat(axis: int = -1) -> InternalLayerMasked:
    """Layer construction function for a fan-in concatenation layer.

  Based on `jax.example_libraries.stax.FanInConcat`.

  Args:
    axis: Specifies the axis along which input tensors should be concatenated.

  Returns:
    `(init_fn, apply_fn, kernel_fn)`.
  """
    init_fn, apply_fn = ostax.FanInConcat(axis)

    def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
        ks, is_reversed = _preprocess_kernels_for_fan_in(ks)

        diagonal_batch = ks[0].diagonal_batch
        diagonal_spatial = ks[0].diagonal_spatial

        shape1, shape2 = ks[0].shape1, ks[0].shape2

        ndim = len(shape1)
        _axis = axis % ndim
        batch_axis = ks[0].batch_axis
        channel_axis = ks[0].channel_axis

        new_shape1 = shape1[:_axis] + shape1[_axis + 1:]
        new_shape2 = shape2[:_axis] + shape2[_axis + 1:]
        for k in ks:
            k_shape1 = k.shape1[:_axis] + k.shape1[_axis + 1:]
            k_shape2 = k.shape2[:_axis] + k.shape2[_axis + 1:]
            if k_shape1 != new_shape1 or k_shape2 != new_shape2:
                raise ValueError(
                    'Non-`axis` shapes should be equal in `FanInConcat`.')

        # Check if inputs are independent Gaussians.
        if _axis != channel_axis:
            is_gaussian = all(k.is_gaussian for k in ks)
            if not is_gaussian:
                # TODO(xlc): FanInSum/FanInConcat could allow non-Gaussian inputs, but
                # we need to propagate the mean of the random variables as well.
                raise NotImplementedError(
                    '`FanInConcat` layer along the non-channel axis is only implemented'
                    'for the case if all input layers guaranteed to be mean-zero '
                    'Gaussian, i.e. having all `is_gaussian set to `True`.')
        else:
            # TODO(romann): allow to apply nonlinearity after
            # channelwise concatenation.
            # TODO(romann): support concatenating different channelwise masks.
            is_gaussian = False

        if _axis == batch_axis:
            warnings.warn(
                f'Concatenation along the batch axis ({_axis}) gives '
                f'inconsistent covariances when batching - '
                f'proceed with caution.')

        spatial_axes = tuple(i for i in range(ndim)
                             if i not in (channel_axis, batch_axis))
        # Change spatial axis according to the kernel `is_reversed`.
        if _axis in spatial_axes and is_reversed:
            _axis = spatial_axes[::-1][spatial_axes.index(_axis)]

        # Map activation tensor axis to the covariance tensor axis.
        tensor_axis_to_kernel_axis = {
            **{
                batch_axis: 0,
                channel_axis: -1,
            },
            **{
                spatial_axis: idx + 1
                for idx, spatial_axis in enumerate(spatial_axes)
            }
        }

        _axis = tensor_axis_to_kernel_axis[_axis]
        widths = [k.shape1[channel_axis] for k in ks]

        cov1 = _concat_kernels([k.cov1 for k in ks], _axis, diagonal_batch,
                               diagonal_spatial, widths)
        cov2 = _concat_kernels([k.cov2 for k in ks], _axis, diagonal_batch,
                               diagonal_spatial, widths)
        nngp = _concat_kernels([k.nngp for k in ks], _axis, False,
                               diagonal_spatial, widths)
        ntk = _concat_kernels([k.ntk for k in ks], _axis, False,
                              diagonal_spatial, widths)

        return Kernel(cov1=cov1,
                      cov2=cov2,
                      nngp=nngp,
                      ntk=ntk,
                      x1_is_x2=ks[0].x1_is_x2,
                      is_gaussian=is_gaussian,
                      is_reversed=is_reversed,
                      is_input=ks[0].is_input,
                      diagonal_batch=diagonal_batch,
                      diagonal_spatial=diagonal_spatial,
                      shape1=None,
                      shape2=None,
                      batch_axis=batch_axis,
                      channel_axis=channel_axis,
                      mask1=None,
                      mask2=None)  # pytype:disable=wrong-keyword-args

    def mask_fn(mask, input_shape):
        return _concat_masks(mask, input_shape, axis)

    return init_fn, apply_fn, kernel_fn, mask_fn
Exemple #4
0
 def testFanInConcat(self, input_shapes, axis):
     init_fun, apply_fun = stax.FanInConcat(axis)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)