Esempio n. 1
0
 def test_sum_pool(self):
     in_shape = (3, 3, 1)
     net_init = pooling.SumPooling((2, 2))
     net_rng = self._seed
     out_shape = net_init.spec(state.Shape(in_shape)).shape
     layer = net_init.init(net_rng, state.Shape(in_shape))
     x = np.array([[-1, 0, -1], [0, 1, 0], [-1, 0, -1]])
     x = np.reshape(x, in_shape)
     result = layer(x)
     self.assertEqual(result.shape, out_shape)
     np.testing.assert_equal(result, np.zeros(out_shape))
Esempio n. 2
0
  def test_sum_pool_batched(self):
    in_shape = (3, 3, 1)
    batch_size = 10
    batch_in_shape = (batch_size,) + in_shape
    net_init = pooling.SumPooling((2, 2))
    net_rng = self._seed
    with self.assertRaises(ValueError):
      _ = net_init.spec(state.Shape(batch_in_shape)).shape
    out_shape = net_init.spec(state.Shape(in_shape)).shape
    batch_out_shape = (batch_size,) + out_shape
    layer = net_init.init(net_rng, state.Shape(in_shape))

    x = np.tile(np.array([[-1, 0, -1], [0, 1, 0], [-1, 0, -1]])[None],
                (batch_size, 1, 1))
    x = np.reshape(x, batch_in_shape)
    with self.assertRaises(ValueError):
      layer(x)
    result = jax.vmap(layer)(x)
    self.assertEqual(result.shape, batch_out_shape)
    np.testing.assert_equal(result, np.zeros(batch_out_shape))