Esempio n. 1
0
  def test_jvp(self):
    net_rng, data_rng = random.split(self._seed)

    net_init = core.Dropout(0.5)

    net = net_init.init(net_rng, state.Shape((10,)))

    x = random.normal(data_rng, [10, 10])
    fixed_net = lambda x: net(x, rng=net_rng)
    y, y_tangent = jax.jvp(fixed_net, (x,), (jax.numpy.ones_like(x),))
    exp_tangent = np.where(np.array(y == 0), 0., 2.)
    np.testing.assert_allclose(exp_tangent, y_tangent)
Esempio n. 2
0
    def test_update(self):
        def template(x, init_key=None):
            return Counter(np.zeros(()))(x, init_key=init_key, name='counter')

        net = state.init(template)(self._seed, state.Shape(()))
        self.assertEqual(net(np.ones(())), 1.)

        net2 = state.update(net, np.ones(()))
        self.assertEqual(net2(np.ones(())), 2.)

        net2 = net.update(np.ones(()))
        self.assertEqual(net2(np.ones(())), 2.)
Esempio n. 3
0
    def test_dense_function(self):
        def dense_no_rng(x):
            return nn.Dense(20)(x, name='dense')

        with self.assertRaises(ValueError):
            out_spec = state.spec(dense_no_rng)(state.Shape(50))

        with self.assertRaises(ValueError):
            net = state.init(dense_no_rng)(self._seed, state.Shape(2))

        def dense(x, init_key=None):
            return nn.Dense(20)(x, init_key=init_key, name='dense')

        out_spec = state.spec(dense)(state.Shape(2))
        self.assertEqual(out_spec, state.Shape(20))

        net = state.init(dense)(self._seed, state.Shape(2))
        self.assertTupleEqual(net(np.ones(2)).shape, (20, ))
        onp.testing.assert_allclose(net(np.ones(2)),
                                    dense(np.ones(2), init_key=self._seed),
                                    rtol=1e-5)
Esempio n. 4
0
    def test_avg_pool_batched(self):
        in_shape = (3, 3, 1)
        batch_size = 10
        batch_in_shape = (batch_size, ) + in_shape
        net_init = pooling.AvgPooling((2, 2))
        net_rng = self._seed
        with self.assertRaises(ValueError):
            _ = net_init.spec(state.Shape(batch_in_shape)).shape
        out_shape = net_init.spec(state.Shape(in_shape)).shape
        batch_out_shape = (batch_size, ) + out_shape
        layer = net_init.init(net_rng, state.Shape(in_shape))

        x = np.tile(
            np.array([[-1, 0, -1], [0, 5, 0], [-1, 0, -1]])[None],
            (batch_size, 1, 1))
        x = np.reshape(x, batch_in_shape)
        with self.assertRaises(ValueError):
            layer(x)
        result = jax.vmap(layer)(x)
        self.assertEqual(result.shape, batch_out_shape)
        np.testing.assert_equal(result, np.ones(batch_out_shape))
Esempio n. 5
0
  def test_jit(self):
    net_rng, data_rng = random.split(self._seed)

    net_init = core.Dropout(0.5)

    net = net_init.init(net_rng, state.Shape((10,)))

    j_net = jax.jit(lambda x, rng: net(x, rng=rng))
    x = random.normal(data_rng, [10, 10])
    y = np.array(net(x, rng=net_rng))
    j_y = np.array(j_net(x, net_rng))
    np.testing.assert_allclose(y, j_y)
 def test_stateful_layer(self, define_network, in_shape):
   net_rng, data_rng = random.split(self._seed, 2)
   network_init = define_network()
   network = network_init.init(net_rng, state.Shape(in_shape))
   x = random.normal(data_rng, in_shape)
   y, new_network = network.call_and_update(x)
   y1 = new_network(x)
   np.testing.assert_allclose(y, y1)
   for s, s1 in zip(network.state, new_network.state):
     if s:
       self.assertTrue(np.any(np.not_equal(s, s1)))
     else:
       self.assertTupleEqual(s, s1)
Esempio n. 7
0
    def test_no_training(self):
        epsilon = 1e-5
        axis = (0, 1)
        net_rng, data_rng = random.split(self._seed)

        net_init = normalization.BatchNorm(axis, center=False, scale=False)
        in_shape = (5, 6, 7)
        net = net_init.init(net_rng, state.Shape(in_shape))

        x = random.normal(data_rng, (4, ) + in_shape)
        z = x / np.sqrt(1.0 + epsilon)
        y = jax.vmap(lambda x: net(x, training=False))(x)
        np.testing.assert_almost_equal(z, np.array(y), decimal=6)
Esempio n. 8
0
    def test_call_no_batch(self):
        epsilon = 1e-5
        axis = (0, 1)
        net_rng, data_rng = random.split(self._seed)

        net_init = normalization.BatchNorm(axis, epsilon=epsilon)
        in_shape = (5, 6, 7)
        net = net_init.init(net_rng, state.Shape(in_shape))
        x = random.normal(data_rng, in_shape)
        net_y = net(x)
        np.testing.assert_allclose(x, net_y)

        with self.assertRaises(ValueError):
            net_y = net(x[None])
Esempio n. 9
0
    def test_update_in_combinator(self):
        def template(x, init_key=None):
            def increment(x, init_key=None):
                return Counter(np.zeros(()))(x,
                                             init_key=init_key,
                                             name='counter')

            return nn.Serial([increment, increment])(x,
                                                     init_key=init_key,
                                                     name='increment')

        net = state.init(template)(self._seed, state.Shape(()))
        self.assertEqual(net(np.ones(())), 1.)
        net = state.update(net, np.ones(()))
        self.assertEqual(net(np.ones(())), 3.)
Esempio n. 10
0
    def test_training_is_false(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, training=False, rng=net_rng))
        np.testing.assert_allclose(x, y)

        # Calling twice produces the same results with different rng.
        net_rng, _ = random.split(net_rng)
        y2 = np.array(net(x, training=False, rng=net_rng))
        np.testing.assert_allclose(x, y2)
 def test_update_state(self, define_network, in_shape):
   net_rng, data_rng = random.split(self._seed)
   network_init = define_network()
   network = network_init.init(net_rng, state.Shape(in_shape))
   x = random.normal(data_rng, in_shape)
   next_network = network.update(x, rng=net_rng)
   next_network_1 = network.update(x, rng=net_rng)
   for s, s1 in zip(next_network.state, next_network_1.state):
     if s is None:
       self.assertIsNone(s1)
     else:
       np.testing.assert_allclose(s, s1)
   y = next_network(x, rng=net_rng)
   y1 = next_network_1(x, rng=net_rng)
   np.testing.assert_allclose(y, y1)
Esempio n. 12
0
 def spec(cls, in_spec, window_shape, strides=None, padding='VALID'):
     in_shape = in_spec.shape
     if len(in_shape) > 3:
         raise ValueError('Need to `jax.vmap` in order to batch')
     in_shape = (1, ) + in_shape
     dims = (1, ) + window_shape + (1, )  # NHWC or NHC
     non_spatial_axes = 0, len(window_shape) + 1
     strides = strides or (1, ) * len(window_shape)
     for i in sorted(non_spatial_axes):
         window_shape = window_shape[:i] + (1, ) + window_shape[i:]
         strides = strides[:i] + (1, ) + strides[i:]
     padding = lax.padtype_to_pads(in_shape, window_shape, strides, padding)
     out_shape = lax.reduce_window_shape_tuple(in_shape, dims, strides,
                                               padding)
     out_shape = out_shape[1:]
     return state.Shape(out_shape, dtype=in_spec.dtype)
Esempio n. 13
0
    def test_dropout(self):
        net_rng = self._seed
        network_init = core.Dropout(0.5)
        network = network_init.init(net_rng, state.Shape((-1, 2)))

        grad_fn = jax.jit(jax.grad(reconstruct_loss))

        x0 = jax.numpy.array([[1.0, 1.0], [2.0, 1.0], [3.0, 0.5]])

        initial_loss = reconstruct_loss(network, x0, rng=net_rng)
        grads = grad_fn(network, x0, rng=net_rng)
        self.assertGreater(initial_loss, 0.0)
        network = network.replace(params=jax.tree_util.tree_multimap(
            lambda w, g: w - 0.1 * g, network.params, grads.params))
        final_loss = reconstruct_loss(network, x0, rng=net_rng)
        self.assertEqual(final_loss, initial_loss)
 def test_vmap_call_update(self, define_network, in_shape):
   net_rng, data_rng = random.split(self._seed)
   network_init = define_network()
   network = network_init.init(net_rng, state.Shape(in_shape))
   x = random.normal(data_rng, (10,) + in_shape)
   y, next_network = jax.vmap(
       lambda x: network.call_and_update(x, rng=net_rng),
       out_axes=(0, None))(x)
   y1, next_network_1 = jax.vmap(
       lambda x: network.call_and_update(x, rng=net_rng),
       out_axes=(0, None))(x)
   for s, s1 in zip(next_network.state, next_network_1.state):
     if s is None:
       self.assertIsNone(s1)
     else:
       np.testing.assert_allclose(s, s1)
   np.testing.assert_allclose(y, y1)
Esempio n. 15
0
    def test_batch_norm_moving_vars_grads(self):
        net_rng, data_rng = random.split(self._seed)
        axis = (0, 1)
        in_shape = (2, 2, 2)
        network_init = normalization.BatchNorm(axis)
        network = network_init.init(net_rng, state.Shape(in_shape))

        grad_fn = jax.grad(reconstruct_loss, has_aux=True)

        x0 = random.normal(data_rng, (2, ) + in_shape)

        grads, _ = grad_fn(network, x0)
        grads_moving_mean, grads_moving_var = grads.state
        np.testing.assert_almost_equal(np.zeros_like(grads_moving_mean),
                                       grads_moving_mean)
        np.testing.assert_almost_equal(np.zeros_like(grads_moving_var),
                                       grads_moving_var)
Esempio n. 16
0
    def test_multiple_calls_produces_different_results(self):
        net_rng, data_rng = random.split(self._seed)

        net_init = core.Dropout(0.5)

        net = net_init.init(net_rng, state.Shape((10, )))

        x = random.normal(data_rng, [10, 10])
        y = np.array(net(x, rng=net_rng))
        exp_x = np.where(y == 0, x, y * 0.5)
        np.testing.assert_allclose(x, exp_x, atol=1e-05)

        # Calling with different rng produces different masks and results
        net_rng, _ = random.split(net_rng)
        y2 = np.array(net(x, rng=net_rng))
        self.assertGreater(np.sum(np.isclose(y, y2)), 10)
        self.assertLess(np.sum(np.isclose(y, y2)), 90)
Esempio n. 17
0
    def test_grad_of_function_constant(self):
        def template(x):
            return x + np.ones_like(x)

        net = state.init(template)(self._seed, state.Shape(5))

        def loss(net, x):
            return net(x).sum()

        g = jax.grad(loss)(net, np.ones(5))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 2 * np.ones(5))
Esempio n. 18
0
    def test_grad_of_shared_layer(self):
        def template(x, init_key=None):
            layer = state.init(ScalarMul(2 * np.ones(1)),
                               name='scalar_mul')(init_key, x)
            return layer(layer(x)).sum()

        net = state.init(template)(self._seed, state.Shape(()))

        def loss(net, x):
            return net(x)

        g = jax.grad(loss)(net, np.ones(()))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        onp.testing.assert_array_equal(net(np.ones(())), 36.)
Esempio n. 19
0
    def test_grad_of_function_with_literal(self):
        def template(x, init_key=None):
            # 1.0 behaves like a literal when tracing
            return ScalarMul(1.0)(x, init_key=init_key, name='scalar_mul')

        net = state.init(template)(self._seed, state.Shape(5))

        def loss(net, x):
            return net(x).sum()

        g = jax.grad(loss)(net, np.ones(5))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 6 * np.ones(5))
Esempio n. 20
0
 def spec(cls,
          in_spec,
          out_chan,
          filter_shape,
          strides=None,
          padding='VALID',
          kernel_init=None,
          bias_init=stax.randn(1e-6),
          use_bias=True):
     del use_bias
     in_shape = in_spec.shape
     shapes, _, _ = conv_info(in_shape,
                              out_chan,
                              filter_shape,
                              strides=strides,
                              padding=padding,
                              kernel_init=kernel_init,
                              bias_init=bias_init)
     return state.Shape(shapes[0], dtype=in_spec.dtype)
Esempio n. 21
0
    def test_kwargs_training_rng(self):
        def template(x, rng, training=True, init_key=None):
            k1, k2 = random.split(init_key)
            x = Sampler()(x, rng=rng, name='sampler', init_key=k1)
            return (IsTraining()(
                x, training=training, name='training', init_key=k2) + x)

        net = state.init(template)(self._seed, state.Shape(()),
                                   random.PRNGKey(0))
        x0n = net(np.ones(()), random.PRNGKey(0), training=False)
        x0t = net(np.ones(()), random.PRNGKey(0), training=True)
        x1n = net(np.ones(()), random.PRNGKey(1), training=False)
        x1t = net(np.ones(()), random.PRNGKey(1), training=True)
        # Different seeds generate different results
        # Same seed generates offset based on training flag
        self.assertNotEqual(x0n, x1n)
        self.assertNotEqual(x0t, x1t)
        onp.testing.assert_allclose(x0n, x0t - 1, rtol=1e-6)
        onp.testing.assert_allclose(x1n, x1t - 1, rtol=1e-6)
Esempio n. 22
0
 def spec(cls, in_spec, dim_out):
     if isinstance(dim_out, int):
         dim_out = (dim_out, )
     else:
         dim_out = tuple(dim_out)
     return state.Shape(dim_out, dtype=in_spec.dtype)
Esempio n. 23
0
 def spec(cls, in_spec):
     in_shape = in_spec.shape
     out_shape = (int(np.prod(in_shape)), )
     return state.Shape(out_shape, dtype=in_spec.dtype)
Esempio n. 24
0
 def test_flatten(self):
     layer_params = base.LayerParams(params=(1, 2), state=3)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)), name='foo')
     self.assertTupleEqual((((1, 2), 3), ((), 'foo')), layer.flatten())
Esempio n. 25
0
 def test_init(self):
     layer_params = base.LayerParams((1, 2), 3)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     self.assertTupleEqual((1, 2), layer.params)
     self.assertEqual(3, layer.info)
Esempio n. 26
0
 def test_init_adds_tuple_to_params(self):
     layer_params = base.LayerParams(1, 2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     self.assertEqual(1, layer.params)
     self.assertEqual(2, layer.info)
Esempio n. 27
0
 def spec(cls, in_spec, dim_out, **kwargs):
   in_shape = in_spec.shape
   out_shape = in_shape[:-1] + (dim_out,)
   return state.Shape(out_shape, dtype=in_spec.dtype)
Esempio n. 28
0
 def test_call_pass_params(self):
     layer_params = base.LayerParams(params=1, state=1)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3)
     self.assertTupleEqual((1, 3, {}), outputs)
Esempio n. 29
0
 def test_call_with_needs_rng(self):
     layer_params = base.LayerParams(params=1, state=1, info=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3, rng=1)
     self.assertTupleEqual((1, 3, {'rng': 1}), outputs)
Esempio n. 30
0
 def test_call_with_has_training(self):
     layer_params = base.LayerParams(params=1, state=1, info=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3, training=True)
     self.assertTupleEqual((1, 3, {'training': True}), outputs)