def test_deconv_strides_shape(self): data_rng, net_rng = random.split(self._seed) x = random.normal(data_rng, (28, 28, 1)) net_init = convolution.Deconv(64, (2, 2), strides=(2, 2), padding='VALID') out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) self.assertEqual(out_shape, (56, 56, 64)) self.assertEqual(net(x).shape, out_shape) net_init = convolution.Deconv(64, (3, 3), strides=(2, 2), padding='VALID') out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) self.assertEqual(out_shape, (57, 57, 64)) self.assertEqual(net(x).shape, out_shape)
def test_deconv_vmap(self): data_rng, net_rng = random.split(self._seed) x = random.normal(data_rng, (10, 28, 28, 1)) net_init = convolution.Deconv(64, (2, 2), strides=(2, 2), padding='VALID') with self.assertRaises(ValueError): out_shape = net_init.spec(state.Shape((10, 28, 28, 1))).shape out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) with self.assertRaises(ValueError): net(x) self.assertEqual(jax.vmap(net)(x).shape, (10, ) + out_shape)