Example #1
0
 def testIncorrectN(self, n):
   init_fn, _ = base.transform(
       lambda: conv.ConvND(n, output_channels=1, kernel_shape=3))
   with self.assertRaisesRegex(
       ValueError,
       "only support convolution operations for num_spatial_dims=1, 2 or 3"):
     init_fn(None)
Example #2
0
 def testIncorrectN(self, n):
   init_fn, _ = transform.transform(
       lambda: conv.ConvND(n, output_channels=1, kernel_shape=3))
   with self.assertRaisesRegex(
       ValueError,
       "convolution operations for `num_spatial_dims` greater than 0"):
     init_fn(None)
Example #3
0
 def f():
     data = jnp.zeros(input_shape)
     net = conv.ConvND(n,
                       output_channels=3,
                       kernel_shape=3,
                       padding="VALID")
     return net(data)
Example #4
0
 def f():
     data = jnp.zeros(input_shape)
     net = conv.ConvND(n,
                       output_channels=3,
                       kernel_shape=3,
                       data_format="channels_first")
     return net(data)
Example #5
0
    def __call__(self, inputs, state):
        prev_h, prev_c = state

        gates = conv.ConvND(num_spatial_dims=self._num_spatial_dims,
                            output_channels=4 * self.output_channels,
                            kernel_shape=self.kernel_shape,
                            name="input_to_hidden")(inputs)
        gates += conv.ConvND(num_spatial_dims=self._num_spatial_dims,
                             output_channels=4 * self.output_channels,
                             kernel_shape=self.kernel_shape,
                             name="hidden_to_hidden")(prev_h)
        i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

        f = jax.nn.sigmoid(f + 1)
        c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g)
        h = jax.nn.sigmoid(o) * jnp.tanh(c)
        return h, (h, c)
Example #6
0
 def test_valid_mask_shape(self):
   n = 2
   input_shape = [2, 4] + [16]*n
   data = jnp.zeros(input_shape)
   net = conv.ConvND(n, output_channels=3, kernel_shape=3,
                     data_format="channels_first",
                     mask=jnp.ones([3, 3, 4, 3]))
   out = net(data)
   expected_output_shape = (2, 3) + (16,)*n
   self.assertEqual(out.shape, expected_output_shape)
Example #7
0
  def test_invalid_mask_shape(self):
    n = 1
    input_shape = [2, 4] + [16]*n

    with self.assertRaisesRegex(ValueError, "Mask needs to have the same "
                                            "shape as weights. Shapes are:"):
      data = jnp.zeros(input_shape)
      net = conv.ConvND(n, output_channels=3, kernel_shape=3,
                        data_format="channels_first", mask=jnp.ones([1, 5, 1]))
      net(data)
Example #8
0
  def test_invalid_input_shape(self):
    n = 1
    input_shape = [2, 4] + [16]*n

    with self.assertRaisesRegex(ValueError, "Input to ConvND needs to have "
                                            "rank 3, but input has shape"):
      data = jnp.zeros(input_shape * 2)
      net = conv.ConvND(n, output_channels=3, kernel_shape=3,
                        data_format="channels_first")
      net(data)
Example #9
0
 def f():
   data = jnp.zeros(input_shape)
   net = conv.ConvND(n, output_channels=3, kernel_shape=3, rate=3)
   return net(data)