コード例 #1
0
def define_cnn():
  return combinator.Serial([convolution.Conv(20, (2, 2)),
                            normalization.BatchNorm(),
                            core.Relu(),
                            reshape.Flatten(),
                            core.Dense(10),
                            core.Softmax()])
コード例 #2
0
    def test_conv_strides_shape(self):
        data_rng, net_rng = random.split(self._seed)
        x = random.normal(data_rng, (28, 28, 1))

        net_init = convolution.Conv(64, (2, 2),
                                    strides=(2, 2),
                                    padding='VALID')
        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        self.assertEqual(out_shape, (14, 14, 64))

        net_init = convolution.Conv(64, (3, 3),
                                    strides=(2, 2),
                                    padding='VALID')
        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        self.assertEqual(out_shape, (13, 13, 64))
        self.assertEqual(net(x).shape, out_shape)
コード例 #3
0
    def test_conv_vmap(self):
        data_rng, net_rng = random.split(self._seed)
        x = random.normal(data_rng, (10, 28, 28, 1))

        net_init = convolution.Conv(64, (2, 2),
                                    strides=(2, 2),
                                    padding='VALID')
        with self.assertRaises(ValueError):
            out_shape = net_init.spec(state.Shape((10, 28, 28, 1))).shape

        out_shape = net_init.spec(state.Shape((28, 28, 1))).shape
        net = net_init.init(net_rng, state.Shape((28, 28, 1)))
        with self.assertRaises(ValueError):
            net(x)
        y = jax.vmap(net)(x)
        self.assertEqual(y.shape, (10, ) + out_shape)