def define_cnn(): return combinator.Serial([convolution.Conv(20, (2, 2)), normalization.BatchNorm(), core.Relu(), reshape.Flatten(), core.Dense(10), core.Softmax()])
def test_check_grads(self): net_rng, data_rng = random.split(self._seed) net_init = core.Dense(100) net = net_init.init(net_rng, state.Shape((10, ))) x = random.normal(data_rng, [10, 10]) jtu.check_grads(net, (x, ), 2, atol=0.03, rtol=0.03)
def test_bias_init(self): net_rng, data_rng = random.split(self._seed) net_init = core.Dense(100, bias_init=stax.ones) net = net_init.init(net_rng, state.Shape((10,))) w, b = net.params x = random.normal(data_rng, [10, 10]) np.testing.assert_allclose(np.dot(x, w) + b, np.array(net(x)), atol=1e-05)
def test_spec(self): net_init = core.Dense(100) out_shape = net_init.spec(state.Shape((10, ))).shape self.assertEqual(out_shape, (100, )) out_shape = net_init.spec(state.Shape((5, 10))).shape self.assertEqual(out_shape, (5, 100)) out_shape = net_init.spec(state.Shape((-1, 5, 10))).shape self.assertEqual(out_shape, (-1, 5, 100))
def test_dense(self): net_rng = self._seed network_init = core.Dense(2) network = network_init.init(net_rng, state.Shape((-1, 2))) grad_fn = jax.jit(jax.grad(reconstruct_loss)) x0 = jax.numpy.array([[1.0, 1.0], [2.0, 1.0], [3.0, 0.5]]) initial_loss = reconstruct_loss(network, x0) grads = grad_fn(network, x0) self.assertGreater(initial_loss, 0.0) network = network.replace(params=jax.tree_util.tree_multimap( lambda w, g: w - 0.1 * g, network.params, grads.params)) final_loss = reconstruct_loss(network, x0) self.assertLess(final_loss, initial_loss)
def define_dnn(): return combinator.Serial([core.Dense(20), core.Relu(), core.Dropout(0.5), core.Dense(10), core.Tanh()])