def define_cnn(): return combinator.Serial([convolution.Conv(20, (2, 2)), normalization.BatchNorm(), core.Relu(), reshape.Flatten(), core.Dense(10), core.Softmax()])
def test_conv_strides_shape(self): data_rng, net_rng = random.split(self._seed) x = random.normal(data_rng, (28, 28, 1)) net_init = convolution.Conv(64, (2, 2), strides=(2, 2), padding='VALID') out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) self.assertEqual(out_shape, (14, 14, 64)) net_init = convolution.Conv(64, (3, 3), strides=(2, 2), padding='VALID') out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) self.assertEqual(out_shape, (13, 13, 64)) self.assertEqual(net(x).shape, out_shape)
def test_conv_vmap(self): data_rng, net_rng = random.split(self._seed) x = random.normal(data_rng, (10, 28, 28, 1)) net_init = convolution.Conv(64, (2, 2), strides=(2, 2), padding='VALID') with self.assertRaises(ValueError): out_shape = net_init.spec(state.Shape((10, 28, 28, 1))).shape out_shape = net_init.spec(state.Shape((28, 28, 1))).shape net = net_init.init(net_rng, state.Shape((28, 28, 1))) with self.assertRaises(ValueError): net(x) y = jax.vmap(net)(x) self.assertEqual(y.shape, (10, ) + out_shape)