def test_without_apply_rng_multi_transform(self):
        def net(name):
            def f(x):
                p = base.get_parameter(name, [], init=jnp.zeros)
                return p + x

            return f

        def mod():
            one = net(name='one')
            two = net(name='two')

            def init(x):
                z = one(x)
                return two(z)

            return init, (one, two)

        f = multi_transform.without_apply_rng(
            multi_transform.multi_transform_with_state(mod))
        self.assertIsInstance(f, multi_transform.MultiTransformedWithState)
        params, state = f.init(None, jnp.ones(()))
        f.apply[0](params, state, jnp.ones(()))
        f.apply[1](params, state, jnp.ones(()))

        f = multi_transform.without_apply_rng(
            multi_transform.multi_transform(mod))
        self.assertIsInstance(f, multi_transform.MultiTransformed)
        params = f.init(None, jnp.ones(()))
        f.apply[0](params, jnp.ones(()))
        f.apply[1](params, jnp.ones(()))
    def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = multi_transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = multi_transform.without_apply_rng(
            transform.transform_with_state(h))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = multi_transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Beispiel #4
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = multi_transform.without_apply_rng(
         transform.transform(inner_fn))
     inner_p = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, 0))
     return vmap_inner(inner_p, x)
Beispiel #5
0
        def build_and_init_stack(module_class):
            def stack_fn(x):
                mod = module_class()
                return layer_stack.layer_stack(1)(mod)(x)

            stack = multi_transform.without_apply_rng(
                transform.transform(stack_fn))
            stack.init(jax.random.PRNGKey(1729), jnp.ones([5]))
    def test_ema_is_identity_on_unchanged_data(self):
        def f(x):
            return moving_averages.ExponentialMovingAverage(0.5)(x)

        inp_value = 1.0
        init_fn, apply_fn = multi_transform.without_apply_rng(
            transform.transform_with_state(f))
        _, params_state = init_fn(None, inp_value)

        # The output should never change as long as the input doesn't.
        value = inp_value
        for _ in range(10):
            value, params_state = apply_fn(None, params_state, value)
            # Floating point error creeps up to 1e-7 (the default).
            np.testing.assert_allclose(inp_value, value, rtol=1e-6)
    def test_ignore_regex(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x)

        init_fn, apply_fn = multi_transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # W should be the same!
        # ... but b should have changed!
        self.assertTrue(
            (changed_params["linear"]["b"] != ema_params["linear"]["b"]).all())
        self.assertTrue(
            (changed_params["linear"]["w"] == ema_params["linear"]["w"]).all())