def test_deconv_strides_shape(self):
        data_rng, net_rng = random.split(self._seed)
        x = random.normal(data_rng, (28, 28, 1))

        net_init = convolution.Deconv(64, (2, 2),
                                      strides=(2, 2),
                                      padding='VALID')
        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        self.assertEqual(out_shape, (56, 56, 64))
        self.assertEqual(net(x).shape, out_shape)

        net_init = convolution.Deconv(64, (3, 3),
                                      strides=(2, 2),
                                      padding='VALID')
        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        self.assertEqual(out_shape, (57, 57, 64))
        self.assertEqual(net(x).shape, out_shape)
    def test_deconv_vmap(self):
        data_rng, net_rng = random.split(self._seed)
        x = random.normal(data_rng, (10, 28, 28, 1))

        net_init = convolution.Deconv(64, (2, 2),
                                      strides=(2, 2),
                                      padding='VALID')
        with self.assertRaises(ValueError):
            out_shape = net_init.spec(state.Shape((10, 28, 28, 1))).shape

        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        with self.assertRaises(ValueError):
            net(x)
        self.assertEqual(jax.vmap(net)(x).shape, (10, ) + out_shape)