def testIncorrectN(self, n): init_fn, _ = base.transform( lambda: conv.ConvND(n, output_channels=1, kernel_shape=3)) with self.assertRaisesRegex( ValueError, "only support convolution operations for num_spatial_dims=1, 2 or 3"): init_fn(None)
def testIncorrectN(self, n): init_fn, _ = transform.transform( lambda: conv.ConvND(n, output_channels=1, kernel_shape=3)) with self.assertRaisesRegex( ValueError, "convolution operations for `num_spatial_dims` greater than 0"): init_fn(None)
def f(): data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, padding="VALID") return net(data)
def f(): data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, data_format="channels_first") return net(data)
def __call__(self, inputs, state): prev_h, prev_c = state gates = conv.ConvND(num_spatial_dims=self._num_spatial_dims, output_channels=4 * self.output_channels, kernel_shape=self.kernel_shape, name="input_to_hidden")(inputs) gates += conv.ConvND(num_spatial_dims=self._num_spatial_dims, output_channels=4 * self.output_channels, kernel_shape=self.kernel_shape, name="hidden_to_hidden")(prev_h) i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1) f = jax.nn.sigmoid(f + 1) c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g) h = jax.nn.sigmoid(o) * jnp.tanh(c) return h, (h, c)
def test_valid_mask_shape(self): n = 2 input_shape = [2, 4] + [16]*n data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, data_format="channels_first", mask=jnp.ones([3, 3, 4, 3])) out = net(data) expected_output_shape = (2, 3) + (16,)*n self.assertEqual(out.shape, expected_output_shape)
def test_invalid_mask_shape(self): n = 1 input_shape = [2, 4] + [16]*n with self.assertRaisesRegex(ValueError, "Mask needs to have the same " "shape as weights. Shapes are:"): data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, data_format="channels_first", mask=jnp.ones([1, 5, 1])) net(data)
def test_invalid_input_shape(self): n = 1 input_shape = [2, 4] + [16]*n with self.assertRaisesRegex(ValueError, "Input to ConvND needs to have " "rank 3, but input has shape"): data = jnp.zeros(input_shape * 2) net = conv.ConvND(n, output_channels=3, kernel_shape=3, data_format="channels_first") net(data)
def f(): data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, rate=3) return net(data)