def _get_output_shape(rank, strides, padding, dilations, input_shape,
                      output_size, filter_shape, output_padding):
    """Compute the `output_shape` and `strides` arg used by `conv_transpose`."""
    if output_padding is None:
        output_padding = (None, ) * rank
    else:
        output_padding = nn_util_lib.prepare_tuple_argument(
            output_padding, rank, 'output_padding')
        for stride, out_pad in zip(strides, output_padding):
            if out_pad >= stride:
                raise ValueError('Stride {} must be greater than output '
                                 'padding {}.'.format(strides, output_padding))
    assert len(filter_shape) == rank
    assert len(strides) == rank
    assert len(output_padding) == rank
    event_shape = []
    for i in range(-rank, 0):
        event_shape.append(
            _deconv_output_length(input_shape[i - 1],
                                  filter_shape[i],
                                  padding=padding,
                                  output_padding=output_padding[i],
                                  stride=strides[i],
                                  dilation=dilations[i]))
    event_shape.append(output_size)
    batch_shape = input_shape[:-rank - 1]
    output_shape = prefer_static.concat([batch_shape, event_shape], axis=0)
    strides = (1, ) + strides + (1, )
    return output_shape, strides
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            # Weights
        init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
            init_bias_fn=None,  # tf.initializers.zeros()
            make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
            dtype=tf.float32,
            batch_shape=(),
            # Misc
            activation_fn=None,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'ConvolutionV2'`).
    """
        filter_shape = nn_util_lib.prepare_tuple_argument(
            filter_shape, rank, arg_name='filter_shape')
        batch_shape = (np.array([], dtype=np.int32) if batch_shape is None else
                       prefer_static.reshape(batch_shape, shape=[-1]))
        batch_ndims = prefer_static.size(batch_shape)
        if tf.get_static_value(batch_ndims) == 0:
            # In this branch, we statically know there are no batch dims.
            kernel_shape = filter_shape + (input_size, output_size)
            bias_shape = [output_size]
            apply_kernel_fn = _make_convolution_fn(rank, strides, padding,
                                                   dilations)
        else:
            # In this branch, there are either static/dynamic batch dims or
            # dynamically no batch dims.
            kernel_shape = prefer_static.concat(
                [batch_shape, filter_shape, [input_size, output_size]], axis=0)
            bias_shape = prefer_static.concat(
                [batch_shape, tf.ones(rank), [output_size]], axis=0)
            apply_kernel_fn = lambda x, k: nn_util_lib.convolution_batch(  # pylint: disable=g-long-lambda
                x,
                k,
                rank=rank,
                strides=strides,
                padding=padding,
                data_format='NBHWC',
                dilations=dilations)
        kernel, bias = make_kernel_bias_fn(kernel_shape, bias_shape,
                                           init_kernel_fn, init_bias_fn,
                                           batch_ndims, batch_ndims, dtype)
        self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
        super(ConvolutionV2, self).__init__(kernel=kernel,
                                            bias=bias,
                                            apply_kernel_fn=apply_kernel_fn,
                                            dtype=dtype,
                                            activation_fn=activation_fn,
                                            name=name)
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            # Weights
        init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
            init_bias_fn=None,  # tf.initializers.zeros()
            make_posterior_fn=nn_util_lib.make_kernel_bias_posterior_mvn_diag,
            make_prior_fn=nn_util_lib.make_kernel_bias_prior_spike_and_slab,
            posterior_value_fn=tfd.Distribution.sample,
            unpack_weights_fn=unpack_kernel_and_bias,
            dtype=tf.float32,
            # Penalty.
            penalty_weight=None,
            posterior_penalty_fn=kl_divergence_monte_carlo,
            # Misc
            activation_fn=None,
            seed=None,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_posterior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`.
      make_prior_fn: ...
        Default value:
          `tfp.experimental.nn.util.make_kernel_bias_prior_spike_and_slab`.
      posterior_value_fn: ...
        Default valye: `tfd.Distribution.sample`
      unpack_weights_fn:
        Default value: `unpack_kernel_and_bias`
      dtype: ...
        Default value: `tf.float32`.
      penalty_weight: ...
        Default value: `None` (i.e., weight is `1`).
      posterior_penalty_fn: ...
        Default value: `kl_divergence_monte_carlo`.
      activation_fn: ...
        Default value: `None`.
      seed: ...
        Default value: `None` (i.e., no seed).
      name: ...
        Default value: `None` (i.e.,
        `'ConvolutionVariationalFlipoutV2'`).
    """
        filter_shape = nn_util_lib.prepare_tuple_argument(
            filter_shape, rank, arg_name='filter_shape')
        kernel_shape = filter_shape + (input_size, output_size)
        self._make_posterior_fn = make_posterior_fn  # For variable tracking.
        self._make_prior_fn = make_prior_fn  # For variable tracking.
        batch_ndims = 0
        super(ConvolutionVariationalFlipoutV2, self).__init__(
            posterior=make_posterior_fn(kernel_shape, [output_size],
                                        init_kernel_fn, init_bias_fn,
                                        batch_ndims, batch_ndims, dtype),
            prior=make_prior_fn(kernel_shape, [output_size], init_kernel_fn,
                                init_bias_fn, batch_ndims, batch_ndims, dtype),
            apply_kernel_fn=_make_convolution_fn(rank, strides, padding,
                                                 dilations),
            posterior_value_fn=posterior_value_fn,
            unpack_weights_fn=unpack_weights_fn,
            dtype=dtype,
            penalty_weight=penalty_weight,
            posterior_penalty_fn=posterior_penalty_fn,
            activation_fn=activation_fn,
            seed=seed,
            name=name)
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            output_padding=None,  # keras::ConvTranspose::output_padding
            # Weights
        init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
            init_bias_fn=None,  # tf.initializers.zeros()
            make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
            dtype=tf.float32,
            # Misc
            activation_fn=None,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      output_padding: An `int` or length-`rank` tuple/list representing the
        amount of padding along the input spatial dimensions (e.g., depth,
        height, width). A single `int` indicates the same value for all spatial
        dimensions. The amount of output padding along a given dimension must be
        lower than the stride along that same dimension.  If set to `None`
        (default), the output shape is inferred.
        In Keras, this argument has the same name and semantics.
        Default value: `None` (i.e., inferred).
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'ConvolutionTranspose'`).
    """
        filter_shape = nn_util_lib.prepare_tuple_argument(
            filter_shape, rank, 'filter_shape')
        kernel_shape = filter_shape + (output_size, input_size
                                       )  # Note transpose.
        batch_ndims = 0
        kernel, bias = make_kernel_bias_fn(kernel_shape, [output_size],
                                           init_kernel_fn, init_bias_fn,
                                           batch_ndims, batch_ndims, dtype)
        super(ConvolutionTranspose,
              self).__init__(kernel=kernel,
                             bias=bias,
                             apply_kernel_fn=_make_convolution_transpose_fn(
                                 rank, strides, padding, dilations,
                                 filter_shape, output_size, output_padding),
                             dtype=dtype,
                             activation_fn=activation_fn,
                             name=name)