def global_conv_block(num_channels, grids, minval, maxval, downsample_factor): """Global convolution block. First downsample the input, then apply global conv, finally upsample and concatenate with the input. The input itself is one channel in the output. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. Returns: (init_fn, apply_fn) pair. """ layers = [] layers.extend([linear_interpolation_transpose()] * downsample_factor) layers.append(exponential_global_convolution( num_channels=num_channels - 1, # one channel is reserved for input. grids=grids, minval=minval, maxval=maxval, downsample_factor=downsample_factor)) layers.extend([linear_interpolation()] * downsample_factor) global_conv_path = stax.serial(*layers) return stax.serial( stax.FanOut(2), stax.parallel(stax.Identity, global_conv_path), stax.FanInConcat(axis=-1), )
def _build_unet_shell(layer, num_filters, activation): """Builds a shell in the U-net structure. *--------------* | | *--------* *----------------------* | downsampling |---------------------| | | | | block | *------------* | concat |-----| upsampling block | | |---| layer |----| | | | *--------------* *------------* *--------* *----------------------* Args: layer: (init_fn, apply_fn) pair in the bottom of the U-shape structure. num_filters: Integer, the number of filters used for downsampling and upsampling. activation: String, the activation function to use in the network. Returns: (init_fn, apply_fn) pair. """ return stax.serial( downsampling_block(num_filters, activation=activation), stax.FanOut(2), stax.parallel(stax.Identity, layer), stax.FanInConcat(axis=-1), upsampling_block(num_filters, activation=activation) )
def FanInConcat(axis: int = -1) -> InternalLayerMasked: """Layer construction function for a fan-in concatenation layer. Based on `jax.example_libraries.stax.FanInConcat`. Args: axis: Specifies the axis along which input tensors should be concatenated. Returns: `(init_fn, apply_fn, kernel_fn)`. """ init_fn, apply_fn = ostax.FanInConcat(axis) def kernel_fn(ks: Kernels, **kwargs) -> Kernel: ks, is_reversed = _preprocess_kernels_for_fan_in(ks) diagonal_batch = ks[0].diagonal_batch diagonal_spatial = ks[0].diagonal_spatial shape1, shape2 = ks[0].shape1, ks[0].shape2 ndim = len(shape1) _axis = axis % ndim batch_axis = ks[0].batch_axis channel_axis = ks[0].channel_axis new_shape1 = shape1[:_axis] + shape1[_axis + 1:] new_shape2 = shape2[:_axis] + shape2[_axis + 1:] for k in ks: k_shape1 = k.shape1[:_axis] + k.shape1[_axis + 1:] k_shape2 = k.shape2[:_axis] + k.shape2[_axis + 1:] if k_shape1 != new_shape1 or k_shape2 != new_shape2: raise ValueError( 'Non-`axis` shapes should be equal in `FanInConcat`.') # Check if inputs are independent Gaussians. if _axis != channel_axis: is_gaussian = all(k.is_gaussian for k in ks) if not is_gaussian: # TODO(xlc): FanInSum/FanInConcat could allow non-Gaussian inputs, but # we need to propagate the mean of the random variables as well. raise NotImplementedError( '`FanInConcat` layer along the non-channel axis is only implemented' 'for the case if all input layers guaranteed to be mean-zero ' 'Gaussian, i.e. having all `is_gaussian set to `True`.') else: # TODO(romann): allow to apply nonlinearity after # channelwise concatenation. # TODO(romann): support concatenating different channelwise masks. is_gaussian = False if _axis == batch_axis: warnings.warn( f'Concatenation along the batch axis ({_axis}) gives ' f'inconsistent covariances when batching - ' f'proceed with caution.') spatial_axes = tuple(i for i in range(ndim) if i not in (channel_axis, batch_axis)) # Change spatial axis according to the kernel `is_reversed`. if _axis in spatial_axes and is_reversed: _axis = spatial_axes[::-1][spatial_axes.index(_axis)] # Map activation tensor axis to the covariance tensor axis. tensor_axis_to_kernel_axis = { **{ batch_axis: 0, channel_axis: -1, }, **{ spatial_axis: idx + 1 for idx, spatial_axis in enumerate(spatial_axes) } } _axis = tensor_axis_to_kernel_axis[_axis] widths = [k.shape1[channel_axis] for k in ks] cov1 = _concat_kernels([k.cov1 for k in ks], _axis, diagonal_batch, diagonal_spatial, widths) cov2 = _concat_kernels([k.cov2 for k in ks], _axis, diagonal_batch, diagonal_spatial, widths) nngp = _concat_kernels([k.nngp for k in ks], _axis, False, diagonal_spatial, widths) ntk = _concat_kernels([k.ntk for k in ks], _axis, False, diagonal_spatial, widths) return Kernel(cov1=cov1, cov2=cov2, nngp=nngp, ntk=ntk, x1_is_x2=ks[0].x1_is_x2, is_gaussian=is_gaussian, is_reversed=is_reversed, is_input=ks[0].is_input, diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial, shape1=None, shape2=None, batch_axis=batch_axis, channel_axis=channel_axis, mask1=None, mask2=None) # pytype:disable=wrong-keyword-args def mask_fn(mask, input_shape): return _concat_masks(mask, input_shape, axis) return init_fn, apply_fn, kernel_fn, mask_fn
def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)