def test_padding_reverse_causal(self): a = pad.create(pad.reverse_causal, 4, 3, 2) self.assertEqual(a, ((0, 9), (0, 9)))
def __init__(self, num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding="SAME", with_bias=True, w_init=None, b_init=None, data_format="channels_last", mask=None, name=None): """Constructs a `ConvND` module. Args: num_spatial_dims: The number of spatial dimensions of the input. output_channels: Number of output channels. kernel_shape: The shape of the kernel. Either an integer or a sequence of length `num_spatial_dims`. stride: Optional stride for the kernel. Either an integer or a sequence of length `num_spatial_dims`. Defaults to 1. rate: Optional kernel dilation rate. Either an integer or a sequence of length `num_spatial_dims`. 1 corresponds to standard ND convolution, `rate > 1` corresponds to dilated convolution. Defaults to 1. padding: Optional padding algorithm. Either "VALID" or "SAME" or a callable or sequence of callables of size `num_spatial_dims`. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to "SAME". See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default, `channels_last`. mask: Optional mask of the weights. name: The name of the module. """ super(ConvND, self).__init__(name=name) if not 1 <= num_spatial_dims <= 3: raise ValueError( "We only support convolution operations for num_spatial_dims=1, 2 or " "3, received num_spatial_dims={}.".format(num_spatial_dims)) self._num_spatial_dims = num_spatial_dims self._output_channels = output_channels self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape") self._with_bias = with_bias self._stride = utils.replicate(stride, num_spatial_dims, "strides") self._w_init = w_init self._b_init = b_init or jnp.zeros self._mask = mask self._lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation") self._kernel_dilation = utils.replicate(rate, num_spatial_dims, "kernel_dilation") self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) if self._channel_index == -1: self._dn = DIMENSION_NUMBERS[self._num_spatial_dims] else: self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims] if isinstance(padding, str): self._padding = padding.upper() else: self._padding = pad.create( padding=padding, kernel=self._kernel_shape, rate=self._kernel_dilation, n=self._num_spatial_dims)
def test_padding_full(self): a = pad.create(pad.full, 4, 3, 2) self.assertEqual(a, ((9, 9), (9, 9)))
def test_padding_causal(self): a = pad.create(pad.causal, 4, 3, 2) self.assertEqual(a, ((9, 0), (9, 0)))
def test_padding_same(self): a = pad.create(pad.same, 4, 3, 2) self.assertEqual(a, ((4, 5), (4, 5)))
def test_padding_valid(self): a = pad.create(pad.valid, 4, 3, 2) self.assertEqual(a, ((0, 0), (0, 0)))
def test_padding_incorrect_input(self, kernel_size, rate): with self.assertRaisesRegex( TypeError, r"must be a scalar or sequence of length 1 or sequence of length 3." ): pad.create(pad.full, kernel_size, rate, 3)
def test_padding_3d(self): a = pad.create((pad.causal, pad.full, pad.full), (3, 2, 3), (1), 3) self.assertEqual(a, ((2, 0), (1, 1), (2, 2)))
def test_padding_1d(self): a = pad.create(pad.full, 3, 1, 1) self.assertEqual(a, ((2, 2), ))