def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = transform.without_apply_rng(
            transform.transform_with_state(h))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemple #3
0
    def test_transparent_lift_with_state_nested(self):
        @transform.transform_with_state
        def inner():
            w = base.get_state("w", [], init=jnp.zeros)
            w += 1
            base.set_state("w", w)
            return w

        class Outer(module.Module):
            def __call__(self):
                lifted, updater = lift.transparent_lift_with_state(inner.init)
                params, state = lifted(None)
                out, state = inner.apply(params, state, None)
                updater.update(state)
                return out, state

        outer = transform.transform_with_state(lambda: Outer()())  # pylint: disable=unnecessary-lambda
        params, state = outer.init(None)
        self.assertEmpty(params)
        self.assertEqual(jax.tree_map(int, state), {"outer/~": {"w": 0}})

        for expected in (1, 2, 3):
            (w, inner_state), state = outer.apply(params, state, None)
            self.assertEqual(jax.tree_map(int, inner_state),
                             {"~": {
                                 "w": expected
                             }})
            self.assertEqual(w, expected)
            self.assertEmpty(params)
            self.assertEqual(state, {"outer/~": {"w": expected}})
Exemple #4
0
 def test_stateful_module(self):
     init_fn, apply_fn = transform.transform_with_state(
         lambda: CountingModule()())  # pylint: disable=unnecessary-lambda
     params, state = init_fn(None)
     self.assertEqual(state, {"counting_module": {"count": 0}})
     _, state = apply_fn(params, state, None)
     self.assertEqual(state, {"counting_module": {"count": 10}})
Exemple #5
0
    def test_lift_with_state(self):
        @transform.transform_with_state
        def inner():
            w = base.get_state("w", [], init=jnp.zeros)
            w += 1
            base.set_state("w", w)
            return w

        def outer():
            lifted, updater = lift.lift_with_state(inner.init)
            params, state = lifted(None)
            self.assertEmpty(params)
            out, state = inner.apply(params, state, None)
            updater.update(state)
            return out, state

        outer = transform.transform_with_state(outer)
        params, state = outer.init(None)
        self.assertEmpty(params)
        self.assertEqual(jax.tree_map(int, state), {"lifted/~": {"w": 0}})

        for expected in (1, 2, 3):
            (w, inner_state), state = outer.apply(params, state, None)
            self.assertEqual(jax.tree_map(int, inner_state),
                             {"~": {
                                 "w": expected
                             }})
            self.assertEqual(w, expected)
            self.assertEmpty(params)
            self.assertEqual(state, {"lifted/~": {"w": expected}})
Exemple #6
0
    def test_eval_shape(self):
        def some_shape_changing_fun(x):
            return x[0, :]

        def f(x):
            m = CountingModule(op=some_shape_changing_fun)
            # state is not changed in this call
            out_shape_struct = stateful.eval_shape(m, x)
            return m(x), out_shape_struct

        f = transform.transform_with_state(f)
        key = jax.random.PRNGKey(42)
        in_shape = (10, 10)
        x = jnp.ones(in_shape)
        params, state = f.init(key, x)
        self.assertEqual(list(state), ["counting_module"])
        self.assertEqual(list(state["counting_module"]), ["count"])
        np.testing.assert_allclose(state["counting_module"]["count"],
                                   0,
                                   rtol=1e-4)
        (out, shape_struct), state = f.apply(params, state, key, x)
        # Count is only advanced once
        np.testing.assert_allclose(state["counting_module"]["count"],
                                   1,
                                   rtol=1e-4)
        np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4)
        self.assertEqual(shape_struct.shape, (in_shape[1], ))
Exemple #7
0
 def test_get_state_no_init(self):
     _, apply_fn = transform.transform_with_state(
         lambda: base.get_state("i"))
     for i in range(10):
         state_in = {"~": {"i": i}}
         _, state_out = apply_fn({}, state_in, None)
         self.assertEqual(state_in, state_out)
Exemple #8
0
 def test_without_state_raises_if_state_used_on_apply(self):
     f = lambda: base.set_state("~", 1)
     f = transform.without_state(transform.transform_with_state(f))
     rng = jax.random.PRNGKey(42)
     with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
         params = f.init(rng)
         f.apply(params, rng)
Exemple #9
0
  def test_simple_training_cross_replica_axis_index_groups(self):
    ldc = jax.local_device_count()
    if ldc < 2:
      self.skipTest("Cross-replica test requires at least 2 devices.")
    num_groups = ldc // 2
    num_group_devices = ldc // num_groups
    # for 8 devices this produces [[0, 1], [2, 3], [4, 5], [6, 7]] groups.
    groups = np.arange(ldc).reshape(num_groups, num_group_devices).tolist()

    def f(x, is_training=True):
      return batch_norm.BatchNorm(
          create_scale=False,
          create_offset=False,
          decay_rate=0.9,
          cross_replica_axis="i",
          cross_replica_axis_index_groups=groups,
      )(x, is_training=is_training)

    f = transform.transform_with_state(f)

    inputs = np.arange(ldc * 4).reshape(ldc, 4).astype(np.float32)
    key = np.broadcast_to(jax.random.PRNGKey(42), (ldc, 2))
    params, state = jax.pmap(f.init, axis_name="i")(key, inputs)
    result, _ = jax.pmap(f.apply, axis_name="i")(params, state, key, inputs)

    expected = np.empty_like(inputs)
    for g in range(num_groups):
      group_inputs = inputs[num_group_devices*g:num_group_devices*(g + 1)]
      group_mean = np.mean(group_inputs, axis=0)
      group_std = np.std(group_inputs, axis=0) + 1e-10
      group_inputs = (group_inputs - group_mean) / group_std
      expected[num_group_devices*g:num_group_devices*(g + 1)] = group_inputs

    np.testing.assert_array_almost_equal(result, expected)
Exemple #10
0
  def test_vmap(self):
    def g(x):
      return CountingModule()(x)

    def f(x):
      return stateful.vmap(g)(x)

    f = transform.transform_with_state(f)

    x = jnp.ones([4]) + 1
    params, state = f.init(None, x)

    # State should not be mapped.
    self.assertEmpty(params)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 0)

    # The output should be mapped but state should not be.
    y, state = f.apply(params, state, None, x)
    self.assertEqual(y.shape, (4,))
    np.testing.assert_allclose(y, x ** 2)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 1)
Exemple #11
0
    def test_argspec(self):
        init_fn, apply_fn = transform.transform_with_state(lambda: None)
        init_fn_spec = inspect.getfullargspec(init_fn)
        apply_fn_spec = inspect.getfullargspec(apply_fn)

        self.assertEqual(init_fn_spec.args, ["rng"])
        self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])
Exemple #12
0
 def wrapper(*a, **k):
     """Runs init and apply of f."""
     rng = random.PRNGKey(seed) if seed is not None else None
     transformed = transform.transform_with_state(lambda: f(*a, **k))
     params, state = transformed.init(rng)
     if run_apply:
         transformed.apply(params, state, rng)
Exemple #13
0
    def test_scan_with_state(self, unroll_length):
        def f(xs):
            m = CountingModule()

            def sf(c, x):
                self.assertEqual(c, ())
                return c, m(x)

            _, ys = stateful.scan(sf, (), xs)
            return ys

        f = transform.transform_with_state(f)
        key = jax.random.PRNGKey(42)
        xs = jnp.arange(unroll_length)
        params, state = f.init(key, xs)
        self.assertEqual(list(state), ["counting_module"])
        self.assertEqual(list(state["counting_module"]), ["count"])
        np.testing.assert_allclose(state["counting_module"]["count"],
                                   0,
                                   rtol=1e-4)
        ys, state = f.apply(params, state, key, xs)
        np.testing.assert_allclose(state["counting_module"]["count"],
                                   unroll_length,
                                   rtol=1e-4)
        np.testing.assert_allclose(ys, xs**2, rtol=1e-4)
Exemple #14
0
 def test_invalid_rng_state(self):
   f = transform.transform_with_state(lambda: None)
   with self.assertRaisesRegex(
       ValueError, "Init must be called with an RNG as the first argument"):
     f.init("nonsense")
   with self.assertRaisesRegex(
       ValueError, "Apply must be called with an RNG as the third argument"):
     f.apply({}, {"x": {}}, "nonsense")
Exemple #15
0
 def test_get_state_no_shape_raises(self):
     init_fn, apply_fn = transform.transform_with_state(
         lambda: base.get_state("i", init=jnp.zeros))
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         init_fn(None)
     state = params = {"~": {}}
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         apply_fn(params, state, None)
Exemple #16
0
 def test_get_state_no_init_raises(self):
     init_fn, apply_fn = transform.transform_with_state(
         lambda: base.get_state("i"))
     with self.assertRaisesRegex(ValueError, "set an init function"):
         init_fn(None)
     state = params = {"~": {}}
     with self.assertRaisesRegex(ValueError, "set an init function"):
         apply_fn(params, state, None)
Exemple #17
0
    def test_unused_updater(self):
        def f() -> lift.LiftWithStateUpdater:
            f = transform.transform_with_state(lambda: None)
            return lift.lift_with_state(f.init)[1]

        f = transform.transform_with_state(f)

        with self.assertRaisesRegex(ValueError, "StateUpdater.*must be used"):
            f.init(None)
Exemple #18
0
  def test_named_call(self):
    def f(x):
      return stateful.named_call(SquareModule(), name="square")(x)

    x = jnp.array(2.)
    rng = jax.random.PRNGKey(42)
    init, apply = transform.transform_with_state(f)
    params, state = init(rng, x)
    y, state = jax.jit(apply)(params, state, rng, x)
    self.assertEqual(y, x ** 2)
Exemple #19
0
 def wrapper(*a, **k):
     """Runs init and apply of f."""
     rng = random.PRNGKey(seed) if seed is not None else None
     init, apply = transform.transform_with_state(lambda: f(*a, **k))
     if jax_transform:
         init, apply = map(jax_transform, (init, apply))
     params, state = init(rng)
     if run_apply:
         out, state = apply(params, state, rng)
         return out
Exemple #20
0
  def test_without_apply_rng_output_type(self):
    def f():
      w = base.get_parameter("w", [], init=jnp.zeros)
      return w

    f = transform.without_apply_rng(transform.transform_with_state(f))
    self.assertIsInstance(f, transform.TransformedWithState)

    f = transform.without_apply_rng(transform.transform(f))
    self.assertIsInstance(f, transform.Transformed)
Exemple #21
0
 def test_empty_lift_with_state(self, ignore_update):
     f = transform.transform_with_state(lambda: None)
     init_fn, updater = lift.lift_with_state(f.init)
     params, state = init_fn(None)
     self.assertEmpty(params)
     self.assertEmpty(state)
     if ignore_update:
         updater.ignore_update()
     else:
         updater.update({})
Exemple #22
0
  def test_grad_and_jit(self):
    def f(x):
      g = stateful.grad(SquareModule())(x)
      return g

    x = jnp.array(3.)
    f = transform.transform_with_state(f)
    params, state = jax.jit(f.init)(None, x)
    g, state = jax.jit(f.apply)(params, state, None, x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
Exemple #23
0
    def test_without_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = transform.without_state(
            transform.transform_with_state(f))
        params = init_fn(None)
        out = apply_fn(params, None)
        self.assertEqual(out, 0)
Exemple #24
0
    def testEmaCrossReplica(self):
        embedding_dim = 6
        batch_size = 16
        inputs = np.random.rand(jax.local_device_count(), batch_size,
                                embedding_dim)
        embeddings = {}
        perplexities = {}

        for axis_name in [None, 'i']:

            def my_function(x, axis_name):
                decay = np.array(0.9, dtype=np.float32)
                vqvae_module = vqvae.VectorQuantizerEMA(
                    embedding_dim=embedding_dim,
                    num_embeddings=7,
                    commitment_cost=0.5,
                    decay=decay,
                    cross_replica_axis=axis_name,
                    dtype=jnp.float32)

                outputs = vqvae_module(x, is_training=True)
                return vqvae_module.embeddings, outputs['perplexity']

            vqvae_f = transform.transform_with_state(
                functools.partial(my_function, axis_name=axis_name))

            rng = jax.random.PRNGKey(42)
            rng = jnp.broadcast_to(rng,
                                   (jax.local_device_count(), rng.shape[0]))

            params, state = jax.pmap(vqvae_f.init, axis_name='i')(rng, inputs)
            update_fn = jax.pmap(vqvae_f.apply, axis_name='i')

            for _ in range(10):
                outputs, state = update_fn(params, state, None, inputs)
            embeddings[axis_name], perplexities[axis_name] = outputs

        # In the single-device case, specifying a cross_replica_axis should have
        # no effect. Otherwise, it should!
        if jax.device_count() == 1:
            # Have to use assert_allclose here rather than checking exact matches to
            # make the test pass on GPU, presumably because of nondeterministic
            # reductions.
            np.testing.assert_allclose(embeddings[None],
                                       embeddings['i'],
                                       rtol=1e-6,
                                       atol=1e-6)
            np.testing.assert_allclose(perplexities[None],
                                       perplexities['i'],
                                       rtol=1e-6,
                                       atol=1e-6)
        else:
            self.assertFalse((embeddings[None] == embeddings['i']).all())
            self.assertFalse((perplexities[None] == perplexities['i']).all())
Exemple #25
0
 def test_cond(self):
   def f(x):
     mod = SquareModule()
     return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))
   f = transform.transform_with_state(f)
   for x, y in ((1, 4), (2, 4), (3, 16)):
     x, y = map(jnp.array, (x, y))
     params, state = f.init(None, x)
     out, state = f.apply(params, state, None, x)
     self.assertEqual(state, {"square_module": {"y": y}})
     self.assertEqual(out, y)
Exemple #26
0
    def test_without_state_raises_if_state_used(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, _ = transform.without_state(transform.transform_with_state(f))

        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            init_fn(None)
Exemple #27
0
    def test_updater_used_in_different_inner_transform(self, updater_fn):
        def f():
            g = transform.transform_with_state(lambda: None)
            _, updater = lift.lift_with_state(g.init)
            transform.transform_with_state(lambda: updater_fn(updater)).init(
                None)

        f = transform.transform_with_state(f)

        with self.assertRaisesRegex(
                ValueError, "must be used within the same call to init/apply"):
            f.init(None)
Exemple #28
0
    def test_stateful(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, apply_fn = transform.transform_with_state(f)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"count": 0}})
        _, state = apply_fn(params, state, None)
        self.assertEqual(state, {"~": {"count": 10}})
Exemple #29
0
  def test_switch(self):
    def f(i, x):
      mod = SquareModule()
      branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)]
      return stateful.switch(i, branches, x)

    f = transform.transform_with_state(f)
    for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)):
      i, x, y = map(jnp.array, (i, x, y))
      params, state = f.init(None, i, x)
      out, state = f.apply(params, state, None, i, x)
      self.assertEqual(state, {"square_module": {"y": y}})
      self.assertEqual(out, y)
Exemple #30
0
    def test_set_then_get(self):
        def net():
            base.set_state("i", 1)
            return base.get_state("i")

        init_fn, apply_fn = transform.transform_with_state(net)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"i": 1}})

        for i in range(10):
            state_in = {"~": {"i": i}}
            y, state_out = apply_fn(params, state_in, None)
            self.assertEqual(y, 1)
            self.assertEqual(state_out, {"~": {"i": 1}})