def define_cnn(): return combinator.Serial([convolution.Conv(20, (2, 2)), normalization.BatchNorm(), core.Relu(), reshape.Flatten(), core.Dense(10), core.Softmax()])
def test_flatten_call(self): net_rng, data_rng = random.split(random.PRNGKey(0)) net_init = reshape.Flatten() net = net_init.init(net_rng, state.Shape((10, 10))) x = random.normal(data_rng, (1, 10, 10)) self.assertEqual(jax.vmap(net)(x).shape, (1, 100)) x = random.normal(data_rng, (5, 20, 5)) self.assertEqual(jax.vmap(net)(x).shape, (5, 100))
def test_flatten_shape(self): net_init = reshape.Flatten() out_shape = net_init.spec(state.Shape((5, 100))).shape self.assertEqual(out_shape, (500, )) out_shape = net_init.spec(state.Shape((10, 10))).shape self.assertEqual(out_shape, (100, )) out_shape = net_init.spec(state.Shape((1, 2, 5, 10))).shape self.assertEqual(out_shape, (100, ))