def define_cnn():
  return combinator.Serial([convolution.Conv(20, (2, 2)),
                            normalization.BatchNorm(),
                            core.Relu(),
                            reshape.Flatten(),
                            core.Dense(10),
                            core.Softmax()])
示例#2
0
    def test_check_grads(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dense(100)

        net = net_init.init(net_rng, state.Shape((10, )))
        x = random.normal(data_rng, [10, 10])
        jtu.check_grads(net, (x, ), 2, atol=0.03, rtol=0.03)
示例#3
0
  def test_bias_init(self):
    net_rng, data_rng = random.split(self._seed)

    net_init = core.Dense(100, bias_init=stax.ones)

    net = net_init.init(net_rng, state.Shape((10,)))
    w, b = net.params
    x = random.normal(data_rng, [10, 10])
    np.testing.assert_allclose(np.dot(x, w) + b, np.array(net(x)), atol=1e-05)
示例#4
0
    def test_spec(self):
        net_init = core.Dense(100)

        out_shape = net_init.spec(state.Shape((10, ))).shape
        self.assertEqual(out_shape, (100, ))

        out_shape = net_init.spec(state.Shape((5, 10))).shape
        self.assertEqual(out_shape, (5, 100))

        out_shape = net_init.spec(state.Shape((-1, 5, 10))).shape
        self.assertEqual(out_shape, (-1, 5, 100))
示例#5
0
    def test_dense(self):
        net_rng = self._seed
        network_init = core.Dense(2)
        network = network_init.init(net_rng, state.Shape((-1, 2)))

        grad_fn = jax.jit(jax.grad(reconstruct_loss))

        x0 = jax.numpy.array([[1.0, 1.0], [2.0, 1.0], [3.0, 0.5]])

        initial_loss = reconstruct_loss(network, x0)
        grads = grad_fn(network, x0)
        self.assertGreater(initial_loss, 0.0)
        network = network.replace(params=jax.tree_util.tree_multimap(
            lambda w, g: w - 0.1 * g, network.params, grads.params))
        final_loss = reconstruct_loss(network, x0)
        self.assertLess(final_loss, initial_loss)
def define_dnn():
  return combinator.Serial([core.Dense(20),
                            core.Relu(),
                            core.Dropout(0.5),
                            core.Dense(10),
                            core.Tanh()])