Exemplo n.º 1
0
    def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = base.without_apply_rng(
            base.transform_with_state(g))
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = base.without_apply_rng(
            base.transform_with_state(h))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemplo n.º 2
0
 def test_get_state_no_init_raises(self):
   init_fn, apply_fn = base.transform_with_state(lambda: base.get_state("i"))
   with self.assertRaisesRegex(ValueError, "set an init function"):
     init_fn(None)
   state = params = {"~": {}}
   with self.assertRaisesRegex(ValueError, "set an init function"):
     apply_fn(params, state, None)
Exemplo n.º 3
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = base.without_apply_rng(base.transform_with_state(inner_fn))
     inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0))
     return vmap_inner(inner_p, inner_s, x)[0]
Exemplo n.º 4
0
    def test_argspec(self):
        init_fn, apply_fn = base.transform_with_state(lambda: None)
        init_fn_spec = inspect.getfullargspec(init_fn)
        apply_fn_spec = inspect.getfullargspec(apply_fn)

        self.assertEqual(init_fn_spec.args, ["rng"])
        self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])
Exemplo n.º 5
0
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = base.without_apply_rng(
            base.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemplo n.º 6
0
 def test_stateful_module(self):
     init_fn, apply_fn = base.transform_with_state(lambda: CountingModule()
                                                   ())  # pylint: disable=unnecessary-lambda
     params, state = init_fn(None)
     self.assertEqual(state, {"counting_module": {"count": 0}})
     _, state = apply_fn(params, state, None)
     self.assertEqual(state, {"counting_module": {"count": 10}})
Exemplo n.º 7
0
 def wrapper(*a, **k):
     """Runs init and apply of f."""
     rng = random.PRNGKey(seed) if seed is not None else None
     transformed = base.transform_with_state(lambda: f(*a, **k))
     params, state = transformed.init(rng)
     if run_apply:
         transformed.apply(params, state, rng)
Exemplo n.º 8
0
 def test_get_state_no_shape_raises(self):
     init_fn, apply_fn = base.transform_with_state(
         lambda: base.get_state("i", init=jnp.zeros))
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         init_fn(None)
     state = params = {"~": {}}
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         apply_fn(params, state, None)
Exemplo n.º 9
0
 def test_invalid_rng_state(self):
   f = base.transform_with_state(lambda: None)
   with self.assertRaisesRegex(
       ValueError, "Init must be called with an RNG as the first argument"):
     f.init("nonsense")
   with self.assertRaisesRegex(
       ValueError, "Apply must be called with an RNG as the third argument"):
     f.apply({}, {"x": {}}, "nonsense")
Exemplo n.º 10
0
    def test_without_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = base.without_state(base.transform_with_state(f))
        params = init_fn(None)
        out = apply_fn(params, None)
        self.assertEqual(out, 0)
Exemplo n.º 11
0
    def test_grad_and_jit(self):
        def f(x):
            g = stateful.grad(SquareModule())(x)
            return g

        x = jnp.array(3.)
        f = base.transform_with_state(f)
        params, state = jax.jit(f.init)(None, x)
        g, state = jax.jit(f.apply)(params, state, None, x)
        np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
Exemplo n.º 12
0
    def test_without_state_raises_if_state_used(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, _ = base.without_state(base.transform_with_state(f))

        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            init_fn(None)
Exemplo n.º 13
0
    def test_stateful(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, apply_fn = base.transform_with_state(f)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"count": 0}})
        _, state = apply_fn(params, state, None)
        self.assertEqual(state, {"~": {"count": 10}})
Exemplo n.º 14
0
    def test_cond(self):
        def f(x):
            mod = SquareModule()
            return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))

        f = base.transform_with_state(f)
        for x, y in ((1, 4), (2, 4), (3, 16)):
            x, y = map(jnp.array, (x, y))
            params, state = f.init(None, x)
            out, state = f.apply(params, state, None, x)
            self.assertEqual(state, {"square_module": {"y": y}})
            self.assertEqual(out, y)
Exemplo n.º 15
0
    def test_set_then_get(self):
        def net():
            base.set_state("i", 1)
            return base.get_state("i")

        init_fn, apply_fn = base.transform_with_state(net)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"i": 1}})

        for i in range(10):
            state_in = {"~": {"i": i}}
            y, state_out = apply_fn(params, state_in, None)
            self.assertEqual(y, 1)
            self.assertEqual(state_out, {"~": {"i": 1}})
Exemplo n.º 16
0
    def test_ema_is_identity_on_unchanged_data(self):
        def f(x):
            return moving_averages.ExponentialMovingAverage(0.5)(x)

        inp_value = 1.0
        init_fn, apply_fn = base.without_apply_rng(
            base.transform_with_state(f))
        _, params_state = init_fn(None, inp_value)

        # The output should never change as long as the input doesn't.
        value = inp_value
        for _ in range(10):
            value, params_state = apply_fn(None, params_state, value)
            # Floating point error creeps up to 1e-7 (the default).
            np.testing.assert_allclose(inp_value, value, rtol=1e-6)
Exemplo n.º 17
0
    def test_ema_naming_scheme(self):
        ema_name = "this_is_a_wacky_but_valid_name"
        linear_name = "so_is_this"

        def f():
            return basic.Linear(output_size=2,
                                name=linear_name)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2, name=ema_name)(x)

        init_fn, _ = base.transform_with_state(g)
        _, params_state = init_fn(None, params)

        expected_ema_states = [
            "{}/{}__{}".format(ema_name, linear_name, s) for s in ["w", "b"]
        ]
        self.assertEqual(set(expected_ema_states), set(params_state.keys()))
Exemplo n.º 18
0
    def test_sn_naming_scheme(self):
        sn_name = "this_is_a_wacky_but_valid_name"
        linear_name = "so_is_this"

        def f():
            return basic.Linear(output_size=2,
                                name=linear_name)(jnp.zeros([6, 6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return spectral_norm.SNParamsTree(ignore_regex=".*b",
                                              name=sn_name)(x)

        init_fn, _ = base.transform_with_state(g)
        _, params_state = init_fn(random.PRNGKey(428), params)

        expected_sn_states = [
            "{}/{}__{}".format(sn_name, linear_name, s) for s in ["w"]
        ]
        self.assertSameElements(expected_sn_states, params_state.keys())
Exemplo n.º 19
0
    def test_ignore_regex(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x)

        init_fn, apply_fn = base.without_apply_rng(
            base.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # W should be the same!
        # ... but b should have changed!
        self.assertTrue((changed_params.linear.b != ema_params.linear.b).all())
        self.assertTrue((changed_params.linear.w == ema_params.linear.w).all())
Exemplo n.º 20
0
 def test_without_state_raises_if_state_used(self):
     init_fn, _ = base.without_state(
         base.transform_with_state(lambda: CountingModule()()))  # pylint: disable=unnecessary-lambda
     with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
         init_fn(None)
Exemplo n.º 21
0
 def test_without_state(self):
     init_fn, apply_fn = base.without_state(
         base.transform_with_state(lambda: ScalarModule()()))  # pylint: disable=unnecessary-lambda
     params = init_fn(None)
     out = apply_fn(params, None)
     self.assertEqual(out, 0)
Exemplo n.º 22
0
 def test_get_state_no_init(self):
     _, apply_fn = base.transform_with_state(lambda: base.get_state("i"))
     for i in range(10):
         state_in = {"~": {"i": i}}
         _, state_out = apply_fn({}, state_in, None)
         self.assertEqual(state_in, state_out)