예제 #1
0
    def __init__(
        self,
        num_spatial_dims: int,
        output_channels: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "channels_last",
        mask: Optional[jnp.ndarray] = None,
        name: str = None,
    ):
        """Initializes the module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME".
        Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input. Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
        super().__init__(name=name)

        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = (utils.replicate(kernel_shape, num_spatial_dims,
                                             "kernel_shape"))
        self.with_bias = with_bias
        self.stride = utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.mask = mask
        # TODO(tomhennigan) Make use of hk.pad.create here?
        self.padding = padding
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=True)
예제 #2
0
 def testIncorrectLength(self):
     v = [2, 2]
     with self.assertRaisesRegex(
             TypeError,
             r"must be a scalar or sequence of length 1 or sequence of length 3"
     ):
         utils.replicate(v, 3, "value")
예제 #3
0
파일: pad.py 프로젝트: stjordanis/dm-haiku
def create_from_padfn(
    padding: Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]],
    kernel: Union[int, Sequence[int]],
    rate: Union[int, Sequence[int]],
    n: int,
) -> Sequence[Tuple[int, int]]:
    """Generates the padding required for a given padding algorithm.

  Args:
    padding: callable/tuple or a sequence of callables/tuples. The callables
      take an integer representing the effective kernel size (kernel size when
      the rate is 1) and return a sequence of two integers representing the
      padding before and padding after for that dimension. The tuples are
      defined with two elements, padding before and after. If `padding` is a
      sequence it must be of length 1 or `n`.
    kernel: int or sequence of ints of length ``n``. The size of the kernel for
      each dimension. If it is an int it will be replicated for the non channel
      and batch dimensions.
    rate: int or sequence of ints of length ``n``. The dilation rate for each
      dimension. If it is an int it will be replicated for the non channel and
      batch dimensions.
    n: the number of spatial dimensions.

  Returns:
    A sequence of length n containing the padding for each element. These are of
    the form ``[pad_before, pad_after]``.
  """
    # The effective kernel size includes any holes/gaps introduced by the
    # dilation rate. It's equal to kernel_size when rate == 1.
    effective_kernel_size = map(lambda kernel, rate: (kernel - 1) * rate + 1,
                                utils.replicate(kernel, n, "kernel"),
                                utils.replicate(rate, n, "rate"))
    paddings = map(lambda x, y: x(y), utils.replicate(padding, n, "padding"),
                   effective_kernel_size)
    return tuple(paddings)
예제 #4
0
def create(
    padding: Paddings,
    kernel: Union[int, Sequence[int]],
    rate: Union[int, Sequence[int]],
    n: int,
):
    """Generates the padding required for a given padding algorithm.

  Args:
    padding: callable or list of callables of length n. The callables take an
      integer representing the effective kernel size (kernel size when the rate
      is 1) and return a list of two integers representing the padding before
      and padding after for that dimension.
    kernel: int or list of ints of length n. The size of the kernel for each
      dimension. If it is an int it will be replicated for the non channel and
      batch dimensions.
    rate: int or list of ints of length n. The dilation rate for each dimension.
      If it is an int it will be replicated for the non channel and batch
      dimensions.
    n: the number of spatial dimensions.

  Returns:
    A list of length n containing the padding for each element. These are of
    the form [pad_before, pad_after].
  """
    # The effective kernel size includes any holes/gaps introduced by the
    # dilation rate. It's equal to kernel_size when rate == 1.
    effective_kernel_size = map(  # pylint: disable=deprecated-lambda
        lambda kernel, rate: (kernel - 1) * rate + 1,
        utils.replicate(kernel, n, "kernel"), utils.replicate(rate, n, "rate"))
    paddings = map(  # pylint: disable=deprecated-lambda
        lambda x, y: x(y), utils.replicate(padding, n, "padding"),
        effective_kernel_size)

    return tuple(paddings)
예제 #5
0
  def __init__(self,
               num_spatial_dims,
               output_channels,
               kernel_shape,
               stride=1,
               padding="SAME",
               with_bias=True,
               w_init=None,
               b_init=None,
               data_format="channels_last",
               mask=None,
               name=None):
    """Initializes a Conv2DTranspose module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME".
        Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input. Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
    super(ConvNDTranspose, self).__init__(name=name)
    if not 1 <= num_spatial_dims <= 3:
      raise ValueError(
          "We only support convolution operations for num_spatial_dims=1, 2 or "
          "3, received num_spatial_dims={}.".format(num_spatial_dims))
    self._num_spatial_dims = num_spatial_dims
    self._output_channels = output_channels
    self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                         "kernel_shape")
    self._with_bias = with_bias
    self._stride = utils.replicate(stride, num_spatial_dims, "strides")
    self._w_init = w_init
    self._b_init = b_init or jnp.zeros
    self._mask = mask
    self._padding = padding

    self._data_format = data_format
    self._channel_index = utils.get_channel_index(data_format)
    if self._channel_index == -1:
      self._dn = DIMENSION_NUMBERS[self._num_spatial_dims]
    else:
      self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
예제 #6
0
    def __init__(
        self,
        channel_multiplier: int,
        kernel_shape: Union[int, Sequence[int]],
        num_spatial_dims: int,
        data_format: str,
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        name: Optional[str] = None,
    ):
        """Construct an ND Depthwise Convolution.

    Args:
      channel_multiplier: Multiplicity of output channels. To keep the number of
        output channels the same as the number of input channels, set 1.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``.
      num_spatial_dims: The number of spatial dimensions of the input data.
      data_format: The data format of the input.  Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``. See :func:`get_channel_index`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``. Defaults to 1.
      padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
        sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      name: The name of the module.
    """
        super().__init__(name=name)
        self.num_spatial_dims = num_spatial_dims
        self.kernel_shape = utils.replicate(kernel_shape,
                                            self.num_spatial_dims,
                                            "kernel_shape")
        self.lhs_dilation = (1, ) * len(self.kernel_shape)
        self.rhs_dilation = (1, ) * len(self.kernel_shape)
        self.channel_multiplier = channel_multiplier
        self.padding = padding
        self.stride = utils.replicate(stride, self.num_spatial_dims, "strides")
        self.data_format = data_format
        self.channel_index = hk.get_channel_index(data_format)
        self.with_bias = with_bias
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        if self.channel_index == -1:
            self.dn = DIMENSION_NUMBERS[self.num_spatial_dims]
        else:
            self.dn = DIMENSION_NUMBERS_NCSPATIAL[self.num_spatial_dims]
예제 #7
0
    def __init__(self,
                 channel_multiplier,
                 kernel_shape,
                 stride=1,
                 padding="SAME",
                 with_bias=True,
                 w_init=None,
                 b_init=None,
                 data_format="NHWC",
                 name=None):
        """Construct a 2D Depthwise Convolution.

    Args:
      channel_multiplier: Multiplicity of output channels. To keep the number of
        output channels the same as the number of input channels, set 1.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME" or
        a callable or sequence of callables of size `num_spatial_dims`. Any
        callables must take a single integer argument equal to the effective
        kernel size and return a list of two integers representing the padding
        before and after. See haiku.pad.* for more details and example
        functions. Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      name: The name of the module.
    """
        super(DepthwiseConv2D, self).__init__(name=name)
        self._kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape")
        self._lhs_dilation = (1, ) * len(self._kernel_shape)
        self._rhs_dilation = (1, ) * len(self._kernel_shape)
        self._channel_multiplier = channel_multiplier
        self._padding = padding
        self._stride = utils.replicate(stride, 2, "strides")
        self._data_format = data_format
        self._channel_index = utils.get_channel_index(data_format)
        self._with_bias = with_bias
        self._w_init = w_init
        self._b_init = b_init or jnp.zeros
        self._num_spatial_dims = 2
예제 #8
0
    def __init__(
        self,
        channel_multiplier: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "NHWC",
        name: Optional[str] = None,
    ):
        """Construct a Separable 2D Depthwise Convolution module.

    Args:
      channel_multiplier: Multiplicity of output channels. To keep the number of
        output channels the same as the number of input channels, set 1.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``. Defaults to 1.
      padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
        sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      name: The name of the module.
    """
        super().__init__(name=name)
        kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape")
        self._conv1 = DepthwiseConv2D(channel_multiplier=channel_multiplier,
                                      kernel_shape=[kernel_shape[0], 1],
                                      stride=stride,
                                      padding=padding,
                                      with_bias=False,
                                      w_init=w_init,
                                      b_init=b_init,
                                      data_format=data_format)

        self._conv2 = DepthwiseConv2D(channel_multiplier=1,
                                      kernel_shape=[1, kernel_shape[1]],
                                      stride=1,
                                      padding=padding,
                                      with_bias=with_bias,
                                      w_init=w_init,
                                      b_init=b_init,
                                      data_format=data_format)
예제 #9
0
 def testListLengthOne(self, value):
     result = utils.replicate([value], 3, "value")
     self.assertLen(result, 3)
     self.assertEqual(result, (value, ) * 3)
예제 #10
0
 def testSingleValue(self, value):
     result = utils.replicate(value, 3, "value")
     self.assertLen(result, 3)
     self.assertEqual(result, (value, ) * 3)
예제 #11
0
 def testListLengthN(self, value):
     v = list((value, ) * 3)
     result = utils.replicate(v, 3, "value")
     self.assertLen(result, 3)
     self.assertEqual(result, (value, ) * 3)
예제 #12
0
 def testTupleLengthN(self, value):
     v = (value, ) * 3
     result = utils.replicate(v, 3, "value")
     self.assertLen(result, 3)
     self.assertEqual(result, (value, ) * 3)
예제 #13
0
  def __init__(self,
               num_spatial_dims,
               output_channels,
               kernel_shape,
               stride=1,
               rate=1,
               padding="SAME",
               with_bias=True,
               w_init=None,
               b_init=None,
               data_format="channels_last",
               mask=None,
               name=None):
    """Constructs a `ConvND` module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      rate: Optional kernel dilation rate. Either an integer or a sequence of
        length `num_spatial_dims`. 1 corresponds to standard ND convolution,
        `rate > 1` corresponds to dilated convolution. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME" or
        a callable or sequence of callables of size `num_spatial_dims`. Any
        callables must take a single integer argument equal to the effective
        kernel size and return a list of two integers representing the padding
        before and after. See haiku.pad.* for more details and example
        functions. Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
    super(ConvND, self).__init__(name=name)

    if not 1 <= num_spatial_dims <= 3:
      raise ValueError(
          "We only support convolution operations for num_spatial_dims=1, 2 or "
          "3, received num_spatial_dims={}.".format(num_spatial_dims))
    self._num_spatial_dims = num_spatial_dims
    self._output_channels = output_channels
    self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                         "kernel_shape")
    self._with_bias = with_bias
    self._stride = utils.replicate(stride, num_spatial_dims, "strides")
    self._w_init = w_init
    self._b_init = b_init or jnp.zeros
    self._mask = mask
    self._lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation")
    self._kernel_dilation = utils.replicate(rate, num_spatial_dims,
                                            "kernel_dilation")
    self._data_format = data_format
    self._channel_index = utils.get_channel_index(data_format)
    if self._channel_index == -1:
      self._dn = DIMENSION_NUMBERS[self._num_spatial_dims]
    else:
      self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]

    if isinstance(padding, str):
      self._padding = padding.upper()
    else:
      self._padding = pad.create(
          padding=padding,
          kernel=self._kernel_shape,
          rate=self._kernel_dilation,
          n=self._num_spatial_dims)
예제 #14
0
파일: conv.py 프로젝트: sooheon/elegy
    def __init__(
        self,
        num_spatial_dims: int,
        output_channels: int,
        kernel_shape: tp.Union[int, tp.Sequence[int]],
        stride: tp.Union[int, tp.Sequence[int]] = 1,
        rate: tp.Union[int, tp.Sequence[int]] = 1,
        padding: tp.Union[str, tp.Sequence[tp.Tuple[int, int]],
                          types.PadFnOrFns] = "SAME",
        with_bias: bool = True,
        w_init: tp.Optional[types.Initializer] = None,
        b_init: tp.Optional[types.Initializer] = None,
        data_format: str = "channels_last",
        mask: tp.Optional[np.ndarray] = None,
        groups: int = 1,
        **kwargs,
    ):
        """
        Initializes the module.

        Args:
            num_spatial_dims: The number of spatial dimensions of the input.
            output_channels: Number of output channels.
            kernel_shape: The shape of the kernel. Either an integer or a sequence of
                length ``num_spatial_dims``.
            stride: tp.Optional stride for the kernel. Either an integer or a sequence of
                length ``num_spatial_dims``. Defaults to 1.
            rate: tp.Optional kernel dilation rate. Either an integer or a sequence of
                length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
                ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
            padding: tp.Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
                sequence of n ``(low, high)`` integer pairs that give the padding to
                apply before and after each spatial dimension. or a callable or sequence
                of callables of size ``num_spatial_dims``. Any callables must take a
                single integer argument equal to the effective kernel size and return a
                sequence of two integers representing the padding before and after. See
                ``haiku.pad.*`` for more details and example functions. Defaults to
                ``SAME``. See:
                https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
            with_bias: Whether to add a bias. By default, true.
            w_init: tp.Optional weight initialization. By default, truncated normal.
            b_init: tp.Optional bias initialization. By default, zeros.
            data_format: The data format of the input.  Can be either
                ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
                default, ``channels_last``.
            mask: tp.Optional mask of the weights.
            groups: A positive integer specifying the number of groups in which the
                input is split along the channel axis. Each group is convolved separately
                with filters / groups filters. The output is the concatenation of all the
                groups results along the channel axis. Input channels and filters must both
                be divisible by groups.
            kwargs: Additional keyword arguments passed to Module.
        """
        super().__init__(**kwargs)
        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = hk_utils.replicate(kernel_shape, num_spatial_dims,
                                               "kernel_shape")
        self.with_bias = with_bias
        self.stride = hk_utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.mask = mask
        self.lhs_dilation = hk_utils.replicate(1, num_spatial_dims,
                                               "lhs_dilation")
        self.kernel_dilation = hk_utils.replicate(rate, num_spatial_dims,
                                                  "kernel_dilation")
        self.data_format = data_format
        self.channel_index = hk_utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=False)
        self.groups = groups

        if isinstance(padding, str):
            self.padding = padding.upper()
        else:
            self.padding = hk.pad.create(
                padding=padding,
                kernel=self.kernel_shape,
                rate=self.kernel_dilation,
                n=self.num_spatial_dims,
            )
예제 #15
0
    def __init__(
        self,
        output_channels: int,
        uniform_init_minval: float,
        uniform_init_maxval: float,
        kernel_shape: Union[int, Sequence[int]],
        num_spatial_dims: int = 2,
        stride: Union[int, Sequence[int]] = 1,
        rate: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn,
                       Sequence[hk.pad.PadFn]] = "SAME",
        with_bias: bool = True,
        w_init: Union[Optional[hk.initializers.Initializer], str] = "uniform",
        b_init: Union[Optional[hk.initializers.Initializer], str] = "uniform",
        data_format: str = "channels_last",
        mask: Optional[jnp.ndarray] = None,
        feature_group_count: int = 1,
        name: Optional[str] = None,
        stochastic_parameters: bool = False,
    ):
        """Initializes the module.

    Args:
      output_channels: Number of output channels.
      uniform_init_minval: TODO(nband).
      uniform_init_maxval: TODO(nband).
      kernel_shape: The shape of the kernel. Either an integer or a sequence
        of length ``num_spatial_dims``.
      num_spatial_dims: The number of spatial dimensions of the input.
      stride: Optional stride for the kernel. Either an integer or a sequence
        of length ``num_spatial_dims``. Defaults to 1.
      rate: Optional kernel dilation rate. Either an integer or a sequence of
        length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
        ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
      padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
        sequence of n ``(low, high)`` integer pairs that give the padding to
        apply before and after each spatial dimension. or a callable or
        sequence of callables of size ``num_spatial_dims``. Any callables must
        take a single integer argument equal to the effective kernel size and
        return a sequence of two integers representing the padding before and
        after. See ``haiku.pad.*`` for more details and example functions.
        Defaults to
          ``SAME``. See:
          https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      mask: Optional mask of the weights.
      feature_group_count: Optional number of groups in group convolution.
        Default value of 1 corresponds to normal dense convolution. If a
        higher value is used, convolutions are applied separately to that many
        groups, then stacked together. This reduces the number of parameters
          and possibly the compute for a given ``output_channels``. See:
          https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      name: The name of the module.
      stochastic_parameters: TODO(nband).
    """
        super().__init__(name=name)
        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                            "kernel_shape")
        self.with_bias = with_bias
        self.stride = utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init
        self.uniform_init_minval = uniform_init_minval
        self.uniform_init_maxval = uniform_init_maxval
        self.mask = mask
        self.feature_group_count = feature_group_count
        self.lhs_dilation = utils.replicate(1, num_spatial_dims,
                                            "lhs_dilation")
        self.kernel_dilation = utils.replicate(rate, num_spatial_dims,
                                               "kernel_dilation")
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=False)
        self.stochastic_parameters = stochastic_parameters

        if isinstance(padding, str):
            self.padding = padding.upper()
        else:
            self.padding = hk.pad.create(
                padding=padding,
                kernel=self.kernel_shape,
                rate=self.kernel_dilation,
                n=self.num_spatial_dims,
            )