def f(): data = jnp.ones([1, 5, 5, 5, 1]) net = conv.Conv3D(output_channels=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, **create_constant_initializers( 1.0, 1.0, with_bias)) return net(data)
def test_invalid_input_shape(self): with_bias = True with self.assertRaisesRegex(ValueError, "Input to ConvND needs to have " "rank 5, but input has shape"): data = jnp.ones([1, 5, 5, 5, 1, 9, 9]) net = conv.Conv3D( output_channels=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) net(data)