示例#1
0
    def test_init(self):
        net_init = core.Dropout(0.5)

        out_shape = net_init.spec(state.Shape((10, ))).shape
        self.assertEqual(out_shape, (10, ))

        out_shape = net_init.spec(state.Shape((5, 10))).shape
        self.assertEqual(out_shape, (5, 10))
示例#2
0
  def test_call(self):
    net_rng, data_rng = random.split(self._seed)

    net_init = core.Dropout(1.0)

    net = net_init.init(net_rng, state.Shape((10,)))

    x = random.normal(data_rng, [10, 10])
    np.testing.assert_allclose(x, np.array(net(x, rng=net_rng)), atol=1e-05)
示例#3
0
    def test_check_grads(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        fixed_net = lambda x: net(x, rng=net_rng)
        jtu.check_grads(fixed_net, (x, ), 2)
示例#4
0
    def test_fix_state_produces_same_results(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, rng=net_rng))
        y2 = np.array(net(x, rng=net_rng))
        np.testing.assert_allclose(y, y2, atol=1e-05)
示例#5
0
    def test_missing_rng_raise_error(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(1.0)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        with self.assertRaisesRegex(ValueError,
                                    'rng is required when training is True'):
            net(x)
示例#6
0
    def test_jvp(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        fixed_net = lambda x: net(x, rng=net_rng)
        y, y_tangent = jax.jvp(fixed_net, (x, ), (jax.numpy.ones_like(x), ))
        exp_tangent = np.where(np.array(y == 0), 0., 2.)
        np.testing.assert_allclose(exp_tangent, y_tangent)
示例#7
0
    def test_jit(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        j_net = jax.jit(lambda x, rng: net(x, rng=rng))
        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, rng=net_rng))
        j_y = np.array(j_net(x, net_rng))
        np.testing.assert_allclose(y, j_y)
示例#8
0
    def test_training_is_false(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, training=False, rng=net_rng))
        np.testing.assert_allclose(x, y)

        # Calling twice produces the same results with different rng.
        net_rng, _ = random.split(net_rng)
        y2 = np.array(net(x, training=False, rng=net_rng))
        np.testing.assert_allclose(x, y2)
示例#9
0
    def test_dropout(self):
        net_rng = self._seed
        network_init = core.Dropout(0.5)
        network = network_init.init(net_rng, state.Shape((-1, 2)))

        grad_fn = jax.jit(jax.grad(reconstruct_loss))

        x0 = jax.numpy.array([[1.0, 1.0], [2.0, 1.0], [3.0, 0.5]])

        initial_loss = reconstruct_loss(network, x0, rng=net_rng)
        grads = grad_fn(network, x0, rng=net_rng)
        self.assertGreater(initial_loss, 0.0)
        network = network.replace(params=jax.tree_util.tree_multimap(
            lambda w, g: w - 0.1 * g, network.params, grads.params))
        final_loss = reconstruct_loss(network, x0, rng=net_rng)
        self.assertEqual(final_loss, initial_loss)
示例#10
0
    def test_multiple_calls_produces_different_results(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, rng=net_rng))
        exp_x = np.where(y == 0, x, y * 0.5)
        np.testing.assert_allclose(x, exp_x, atol=1e-05)

        # Calling with different rng produces different masks and results
        net_rng, _ = random.split(net_rng)
        y2 = np.array(net(x, rng=net_rng))
        self.assertGreater(np.sum(np.isclose(y, y2)), 10)
        self.assertLess(np.sum(np.isclose(y, y2)), 90)
def define_dnn():
  return combinator.Serial([core.Dense(20),
                            core.Relu(),
                            core.Dropout(0.5),
                            core.Dense(10),
                            core.Tanh()])