示例#1
0
 def test_padding_reverse_causal(self):
     a = pad.create(pad.reverse_causal, 4, 3, 2)
     self.assertEqual(a, ((0, 9), (0, 9)))
示例#2
0
  def __init__(self,
               num_spatial_dims,
               output_channels,
               kernel_shape,
               stride=1,
               rate=1,
               padding="SAME",
               with_bias=True,
               w_init=None,
               b_init=None,
               data_format="channels_last",
               mask=None,
               name=None):
    """Constructs a `ConvND` module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      rate: Optional kernel dilation rate. Either an integer or a sequence of
        length `num_spatial_dims`. 1 corresponds to standard ND convolution,
        `rate > 1` corresponds to dilated convolution. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME" or
        a callable or sequence of callables of size `num_spatial_dims`. Any
        callables must take a single integer argument equal to the effective
        kernel size and return a list of two integers representing the padding
        before and after. See haiku.pad.* for more details and example
        functions. Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
    super(ConvND, self).__init__(name=name)

    if not 1 <= num_spatial_dims <= 3:
      raise ValueError(
          "We only support convolution operations for num_spatial_dims=1, 2 or "
          "3, received num_spatial_dims={}.".format(num_spatial_dims))
    self._num_spatial_dims = num_spatial_dims
    self._output_channels = output_channels
    self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                         "kernel_shape")
    self._with_bias = with_bias
    self._stride = utils.replicate(stride, num_spatial_dims, "strides")
    self._w_init = w_init
    self._b_init = b_init or jnp.zeros
    self._mask = mask
    self._lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation")
    self._kernel_dilation = utils.replicate(rate, num_spatial_dims,
                                            "kernel_dilation")
    self._data_format = data_format
    self._channel_index = utils.get_channel_index(data_format)
    if self._channel_index == -1:
      self._dn = DIMENSION_NUMBERS[self._num_spatial_dims]
    else:
      self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]

    if isinstance(padding, str):
      self._padding = padding.upper()
    else:
      self._padding = pad.create(
          padding=padding,
          kernel=self._kernel_shape,
          rate=self._kernel_dilation,
          n=self._num_spatial_dims)
示例#3
0
 def test_padding_full(self):
     a = pad.create(pad.full, 4, 3, 2)
     self.assertEqual(a, ((9, 9), (9, 9)))
示例#4
0
 def test_padding_causal(self):
     a = pad.create(pad.causal, 4, 3, 2)
     self.assertEqual(a, ((9, 0), (9, 0)))
示例#5
0
 def test_padding_same(self):
     a = pad.create(pad.same, 4, 3, 2)
     self.assertEqual(a, ((4, 5), (4, 5)))
示例#6
0
 def test_padding_valid(self):
     a = pad.create(pad.valid, 4, 3, 2)
     self.assertEqual(a, ((0, 0), (0, 0)))
示例#7
0
 def test_padding_incorrect_input(self, kernel_size, rate):
     with self.assertRaisesRegex(
             TypeError,
             r"must be a scalar or sequence of length 1 or sequence of length 3."
     ):
         pad.create(pad.full, kernel_size, rate, 3)
示例#8
0
 def test_padding_3d(self):
     a = pad.create((pad.causal, pad.full, pad.full), (3, 2, 3), (1), 3)
     self.assertEqual(a, ((2, 0), (1, 1), (2, 2)))
示例#9
0
 def test_padding_1d(self):
     a = pad.create(pad.full, 3, 1, 1)
     self.assertEqual(a, ((2, 2), ))