def define_cnn():
  return combinator.Serial([convolution.Conv(20, (2, 2)),
                            normalization.BatchNorm(),
                            core.Relu(),
                            reshape.Flatten(),
                            core.Dense(10),
                            core.Softmax()])
    def test_flatten_call(self):
        net_rng, data_rng = random.split(random.PRNGKey(0))

        net_init = reshape.Flatten()

        net = net_init.init(net_rng, state.Shape((10, 10)))
        x = random.normal(data_rng, (1, 10, 10))
        self.assertEqual(jax.vmap(net)(x).shape, (1, 100))

        x = random.normal(data_rng, (5, 20, 5))
        self.assertEqual(jax.vmap(net)(x).shape, (5, 100))
    def test_flatten_shape(self):
        net_init = reshape.Flatten()

        out_shape = net_init.spec(state.Shape((5, 100))).shape
        self.assertEqual(out_shape, (500, ))

        out_shape = net_init.spec(state.Shape((10, 10))).shape
        self.assertEqual(out_shape, (100, ))

        out_shape = net_init.spec(state.Shape((1, 2, 5, 10))).shape
        self.assertEqual(out_shape, (100, ))