def test_sum_pool(self): in_shape = (3, 3, 1) net_init = pooling.SumPooling((2, 2)) net_rng = self._seed out_shape = net_init.spec(state.Shape(in_shape)).shape layer = net_init.init(net_rng, state.Shape(in_shape)) x = np.array([[-1, 0, -1], [0, 1, 0], [-1, 0, -1]]) x = np.reshape(x, in_shape) result = layer(x) self.assertEqual(result.shape, out_shape) np.testing.assert_equal(result, np.zeros(out_shape))
def test_sum_pool_batched(self): in_shape = (3, 3, 1) batch_size = 10 batch_in_shape = (batch_size,) + in_shape net_init = pooling.SumPooling((2, 2)) net_rng = self._seed with self.assertRaises(ValueError): _ = net_init.spec(state.Shape(batch_in_shape)).shape out_shape = net_init.spec(state.Shape(in_shape)).shape batch_out_shape = (batch_size,) + out_shape layer = net_init.init(net_rng, state.Shape(in_shape)) x = np.tile(np.array([[-1, 0, -1], [0, 1, 0], [-1, 0, -1]])[None], (batch_size, 1, 1)) x = np.reshape(x, batch_in_shape) with self.assertRaises(ValueError): layer(x) result = jax.vmap(layer)(x) self.assertEqual(result.shape, batch_out_shape) np.testing.assert_equal(result, np.zeros(batch_out_shape))