Example #1
0
class ResnetTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    @test_utils.transform_and_run
    def test_simple(self, resnet_v2, bottleneck):
        image = jnp.ones([2, 64, 64, 3])
        model = resnet.ResNet([1, 1, 1, 1],
                              10,
                              resnet_v2=resnet_v2,
                              bottleneck=bottleneck)

        for is_training in (True, False):
            logits = model(image, is_training=is_training)
            self.assertEqual(logits.shape, (2, 10))

    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    def test_local_stats(self, resnet_v2, bottleneck):
        def forward_fn(image):
            model = resnet.ResNet([1, 1, 1, 1],
                                  10,
                                  resnet_v2=resnet_v2,
                                  bottleneck=bottleneck)
            return model(image, is_training=False, test_local_stats=True)

        forward = transform.transform(forward_fn, apply_rng=True)
        rng = jax.random.PRNGKey(42)
        image = jnp.ones([2, 64, 64, 3])
        params = forward.init(rng, image)
        logits = forward.apply(params, None, image)
        self.assertEqual(logits.shape, (2, 10))

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_block_list(self, list_length):
        block_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "blocks_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_channel_list(self, list_length):
        channel_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "channels_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet([1, 1, 1, 1],
                          10, {
                              "decay_rate": 0.9,
                              "eps": 1e-5
                          },
                          channels_per_group=channel_list)
Example #2
0
class NumpyInputsTest(parameterized.TestCase):

  @test_utils.combined_named_parameters(
      descriptors.ALL_MODULES,
      test_utils.named_bools('np_inputs'),
      test_utils.named_bools('np_params'),
      test_utils.named_bools('close_over_params'))
  def test_numpy_and_jax_results_close(
      self,
      module_fn: ModuleFn,
      shape: Tuple[int, ...],
      dtype: jnp.dtype,
      np_params: bool,
      np_inputs: bool,
      close_over_params: bool,
  ):
    if not (np_params or np_inputs):
      self.skipTest('Pure JAX variants tested elsewhere')

    f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda

    rng = jax.random.PRNGKey(42)
    x = jnp.ones(shape, dtype)
    params, state = f.init(rng, x)
    if close_over_params:
      apply_fn = functools.partial(f.apply, params, state)
      out, new_state = jax.jit(apply_fn)(rng, x)
    else:
      out, new_state = jax.jit(f.apply)(params, state, rng, x)

    if np_inputs:
      rng, x = jax.device_get((rng, x))

      with self.subTest('init'):
        params2, state2 = f.init(rng, x)
        tree_assert_allclose(params, params2)
        tree_assert_allclose(state, state2)

    with self.subTest('apply'):
      if np_params:
        params, state = jax.device_get((params, state))

      if close_over_params:
        apply_fn = functools.partial(f.apply, params, state)
        out2, new_state2 = jax.jit(apply_fn)(rng, x)
      else:
        out2, new_state2 = jax.jit(f.apply)(params, state, rng, x)

      tree_assert_allclose(out, out2)
      tree_assert_allclose(new_state, new_state2)
Example #3
0
class JaxToTfTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                          test_utils.named_bools("init"),
                                          TF_TRANSFORM, JAX_TRANSFORM)
    def test_convert(
        self,
        module_fn: ModuleFn,
        shape,
        dtype,
        init: bool,
        tf_transform,
        jax_transform,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn),
                               DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        get = lambda t: jax.tree_map(lambda x: x.numpy(), t)

        if init:
            init_jax = jax_transform(f.init)
            init_tf = tf_transform(jax2tf.convert(f.init))
            jax.tree_multimap(assert_allclose, init_jax(rng, x),
                              get(init_tf(rng, x)))

        else:
            params, state = f.init(rng, x)
            apply_jax = jax_transform(f.apply)
            apply_tf = tf_transform(jax2tf.convert(f.apply))
            jax.tree_multimap(assert_allclose,
                              apply_jax(params, state, rng, x),
                              get(apply_tf(params, state, rng, x)))
Example #4
0
class HaikuTransformsTest(parameterized.TestCase):

  @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                        test_utils.named_bools('init'))
  def test_hk_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
      init: bool,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, jit=False):
      mod = module_fn()
      if jit:
        mod = hk.jit(mod)
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4)

    # NOTE: We shard init/apply tests since some modules are expensive to jit
    # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
    if init:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.init)(rng, x),
                        f.init(rng, x, jit=True))

    else:
      params, state = f.init(rng, x)
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x, jit=True))

  @test_utils.combined_named_parameters(
      # TODO(tomhennigan) Enable once grad for _scan_transpose implemented.
      set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES))
  def test_hk_remat(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, remat=False):
      mod = module_fn()
      if remat:
        mod = hk.remat(mod)
      out = mod(x)
      if isinstance(out, dict):
        out = out['loss']
      return jnp.mean(out)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
    grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                             has_aux=True)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      grad_jax_remat(params, state, rng, x),
                      grad_hk_remat(params, state, rng, x))

  @test_utils.combined_named_parameters(descriptors.ALL_MODULES)
  def test_profiler_name_scopes(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, name_scopes=False):
      hk.experimental.profiler_name_scopes(enabled=name_scopes)
      mod = module_fn()
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      f.apply(params, state, rng, x),
                      f.apply(params, state, rng, x, name_scopes=True))

    # TODO(lenamartens): flip to True when default changes
    hk.experimental.profiler_name_scopes(enabled=False)

  @test_utils.combined_named_parameters(descriptors.ALL_MODULES)
  def test_optimize_rng_use_under_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x):
      return module_fn()(x)

    f = hk.transform_with_state(hk.experimental.optimize_rng_use(g))

    module_type = descriptors.module_type(module_fn)
    atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
    assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol)

    params, state = jax.jit(f.init)(rng, x)
    jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x))

    if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA):
      # For stochastic modules just test apply runs.
      jax.device_get(jax.jit(f.apply)(params, state, rng, x))

    else:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x))
Example #5
0
class SummariseTest(parameterized.TestCase):
    def test_empty(self):
        self.assertEmpty(get_summary(lambda: None))

    def test_filters_ctor_only(self):
        f = lambda: IdentityModule()  # NOTE: Just calling ctor.
        self.assertEmpty(get_summary(f))

    @parameterized.parameters(*range(1, 5))
    def test_one_row_per_method_call(self, num_calls):
        def f():
            m = IdentityModule()
            for _ in range(num_calls):
                m(x)

        x = jnp.ones([])
        invocations = get_summary(f)
        self.assertLen(invocations, num_calls)
        for invocation in invocations[1:]:
            self.assertEqual(invocations[0].context.method_name,
                             invocation.context.method_name)

    @test_utils.combined_named_parameters(test_utils.named_bools("params"),
                                          test_utils.named_range(
                                              "num_elems", 8))
    def test_params_or_state(self, params, num_elems):
        def cls():
            for i in range(num_elems):
                g = base.get_parameter if params else base.get_state
                g(f"x{i}", [], init=jnp.zeros)

        f = lambda: basic.to_module(cls)(name="foo")()
        invocations = get_summary(f)
        invocation, = invocations
        details = invocation.module_details
        d = details.params if params else details.state
        self.assertEqual(list(d), [f"foo/x{i}" for i in range(num_elems)])

    def test_jitted_f(self):
        witness = []

        def f(x):
            witness.append(None)
            return basic.Linear(1)(x)

        f = transform.transform(f)
        rng = jax.random.PRNGKey(42)
        x = jnp.zeros([1, 1])
        params = f.init(rng, x)
        del witness[:]

        # This layer of indirection (`g`) means summarise cannot unpack `f` and
        # strip our jit.
        jit_apply = jax.jit(f.apply)
        g = lambda params, x: jit_apply(params, None, x)
        for _ in range(2):
            g(params, x)  # Warm up JIT.
            self.assertLen(witness, 1)

        summary = get_summary(g, params, x)
        self.assertLen(summary, 1)
Example #6
0
class ResnetTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    @test_utils.transform_and_run
    def test_simple(self, resnet_v2, bottleneck):
        image = jnp.ones([2, 64, 64, 3])
        model = resnet.ResNet([1, 1, 1, 1],
                              10,
                              resnet_v2=resnet_v2,
                              bottleneck=bottleneck)

        for is_training in (True, False):
            logits = model(image, is_training=is_training)
            self.assertEqual(logits.shape, (2, 10))

    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    def test_local_stats(self, resnet_v2, bottleneck):
        def forward_fn(image):
            model = resnet.ResNet([1, 1, 1, 1],
                                  10,
                                  resnet_v2=resnet_v2,
                                  bottleneck=bottleneck)
            return model(image, is_training=False, test_local_stats=True)

        forward = transform.transform(forward_fn)
        rng = jax.random.PRNGKey(42)
        image = jnp.ones([2, 64, 64, 3])
        params = forward.init(rng, image)
        logits = forward.apply(params, None, image)
        self.assertEqual(logits.shape, (2, 10))

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_block_list(self, list_length):
        block_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "blocks_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_channel_list(self, list_length):
        channel_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "channels_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet([1, 1, 1, 1],
                          10, {
                              "decay_rate": 0.9,
                              "eps": 1e-5
                          },
                          channels_per_group=channel_list)

    @test_utils.combined_named_parameters(
        [(i, (getattr(resnet, i), n))
         for i, n in zip(_RESNETS, _RESNET_NUM_PARAMS)],
        test_utils.named_bools("resnet_v2"),
    )
    def test_num_params(self, resnet_class_and_num_params, resnet_v2):
        resnet_class, expected_num_params = resnet_class_and_num_params

        def model_func(img):
            model = resnet_class(1000, resnet_v2=resnet_v2)
            return model(img, is_training=True)

        model = hk.transform_with_state(model_func)
        image = jnp.ones([2, 64, 64, 3])
        rng = jax.random.PRNGKey(0)
        params, _ = model.init(rng, image)
        num_params = sum(
            jnp.prod(p.shape).item() for p in jax.tree_leaves(params))
        self.assertGreater(num_params, int(0.998 * expected_num_params))
        self.assertLess(num_params, int(1.002 * expected_num_params))

    @test_utils.combined_named_parameters(
        [(i, (getattr(resnet, i), p))
         for i, p in zip(_RESNETS, _RESNET_HAS_PROJECTION)],
        test_utils.named_bools("resnet_v2"),
    )
    @test_utils.transform_and_run
    def test_has_projection(self, resnet_class_and_has_projection, resnet_v2):
        resnet_class, has_projection = resnet_class_and_has_projection
        model = resnet_class(1000, resnet_v2=resnet_v2)
        for i, block_group in enumerate(model.block_groups):
            if i == 0:
                self.assertEqual(hasattr(block_group.blocks[0], "proj_conv"),
                                 has_projection)
            else:
                self.assertTrue(hasattr(block_group.blocks[0], "proj_conv"))

            for block in block_group.blocks[1:]:
                self.assertFalse(hasattr(block, "proj_conv"))
class HaikuTransformsTest(parameterized.TestCase):

  @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                        test_utils.named_bools('init'))
  def test_hk_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
      init: bool,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, jit=False):
      mod = module_fn()
      if jit:
        mod = hk.jit(mod)
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    # NOTE: We shard init/apply tests since some modules are expensive to jit
    # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
    if init:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.init)(rng, x),
                        f.init(rng, x, jit=True))

    else:
      params, state = f.init(rng, x)
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x, jit=True))

  @test_utils.combined_named_parameters(
      # TODO(tomhennigan) Enable once grad for _scan_transpose implemented.
      set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES))
  def test_hk_remat(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, remat=False):
      mod = module_fn()
      if remat:
        mod = hk.remat(mod)
      return jnp.mean(mod(x))

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
    grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                             has_aux=True)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      grad_jax_remat(params, state, rng, x),
                      grad_hk_remat(params, state, rng, x))
Example #8
0
class HaikuTransformsTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                          test_utils.named_bools('init'))
    def test_hk_jit(
        self,
        module_fn: descriptors.ModuleFn,
        shape: Shape,
        dtype: DType,
        init: bool,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, jit=False):
            mod = module_fn()
            if jit:
                mod = hk.jit(mod)
            return mod(x)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-4)

        # NOTE: We shard init/apply tests since some modules are expensive to jit
        # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
        if init:
            jax.tree_multimap(assert_allclose,
                              jax.jit(f.init)(rng, x), f.init(rng, x,
                                                              jit=True))

        else:
            params, state = f.init(rng, x)
            jax.tree_multimap(assert_allclose,
                              jax.jit(f.apply)(params, state, rng, x),
                              f.apply(params, state, rng, x, jit=True))

    @test_utils.combined_named_parameters(
        # TODO(tomhennigan) Enable once grad for _scan_transpose implemented.
        set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES))
    def test_hk_remat(
        self,
        module_fn: descriptors.ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, remat=False):
            mod = module_fn()
            if remat:
                mod = hk.remat(mod)
            out = mod(x)
            if isinstance(out, dict):
                out = out['loss']
            return jnp.mean(out)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
        grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                                 has_aux=True)

        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose,
                          grad_jax_remat(params, state, rng, x),
                          grad_hk_remat(params, state, rng, x))

    @test_utils.combined_named_parameters(descriptors.ALL_MODULES)
    def test_profiler_name_scopes(
        self,
        module_fn: descriptors.ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        if not hasattr(xla.xb, 'parameter'):
            self.skipTest('Need Jaxlib version > 0.1.45')

        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, name_scopes=False):
            hk.experimental.profiler_name_scopes(enabled=name_scopes)
            mod = module_fn()
            return mod(x)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x),
                          f.apply(params, state, rng, x, name_scopes=True))

        # TODO(lenamartens): flip to True when default changes
        hk.experimental.profiler_name_scopes(enabled=False)
Example #9
0
class ResnetTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    @test_utils.transform_and_run
    def test_simple(self, resnet_v2, bottleneck):
        image = jnp.ones([2, 64, 64, 3])
        model = resnet.ResNet([1, 1, 1, 1],
                              10,
                              resnet_v2=resnet_v2,
                              bottleneck=bottleneck)

        for is_training in (True, False):
            logits = model(image, is_training=is_training)
            self.assertEqual(logits.shape, (2, 10))

    @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"),
                                          test_utils.named_bools("bottleneck"))
    def test_local_stats(self, resnet_v2, bottleneck):
        def forward_fn(image):
            model = resnet.ResNet([1, 1, 1, 1],
                                  10,
                                  resnet_v2=resnet_v2,
                                  bottleneck=bottleneck)
            return model(image, is_training=False, test_local_stats=True)

        forward = transform.transform(forward_fn)
        rng = jax.random.PRNGKey(42)
        image = jnp.ones([2, 64, 64, 3])
        params = forward.init(rng, image)
        logits = forward.apply(params, None, image)
        self.assertEqual(logits.shape, (2, 10))

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_block_list(self, list_length):
        block_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "blocks_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})

    @parameterized.parameters(3, 5)
    @test_utils.transform_and_run
    def test_error_incorrect_args_channel_list(self, list_length):
        channel_list = [i for i in range(list_length)]
        with self.assertRaisesRegex(
                ValueError,
                "channels_per_group` must be of length 4 not {}".format(
                    list_length)):
            resnet.ResNet([1, 1, 1, 1],
                          10, {
                              "decay_rate": 0.9,
                              "eps": 1e-5
                          },
                          channels_per_group=channel_list)

    @test_utils.combined_named_parameters(
        [(i, (getattr(resnet, i), n))
         for i, n in zip(_RESNETS, _RESNET_NUM_PARAMS)],
        test_utils.named_bools("resnet_v2"),
    )
    def test_num_params(self, resnet_class_and_num_params, resnet_v2):
        resnet_class, expected_num_params = resnet_class_and_num_params

        def model_func(img):
            model = resnet_class(1000, resnet_v2=resnet_v2)
            return model(img, is_training=True)

        model = hk.transform_with_state(model_func)
        image = jnp.ones([2, 64, 64, 3])
        rng = jax.random.PRNGKey(0)
        params, _ = model.init(rng, image)
        num_params = sum(
            np.prod(p.shape).item() for p in jax.tree_leaves(params))
        self.assertGreater(num_params, int(0.998 * expected_num_params))
        self.assertLess(num_params, int(1.002 * expected_num_params))

    @test_utils.combined_named_parameters(
        [(i, (getattr(resnet, i), p))
         for i, p in zip(_RESNETS, _RESNET_HAS_PROJECTION)],
        test_utils.named_bools("resnet_v2"),
    )
    @test_utils.transform_and_run
    def test_has_projection(self, resnet_class_and_has_projection, resnet_v2):
        resnet_class, has_projection = resnet_class_and_has_projection
        model = resnet_class(1000, resnet_v2=resnet_v2)
        for i, block_group in enumerate(model.block_groups):
            if i == 0:
                self.assertEqual(hasattr(block_group.blocks[0], "proj_conv"),
                                 has_projection)
            else:
                self.assertTrue(hasattr(block_group.blocks[0], "proj_conv"))

            for block in block_group.blocks[1:]:
                self.assertFalse(hasattr(block, "proj_conv"))

    @test_utils.combined_named_parameters(
        [(i, getattr(resnet, i)) for i in _RESNETS],
        test_utils.named_bools("resnet_v2"),
    )
    def test_logits_config(self, resnet_class, resnet_v2):
        def model_func_logits_config_default(img):
            model = resnet_class(1000, resnet_v2=resnet_v2)
            return model(img, is_training=True)

        def model_func_logits_config_modified(img):
            model = resnet_class(1000,
                                 resnet_v2=resnet_v2,
                                 logits_config=dict(w_init=jnp.ones))
            return model(img, is_training=True)

        image = jnp.ones([2, 64, 64, 3])
        rng = jax.random.PRNGKey(0)

        model = hk.transform_with_state(model_func_logits_config_default)
        params, _ = model.init(rng, image)
        logits_keys = [k for k in params.keys() if "/logits" in k]
        self.assertLen(logits_keys, 1)

        # Check logits params are zeros
        w_logits = params[logits_keys[0]]["w"]
        np.testing.assert_allclose(jnp.zeros_like(w_logits), w_logits)

        model = hk.transform_with_state(model_func_logits_config_modified)
        params, _ = model.init(rng, image)

        # Check logits params are ones
        w_logits = params[logits_keys[0]]["w"]
        np.testing.assert_allclose(jnp.ones_like(w_logits), w_logits)

    @test_utils.combined_named_parameters(
        [(i, getattr(resnet, i)) for i in _RESNETS], )
    @test_utils.transform_and_run
    def test_initial_conv_config(self, resnet_cls):
        config = dict(name="custom_name",
                      output_channels=32,
                      kernel_shape=(3, 3),
                      stride=(1, 1),
                      padding="VALID",
                      with_bias=True)
        net = resnet_cls(1000, initial_conv_config=config)
        for key, value in config.items():
            self.assertEqual(getattr(net.initial_conv, key), value)
Example #10
0
class StatefulTest(parameterized.TestCase):

  @test_utils.transform_and_run
  def test_grad(self):
    x = jnp.array(3.)
    g = stateful.grad(SquareModule())(x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

  def test_grad_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
      stateful.grad(jnp.square)(x)

  @test_utils.transform_and_run
  def test_value_and_grad(self):
    x = jnp.array(2.)
    y, g = stateful.value_and_grad(SquareModule())(x)
    self.assertEqual(y, x ** 2)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

  def test_value_and_grad_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
      stateful.value_and_grad(jnp.square)(x)

  @test_utils.transform_and_run
  def test_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    g, aux = stateful.grad(f, has_aux=True)(x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)

  @test_utils.transform_and_run
  def test_value_and_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x)
    self.assertEqual(y, jnp.power(x, 2))
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)

  def test_grad_and_jit(self):
    def f(x):
      g = stateful.grad(SquareModule())(x)
      return g

    x = jnp.array(3.)
    f = transform.transform_with_state(f)
    params, state = jax.jit(f.init)(None, x)
    g, state = jax.jit(f.apply)(params, state, None, x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-3)

  def test_value_and_grad_and_jit(self):
    def f(x):
      y, g = stateful.value_and_grad(SquareModule())(x)
      return y, g

    x = jnp.array(3.)
    f = transform.transform_with_state(f)
    params, state = jax.jit(f.init)(None, x)
    (y, g), state = jax.jit(f.apply)(params, state, None, x)
    np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-3)

  @test_utils.transform_and_run
  def test_jit(self):
    mod = SquareModule()
    x = jnp.array(2)
    y = stateful.jit(mod)(x)
    self.assertEqual(y, x ** 2)

  def test_jit_no_transform(self):
    x = jnp.array(2)
    with self.assertRaises(ValueError, msg="Use jax.jit() instead"):
      stateful.jit(jnp.square)(x)

  @test_utils.transform_and_run
  def test_remat(self):
    forward, backward = [], []
    callback = _callback_prim(lambda: forward.append(None),
                              lambda: backward.append(None))

    def test(remat):
      x = jnp.array(3.)
      mod = CountingModule()
      self.assertEqual(mod.count, 0)
      f = lambda x: callback(mod(x))
      if remat:
        f = stateful.remat(f)
      y, g = stateful.value_and_grad(f)(x)
      np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
      np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
      self.assertEqual(mod.count, 1)
      num_forward = len(forward)
      num_backward = len(backward)
      del forward[:], backward[:]
      return num_forward, num_backward

    # Sanity check.
    self.assertEqual(test(remat=True), test(remat=True))
    self.assertEqual(test(remat=False), test(remat=False))

    # NOTE: JAX does not guarantee to execute primitives once and only once for
    # a given function (we observe f=2,b=1 without remat and f=5,b=1 with
    # remat), but we do expect that JAX will execute our primitive forward at
    # least one more time with remat than without it.
    num_forward_remat, num_backward_remat = test(remat=True)
    num_forward_no_remat, num_backward_no_remat = test(remat=False)
    self.assertGreater(num_forward_remat, num_forward_no_remat)
    self.assertEqual(num_backward_remat, num_backward_no_remat)

  def test_remat_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.remat() instead"):
      stateful.remat(jnp.square)(x)

  @test_utils.combined_named_parameters(
      test_utils.named_bools("jax_remat"),
      test_utils.named_bools("inline_hk_remat"))
  def test_create_module_inside_remat(self, jax_remat, inline_hk_remat):
    log = []
    def forward(x):
      def create_and_use_layer(x):
        m = SquareModule(name="layer")
        log.append(m.module_name)
        return m(x)

      if not inline_hk_remat:
        create_and_use_layer = stateful.remat(create_and_use_layer)

      for _ in range(2):
        if inline_hk_remat:
          x = stateful.remat(create_and_use_layer)(x)
        else:
          x = create_and_use_layer(x)
      return x

    def reset():
      del log[:]
      self.assertEmpty(log)

    # Test forward.
    x = jnp.float32(3)
    forward = transform.transform_with_state(forward)
    params, state = forward.init(None, x)
    self.assertEqual(log, ["layer", "layer_1"])
    reset()

    # Test backward.
    for _ in range(3):
      grad_fn = jax.grad(lambda x: forward.apply(params, state, None, x)[0])
      if jax_remat:
        grad_fn = jax.remat(grad_fn)
      self.assertEqual(int(grad_fn(x)), int(4 * (x ** 3)))
      self.assertEqual(log, ["layer", "layer_1"])
      reset()

  @parameterized.parameters(True, False)
  def test_cond(self, single_arg):
    def f(x):
      mod = SquareModule()
      if single_arg:
        return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x)
      else:
        return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))

    f = transform.transform_with_state(f)
    for x, y in ((1, 4), (2, 4), (3, 16)):
      x, y = map(jnp.array, (x, y))
      params, state = f.init(None, x)
      out, state = f.apply(params, state, None, x)
      self.assertEqual(state, {"square_module": {"y": y}})
      self.assertEqual(out, y)

  @test_utils.transform_and_run
  def test_cond_traces_branches_with_same_id_once(self):
    witness = []
    def f(x):
      witness.append(None)
      return x ** 2

    stateful.cond(False, f, f, 0)
    hk_call_count = len(witness)
    self.assertEqual(hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del witness[:]
    jax.lax.cond(False, f, f, 0)
    jax_call_count = len(witness)
    self.assertEqual(hk_call_count, jax_call_count)

  def test_cond_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.cond() instead"):
      stateful.cond(x == 2, x, jnp.square, x, lambda x: jnp.square(x + 1))

  def test_switch(self):
    def f(i, x):
      mod = SquareModule()
      branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)]
      return stateful.switch(i, branches, x)

    f = transform.transform_with_state(f)
    for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)):
      i, x, y = map(jnp.array, (i, x, y))
      params, state = f.init(None, i, x)
      out, state = f.apply(params, state, None, i, x)
      self.assertEqual(state, {"square_module": {"y": y}})
      self.assertEqual(out, y)

  @parameterized.parameters(1, 2, 4, 8)
  @test_utils.transform_and_run
  def test_switch_traces_cases_with_same_id_once(self, n):
    f_witness = []
    g_witness = []

    def f(x):
      f_witness.append(None)
      return x ** 2

    def g(x):
      g_witness.append(None)
      return x ** 2

    stateful.switch(0, [f, g] * n, 2)
    f_hk_call_count = len(f_witness)
    g_hk_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, 1)
    self.assertEqual(g_hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del f_witness[:], g_witness[:]
    jax.lax.switch(0, [f, g] * n, 2)
    f_jax_call_count = len(f_witness)
    g_jax_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, f_jax_call_count)
    self.assertEqual(f_hk_call_count, g_jax_call_count)

  def test_switch_no_transform(self):
    i = jnp.array(2)
    x = jnp.array(42.)
    with self.assertRaises(ValueError, msg="Use jax.switch() instead"):
      stateful.switch(i, [jnp.square] * 3, x)

  @test_utils.transform_and_run
  def test_difference_empty(self):
    before = stateful.internal_state()
    after = stateful.internal_state()
    self.assertEmpty(jax.tree_leaves(stateful.difference(before, after)))

  @parameterized.parameters(base.get_parameter, base.get_state)
  @test_utils.transform_and_run(run_apply=False)
  def test_difference_new(self, get_x):
    get_x("a", [], init=jnp.zeros)
    before = stateful.internal_state()
    b = get_x("b", [], init=jnp.zeros)
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    if get_x == base.get_state:
      self.assertEmpty(diff.params)
      self.assertEqual(diff.state, {"~": {"a": None,
                                          "b": base.StatePair(b, b)}})
    else:
      self.assertEqual(diff.params, {"~": {"a": None, "b": b}})
      self.assertEmpty(diff.state)
    self.assertIsNone(diff.rng)

  @test_utils.transform_and_run(run_apply=False)
  def test_difference_update_state(self):
    base.get_state("a", [], init=jnp.zeros)
    base.get_state("b", [], init=jnp.zeros)
    before = stateful.internal_state()
    base.set_state("b", jnp.ones([]))
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    self.assertEmpty(diff.params)
    self.assertEqual(diff.state, {"~": {"a": None,
                                        "b": base.StatePair(0., 1.)}})
    self.assertIsNone(diff.rng)

  @test_utils.transform_and_run(run_apply=False)
  def test_difference_rng(self):
    before = stateful.internal_state()
    base.next_rng_key()
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    self.assertEmpty(diff.params)
    self.assertEmpty(diff.state)
    self.assertIsNotNone(diff.rng)

  def test_scan_no_transform(self):
    xs = jnp.arange(3)
    with self.assertRaises(ValueError, msg="Use jax.scan() instead"):
      stateful.scan(lambda c, x: (c, x), (), xs)

  @parameterized.parameters(0, 1, 2, 4, 8)
  def test_scan_with_state(self, unroll_length):
    def f(xs):
      m = CountingModule()
      def sf(c, x):
        self.assertEqual(c, ())
        return c, m(x)
      _, ys = stateful.scan(sf, (), xs)
      return ys

    f = transform.transform_with_state(f)
    key = jax.random.PRNGKey(42)
    xs = jnp.arange(unroll_length)
    params, state = f.init(key, xs)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4)
    ys, state = f.apply(params, state, key, xs)
    np.testing.assert_allclose(state["counting_module"]["count"], unroll_length,
                               rtol=1e-4)
    np.testing.assert_allclose(ys, xs ** 2, rtol=1e-4)

  @parameterized.parameters(0, 1, 2, 8)
  @test_utils.transform_and_run
  def test_stateful_scan_with_rng_use(self, iteration_count):
    # TODO(lenamartens): remove when default changes to > 1.
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    def body_fun(c, x):
      for _ in range(10):
        _ = base.next_rng_key()
      return c, x
    base.reserve_rng_keys(5)
    _ = stateful.scan(body_fun, (), (), length=iteration_count)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @parameterized.parameters(0, 1, 2, 8)
  @test_utils.transform_and_run
  def test_stateful_fori_with_rng_use(self, iteration_count):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    def body_fun(_, x):
      for _ in range(10):
        _ = base.next_rng_key()
      return x
    base.reserve_rng_keys(5)
    _ = stateful.fori_loop(0, iteration_count, body_fun, 1)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @test_utils.transform_and_run
  def test_stateful_cond_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def true_branch(x):
      _ = base.next_rng_key()
      return x

    def false_branch(x):
      _ = base.next_rng_key()
      _ = base.next_rng_key()
      return x

    base.reserve_rng_keys(5)
    _ = stateful.cond(True, true_branch, false_branch, 0)
    _ = stateful.cond(False, true_branch, false_branch, 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @test_utils.transform_and_run
  def test_stateful_switch_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def branch_f(i):
      for _ in range(i):
        _ = base.next_rng_key()
      return i

    base.reserve_rng_keys(5)
    branches = [lambda _, i=i: branch_f(i) for i in range(5)]
    self.assertEqual(stateful.switch(3, branches, None), 3)
    self.assertEqual(stateful.switch(0, branches, None), 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @parameterized.parameters(*it.product((0, 1, 2, 4, 8), (1, 2, 3)))
  @test_utils.transform_and_run
  def test_fori(self, lower, n):
    upper = lower + n
    m = CountingModule()
    y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2)
    self.assertEqual(y, jnp.square(upper - 1))
    self.assertEqual(m.count, upper - lower)

  @test_utils.transform_and_run
  def test_fori_traced_length(self):
    m = CountingModule()

    def f(lower, upper):
      y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2)
      return y

    # Because of the jit, lower and upper will be tracers.
    out = stateful.jit(f)(0, 3)
    self.assertEqual(out, 4)
    self.assertEqual(m.count, 3)

  def test_vmap(self):
    def g(x):
      return CountingModule()(x)

    def f(x):
      return stateful.vmap(g)(x)

    f = transform.transform_with_state(f)

    x = jnp.ones([4]) + 1
    params, state = f.init(None, x)

    # State should not be mapped.
    self.assertEmpty(params)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 0)

    # The output should be mapped but state should not be.
    y, state = f.apply(params, state, None, x)
    self.assertEqual(y.shape, (4,))
    np.testing.assert_allclose(y, x ** 2)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 1)

  def test_while_loop_rejected_in_init(self):
    def f():
      stateful.while_loop(lambda x: x.all(), lambda x: not x, 1)
    f = transform.transform(f)
    with self.assertRaisesRegex(
        ValueError, "hk.while_loop does not support initialization"):
      f.init(None)

  def test_updating_state_in_cond_fails(self):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        stateful.while_loop(m, lambda x: x, x)

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    with self.assertRaisesRegex(
        ValueError,
        "does not support.*set_state.*next_rng_key.*in.*cond_fun`"):
      f.apply(params, state, None, x)

  def test_rng_in_cond_fails(self):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x)

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    with self.assertRaisesRegex(
        ValueError,
        "does not support.*set_state.*next_rng_key.*in.*cond_fun`"):
      f.apply(params, state, jax.random.PRNGKey(42), x)

  @parameterized.parameters(0, 1, 2, 4, 8)
  def test_while_loop_with_state(self, iters):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        _, y = stateful.while_loop(lambda a: a[0] < iters,
                                   lambda a: (a[0] + 1, m(a[1])),
                                   (0, x))
        return y

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], x, rtol=1e-4)

    y, state = f.apply(params, state, None, x)
    np.testing.assert_allclose(state["counting_module"]["count"], iters,
                               rtol=1e-4)
    np.testing.assert_allclose(y, iters, rtol=1e-4)

  def test_named_call(self):
    def f(x):
      return stateful.named_call(SquareModule(), name="square")(x)

    x = jnp.array(2.)
    rng = jax.random.PRNGKey(42)
    init, apply = transform.transform_with_state(f)
    params, state = init(rng, x)
    y, state = jax.jit(apply)(params, state, rng, x)
    self.assertEqual(y, x ** 2)

  @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat)
  def test_named_call_jax_transforms(self, jax_transform):
    f = jnp.sum
    x = jnp.array([1.])

    unnamed_out = jax_transform(f)(x)
    named_out = jax_transform(stateful.named_call(f, name="test"))(x)

    self.assertEqual(unnamed_out, named_out)

  def test_static_argnums_named_call(self):
    f = stateful.named_call(lambda x, y: y if x else None,
                            name="test")
    f = jax.jit(f, static_argnums=(0,))
    out = f(True, 5)
    self.assertEqual(out, 5)

  def test_named_call_non_jaxtype_arg(self):
    # For the test to fail without the invalid JaxType filter we need to pass
    # in a valid JaxType that forces the invalid Jaxtype to be raised to an
    # abstract value.
    def f(not_a_jaxtype, a_jaxtype):
      # then Jax needs to try and evaluate the abstractified non-JaxType
      if not_a_jaxtype:
        return a_jaxtype
      return 0

    f = stateful.named_call(f, name="test")
    out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1)
    self.assertEqual(out, 1)

  @parameterized.parameters("hi", None, object(), object)
  def test_named_call_non_jaxtype_result(self, non_jaxtype):
    def fun_with_non_jaxtype_output(x, non_jaxtype):
      return x, non_jaxtype

    def jitted_fun(x, non_jaxtype):
      named_fun = stateful.named_call(fun_with_non_jaxtype_output)
      # The non-jaxtype is returned out of named_call (which is supported),
      # but is not returned out of the jit (which should not be supported).
      x, non_jaxtype_out = named_fun(x, non_jaxtype)
      self.assertEqual(non_jaxtype_out, non_jaxtype)
      return x

    jitted_fun = jax.jit(jitted_fun, static_argnums=1)
    self.assertEqual(jitted_fun(0, non_jaxtype), 0)

  def test_named_call_partial_function(self):
    f = stateful.named_call(lambda x, y: y if x else None)
    f = jax.jit(functools.partial(f, True))
    out = f(5)
    self.assertEqual(out, 5)

  def test_named_call_default_name(self):
    @stateful.named_call
    def naming_things_is_hard(x):
      return x ** 2

    @jax.jit
    def f(x):
      return naming_things_is_hard(x) + naming_things_is_hard(x)

    c = jax.xla_computation(f)(2)
    self.assertIn("naming_things_is_hard", c.as_hlo_text())

  def test_eval_shape(self):
    def some_shape_changing_fun(x):
      return x[0, :]

    def f(x):
      m = CountingModule(op=some_shape_changing_fun)
      # state is not changed in this call
      out_shape_struct = stateful.eval_shape(m, x)
      return m(x), out_shape_struct

    f = transform.transform_with_state(f)
    key = jax.random.PRNGKey(42)
    in_shape = (10, 10)
    x = jnp.ones(in_shape)
    params, state = f.init(key, x)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4)
    (out, shape_struct), state = f.apply(params, state, key, x)
    # Count is only advanced once
    np.testing.assert_allclose(state["counting_module"]["count"], 1, rtol=1e-4)
    np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4)
    self.assertEqual(shape_struct.shape, (in_shape[1],))

  def test_eval_shape_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"):
      stateful.eval_shape(jnp.square)(x)

  @test_utils.transform_and_run
  def test_temporary_state_resets_names(self):
    with stateful.temporary_internal_state(stateful.internal_state()):
      mod1 = module.Module(name="foo")
    mod2 = module.Module(name="foo")
    self.assertEqual(mod1.module_name, "foo")
    self.assertEqual(mod2.module_name, "foo")

  @test_utils.transform_and_run(run_apply=False)
  def test_eval_shape_no_leaked_tracers_under_leak_checker(self):
    with jax.checking_leaks():
      stateful.eval_shape(SquareModule(), jnp.ones(()))  # does not crash
Example #11
0
class HaikuTransformsTest(parameterized.TestCase):
    @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                          test_utils.named_bools('init'))
    def test_hk_jit(self, module_fn: ModuleFn, shape, dtype, init):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, jit=False):
            mod = module_fn()
            if jit:
                mod = hk.jit(mod)
            return mod(x)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-4)

        # NOTE: We shard init/apply tests since some modules are expensive to jit
        # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
        if init:
            jax.tree_multimap(assert_allclose,
                              jax.jit(f.init)(rng, x), f.init(rng, x,
                                                              jit=True))

        else:
            params, state = f.init(rng, x)
            jax.tree_multimap(assert_allclose,
                              jax.jit(f.apply)(params, state, rng, x),
                              f.apply(params, state, rng, x, jit=True))

    @test_utils.combined_named_parameters(descriptors.ALL_MODULES,
                                          test_utils.named_bools('init'))
    def test_hk_scan(self, module_fn: descriptors.ModuleFn, shape, dtype,
                     init):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def f(x):
            mod = module_fn()
            return mod(x)

        def u_f(xs):
            mod = module_fn()

            def s(carry, x):
                y = mod(x)
                return carry, y

            _, ys = hk.scan(s, (), xs)
            return ys

        u_f = hk.transform_with_state(u_f)
        f = hk.transform_with_state(f)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-4)
        xs = jnp.broadcast_to(x, (8, ) + x.shape)
        params, state = f.init(rng, x)

        if init:
            u_params, u_state = u_f.init(rng, xs)
            jax.tree_multimap(assert_allclose, u_params, params)
            jax.tree_multimap(assert_allclose, u_state, state)
            return

        def fun(state, x):
            y, state = f.apply(params, state, rng, x)
            return state, y

        s_state, s_ys = jax.lax.scan(fun, state, xs)
        u_ys, u_state = u_f.apply(params, state, rng, xs)

        jax.tree_multimap(assert_allclose, u_ys, s_ys)
        jax.tree_multimap(assert_allclose, u_state, s_state)

    @test_utils.combined_named_parameters(
        # TODO(tomhennigan) Enable once grad for _scan_transpose implemented.
        set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES))
    def test_hk_remat(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, remat=False):
            mod = module_fn()
            if remat:
                mod = hk.remat(mod)
            out = mod(x)
            if isinstance(out, dict):
                out = out['loss']
            return jnp.mean(out)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
        grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                                 has_aux=True)

        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose,
                          grad_jax_remat(params, state, rng, x),
                          grad_hk_remat(params, state, rng, x))

    @test_utils.combined_named_parameters(descriptors.ALL_MODULES)
    def test_profiler_name_scopes(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x, name_scopes=False):
            hk.experimental.profiler_name_scopes(enabled=name_scopes)
            mod = module_fn()
            return mod(x)

        f = hk.transform_with_state(g)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x),
                          f.apply(params, state, rng, x, name_scopes=True))

        # TODO(lenamartens): flip to True when default changes
        hk.experimental.profiler_name_scopes(enabled=False)

    @test_utils.combined_named_parameters(descriptors.ALL_MODULES)
    def test_optimize_rng_use_under_jit(self, module_fn: ModuleFn, shape,
                                        dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(hk.experimental.optimize_rng_use(g))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        params, state = jax.jit(f.init)(rng, x)
        jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x))

        if module_type in (hk.nets.VectorQuantizer,
                           hk.nets.VectorQuantizerEMA):
            # For stochastic modules just test apply runs.
            jax.device_get(jax.jit(f.apply)(params, state, rng, x))

        else:
            jax.tree_multimap(assert_allclose,
                              jax.jit(f.apply)(params, state, rng, x),
                              f.apply(params, state, rng, x))

    @test_utils.combined_named_parameters(descriptors.OPTIONAL_BATCH_MODULES)
    def test_vmap(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        # Expand our input since we will map over it.
        x = jnp.broadcast_to(x, (2, ) + x.shape)

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
        f_mapped = hk.transform_with_state(
            lambda x: hk.vmap(lambda x: module_fn()(x))(x))  # pylint: disable=unnecessary-lambda

        params, state = f_mapped.init(rng, x)

        # JAX vmap with explicitly unmapped params/state/rng. This should be
        # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map
        # params/state/rng).
        v_apply = jax.vmap(f.apply,
                           in_axes=(None, None, None, 0),
                           out_axes=(0, None))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)
        jax.tree_multimap(assert_allclose,
                          f_mapped.apply(params, state, rng, x),
                          v_apply(params, state, rng, x))
Example #12
0
class BatchNormTest(parameterized.TestCase):

  @test_utils.transform_and_run
  def test_basic(self):
    data = jnp.arange(2 * 3 * 4, dtype=jnp.float32).reshape([2, 3, 4])

    norm = batch_norm.BatchNorm(True, True, 0.9)
    result = norm(data, is_training=True)
    result_0_replicated = jnp.broadcast_to(result[:, :, :1], result.shape)
    # Input data is symmetrical variance per-channel.
    np.testing.assert_allclose(result, result_0_replicated)
    # Running through again in test mode produces same output.
    np.testing.assert_allclose(norm(data, is_training=False), result, rtol=2e-2)

  @test_utils.transform_and_run
  def test_simple_training(self):
    layer = batch_norm.BatchNorm(
        create_scale=False, create_offset=False, decay_rate=0.9)

    inputs = np.ones([2, 3, 3, 5])
    scale = np.full((5,), 0.5)
    offset = np.full((5,), 2.0)

    result = layer(inputs, True, scale=scale, offset=offset)
    np.testing.assert_equal(result, np.full(inputs.shape, 2.0))

  @test_utils.transform_and_run
  def test_simple_training_nchw(self):
    layer = batch_norm.BatchNorm(
        create_scale=False,
        create_offset=False,
        decay_rate=0.9,
        data_format="NCHW")

    inputs = np.ones([2, 5, 3, 3])
    scale = np.full((5, 1, 1), 0.5)
    offset = np.full((5, 1, 1), 2.0)

    result = layer(inputs, True, scale=scale, offset=offset)
    np.testing.assert_equal(result, np.full(inputs.shape, 2.0))

  @test_utils.transform_and_run
  def test_simple_training_normalized_axes(self):
    layer = batch_norm.BatchNorm(
        create_scale=False,
        create_offset=False,
        decay_rate=0.9,
        axis=[0, 2, 3])  # Not the second axis.

    # This differs only in the second axis.
    inputs = np.stack([2.0 * np.ones([5, 3, 3]), np.ones([5, 3, 3])], 1)

    result = layer(inputs, True)

    # Despite not all values being identical, treating slices from the first
    # axis separately leads to a fully normalized = equal array.
    np.testing.assert_equal(result, np.zeros(inputs.shape))

  @test_utils.transform_and_run
  def test_no_scale_and_offset(self):
    layer = batch_norm.BatchNorm(
        create_scale=False, create_offset=False, decay_rate=0.9)

    inputs = jnp.ones([2, 5, 3, 3, 3])
    result = layer(inputs, True)
    np.testing.assert_equal(result, np.zeros_like(inputs))

  @test_utils.transform_and_run
  def test_no_scale_and_init_provided(self):
    with self.assertRaisesRegex(
        ValueError, "Cannot set `scale_init` if `create_scale=False`"):
      batch_norm.BatchNorm(
          create_scale=False,
          create_offset=True,
          decay_rate=0.9,
          scale_init=jnp.ones)

  @test_utils.transform_and_run
  def test_no_offset_beta_init_provided(self):
    with self.assertRaisesRegex(
        ValueError, "Cannot set `offset_init` if `create_offset=False`"):
      batch_norm.BatchNorm(
          create_scale=True,
          create_offset=False,
          decay_rate=0.9,
          offset_init=jnp.zeros)

  @test_utils.combined_named_parameters(
      test_utils.named_bools("is_training"),
      test_utils.named_bools("test_local_stats"))
  def test_inits_ema_not_is_training(self, is_training, test_local_stats):
    def f(x):
      net = batch_norm.BatchNorm(True, True, 0.9)
      return net(x, is_training, test_local_stats)
    f = transform.transform_with_state(f)
    x = jnp.ones([])
    _, state = f.init(None, x)
    if not is_training and test_local_stats:
      self.assertEmpty(state)
    else:
      self.assertLen(state, 2)
      for ema_name in ("mean_ema", "var_ema"):
        self.assertEqual(set(state[f"batch_norm/~/{ema_name}"]),
                         {"counter", "average", "hidden"})
Example #13
0
class StatefulTest(parameterized.TestCase):

  @test_utils.transform_and_run
  def test_grad(self):
    x = jnp.array(3.)
    g = stateful.grad(SquareModule())(x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

  def test_grad_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
      stateful.grad(jnp.square)(x)

  @test_utils.transform_and_run
  def test_value_and_grad(self):
    x = jnp.array(2.)
    y, g = stateful.value_and_grad(SquareModule())(x)
    self.assertEqual(y, x ** 2)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

  def test_value_and_grad_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
      stateful.value_and_grad(jnp.square)(x)

  @test_utils.transform_and_run
  def test_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    g, aux = stateful.grad(f, has_aux=True)(x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)

  @test_utils.transform_and_run
  def test_value_and_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x)
    self.assertEqual(y, jnp.power(x, 2))
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)

  def test_grad_and_jit(self):
    def f(x):
      g = stateful.grad(SquareModule())(x)
      return g

    x = jnp.array(3.)
    f = transform.transform_with_state(f)
    params, state = jax.jit(f.init)(None, x)
    g, state = jax.jit(f.apply)(params, state, None, x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-3)

  def test_value_and_grad_and_jit(self):
    def f(x):
      y, g = stateful.value_and_grad(SquareModule())(x)
      return y, g

    x = jnp.array(3.)
    f = transform.transform_with_state(f)
    params, state = jax.jit(f.init)(None, x)
    (y, g), state = jax.jit(f.apply)(params, state, None, x)
    np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-3)

  @test_utils.transform_and_run
  def test_jit(self):
    mod = SquareModule()
    x = jnp.array(2)
    y = stateful.jit(mod)(x)
    self.assertEqual(y, x ** 2)

  def test_jit_no_transform(self):
    x = jnp.array(2)
    with self.assertRaises(ValueError, msg="Use jax.jit() instead"):
      stateful.jit(jnp.square)(x)

  @test_utils.transform_and_run
  def test_remat(self):
    forward, backward = [], []
    callback = _callback_prim(lambda: forward.append(None),
                              lambda: backward.append(None))

    def test(remat):
      x = jnp.array(3.)
      mod = CountingModule()
      self.assertEqual(mod.count, 0)
      f = lambda x: callback(mod(x))
      if remat:
        f = stateful.remat(f)
      y, g = stateful.value_and_grad(f)(x)
      np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
      np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
      self.assertEqual(mod.count, 1)
      num_forward = len(forward)
      num_backward = len(backward)
      del forward[:], backward[:]
      return num_forward, num_backward

    # Sanity check.
    self.assertEqual(test(remat=True), test(remat=True))
    self.assertEqual(test(remat=False), test(remat=False))

    # NOTE: JAX does not guarantee to execute primitives once and only once for
    # a given function (we observe f=2,b=1 without remat and f=5,b=1 with
    # remat), but we do expect that JAX will execute our primitive forward at
    # least one more time with remat than without it.
    num_forward_remat, num_backward_remat = test(remat=True)
    num_forward_no_remat, num_backward_no_remat = test(remat=False)
    self.assertGreater(num_forward_remat, num_forward_no_remat)
    self.assertEqual(num_backward_remat, num_backward_no_remat)

  def test_remat_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.remat() instead"):
      stateful.remat(jnp.square)(x)

  @test_utils.combined_named_parameters(
      test_utils.named_bools("jax_remat"),
      test_utils.named_bools("inline_hk_remat"))
  def test_create_module_inside_remat(self, jax_remat, inline_hk_remat):
    log = []
    def forward(x):
      def create_and_use_layer(x):
        m = SquareModule(name="layer")
        log.append(m.module_name)
        return m(x)

      if not inline_hk_remat:
        create_and_use_layer = stateful.remat(create_and_use_layer)

      for _ in range(2):
        if inline_hk_remat:
          x = stateful.remat(create_and_use_layer)(x)
        else:
          x = create_and_use_layer(x)
      return x

    def reset():
      del log[:]
      self.assertEmpty(log)

    # Test forward.
    x = jnp.float32(3)
    forward = transform.transform_with_state(forward)
    params, state = forward.init(None, x)
    self.assertEqual(log, ["layer", "layer_1"])
    reset()

    # Test backward.
    for _ in range(3):
      grad_fn = jax.grad(lambda x: forward.apply(params, state, None, x)[0])
      if jax_remat:
        grad_fn = jax.remat(grad_fn)
      self.assertEqual(int(grad_fn(x)), int(4 * (x ** 3)))
      self.assertEqual(log, ["layer", "layer_1"])
      reset()

  @parameterized.parameters(True, False)
  def test_cond(self, single_arg):
    def f(x):
      mod = SquareModule()
      if single_arg:
        return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x)
      else:
        return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))

    f = transform.transform_with_state(f)
    for x, y in ((1, 4), (2, 4), (3, 16)):
      x, y = map(jnp.array, (x, y))
      params, state = f.init(None, x)
      out, state = f.apply(params, state, None, x)
      self.assertEqual(state, {"square_module": {"y": y}})
      self.assertEqual(out, y)

  @test_utils.transform_and_run
  def test_cond_traces_branches_with_same_id_once(self):
    witness = []
    def f(x):
      witness.append(None)
      return x ** 2

    stateful.cond(False, f, f, 0)
    hk_call_count = len(witness)
    self.assertEqual(hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del witness[:]
    jax.lax.cond(False, f, f, 0)
    jax_call_count = len(witness)
    self.assertEqual(hk_call_count, jax_call_count)

  @test_utils.transform_and_run
  def test_cond_no_args(self):
    x = stateful.cond(True, lambda: 5, lambda: 4)
    self.assertEqual(x, 5)

  @test_utils.transform_and_run
  def test_cond_operand_kwarg(self):
    x = stateful.cond(True, lambda x: x + 5, lambda x: x + 4, operand=1)
    self.assertEqual(x, 6)

  @test_utils.transform_and_run
  def test_cond_operand_kwarg_and_operands(self):
    with self.assertRaisesRegex(ValueError, "cannot.*pass.*positionally"):
      stateful.cond(True, lambda x: x + 5, lambda x: x + 4, 1, operand=1)

  @test_utils.transform_and_run
  def test_cond_two_args(self):
    a, b = stateful.cond(True,
                         lambda a, b: (b, a),
                         lambda a, b: (a, b),
                         2, 1)
    self.assertEqual(a, 1)
    self.assertEqual(b, 2)

  @test_utils.transform_and_run
  def test_cond_three_args(self):
    a, b, c = stateful.cond(True,
                            lambda a, b, c: (c, b, a),
                            lambda a, b, c: (a, b, c),
                            3, 2, 1)
    self.assertEqual(a, 1)
    self.assertEqual(b, 2)
    self.assertEqual(c, 3)

  def test_cond_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.cond() instead"):
      stateful.cond(x == 2, x, jnp.square, x, lambda x: jnp.square(x + 1))

  def test_switch(self):
    def f(i, x):
      mod = SquareModule()
      branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)]
      return stateful.switch(i, branches, x)

    f = transform.transform_with_state(f)
    for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)):
      i, x, y = map(jnp.array, (i, x, y))
      params, state = f.init(None, i, x)
      out, state = f.apply(params, state, None, i, x)
      self.assertEqual(state, {"square_module": {"y": y}})
      self.assertEqual(out, y)

  @parameterized.parameters(1, 2, 4, 8)
  @test_utils.transform_and_run
  def test_switch_traces_cases_with_same_id_once(self, n):
    f_witness = []
    g_witness = []

    def f(x):
      f_witness.append(None)
      return x ** 2

    def g(x):
      g_witness.append(None)
      return x ** 2

    stateful.switch(0, [f, g] * n, 2)
    f_hk_call_count = len(f_witness)
    g_hk_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, 1)
    self.assertEqual(g_hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del f_witness[:], g_witness[:]
    jax.lax.switch(0, [f, g] * n, 2)
    f_jax_call_count = len(f_witness)
    g_jax_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, f_jax_call_count)
    self.assertEqual(f_hk_call_count, g_jax_call_count)

  def test_switch_no_transform(self):
    i = jnp.array(2)
    x = jnp.array(42.)
    with self.assertRaises(ValueError, msg="Use jax.switch() instead"):
      stateful.switch(i, [jnp.square] * 3, x)

  @test_utils.transform_and_run
  def test_difference_empty(self):
    before = stateful.internal_state()
    after = stateful.internal_state()
    self.assertEmpty(jax.tree_leaves(stateful.difference(before, after)))

  @parameterized.parameters(base.get_parameter, base.get_state)
  @test_utils.transform_and_run(run_apply=False)
  def test_difference_new(self, get_x):
    get_x("a", [], init=jnp.zeros)
    before = stateful.internal_state()
    b = get_x("b", [], init=jnp.zeros)
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    if get_x == base.get_state:
      self.assertEmpty(diff.params)
      self.assertEqual(diff.state, {"~": {"a": None,
                                          "b": base.StatePair(b, b)}})
    else:
      self.assertEqual(diff.params, {"~": {"a": None, "b": b}})
      self.assertEmpty(diff.state)
    self.assertIsNone(diff.rng)

  @test_utils.transform_and_run(run_apply=False)
  def test_difference_update_state(self):
    base.get_state("a", [], init=jnp.zeros)
    base.get_state("b", [], init=jnp.zeros)
    before = stateful.internal_state()
    base.set_state("b", jnp.ones([]))
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    self.assertEmpty(diff.params)
    self.assertEqual(diff.state, {"~": {"a": None,
                                        "b": base.StatePair(0., 1.)}})
    self.assertIsNone(diff.rng)

  @test_utils.transform_and_run(run_apply=False)
  def test_difference_rng(self):
    before = stateful.internal_state()
    base.next_rng_key()
    after = stateful.internal_state()
    diff = stateful.difference(before, after)
    self.assertEmpty(diff.params)
    self.assertEmpty(diff.state)
    self.assertIsNotNone(diff.rng)

  def test_scan_no_transform(self):
    xs = jnp.arange(3)
    with self.assertRaises(ValueError, msg="Use jax.scan() instead"):
      stateful.scan(lambda c, x: (c, x), (), xs)

  @parameterized.parameters(0, 1, 2, 4, 8)
  def test_scan_with_state(self, unroll_length):
    def f(xs):
      m = CountingModule()
      def sf(c, x):
        self.assertEqual(c, ())
        return c, m(x)
      _, ys = stateful.scan(sf, (), xs)
      return ys

    f = transform.transform_with_state(f)
    key = jax.random.PRNGKey(42)
    init_key, apply_key = jax.random.split(key)
    xs = jnp.arange(unroll_length)
    params, state = f.init(init_key, xs)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4)
    ys, state = f.apply(params, state, apply_key, xs)
    np.testing.assert_allclose(state["counting_module"]["count"], unroll_length,
                               rtol=1e-4)
    np.testing.assert_allclose(ys, xs ** 2, rtol=1e-4)

  @parameterized.parameters(0, 1, 2, 8)
  @test_utils.transform_and_run
  def test_stateful_scan_with_rng_use(self, iteration_count):
    # TODO(lenamartens): remove when default changes to > 1.
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    def body_fun(c, x):
      for _ in range(10):
        _ = base.next_rng_key()
      return c, x
    base.reserve_rng_keys(5)
    _ = stateful.scan(body_fun, (), (), length=iteration_count)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @parameterized.parameters(0, 1, 2, 8)
  @test_utils.transform_and_run
  def test_stateful_fori_with_rng_use(self, iteration_count):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    def body_fun(_, x):
      for _ in range(10):
        _ = base.next_rng_key()
      return x
    base.reserve_rng_keys(5)
    _ = stateful.fori_loop(0, iteration_count, body_fun, 1)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @test_utils.transform_and_run
  def test_stateful_cond_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def true_branch(x):
      _ = base.next_rng_key()
      return x

    def false_branch(x):
      _ = base.next_rng_key()
      _ = base.next_rng_key()
      return x

    base.reserve_rng_keys(5)
    _ = stateful.cond(True, true_branch, false_branch, 0)
    _ = stateful.cond(False, true_branch, false_branch, 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @test_utils.transform_and_run
  def test_stateful_switch_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def branch_f(i):
      for _ in range(i):
        _ = base.next_rng_key()
      return i

    base.reserve_rng_keys(5)
    branches = [lambda _, i=i: branch_f(i) for i in range(5)]
    self.assertEqual(stateful.switch(3, branches, None), 3)
    self.assertEqual(stateful.switch(0, branches, None), 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default

  @parameterized.parameters(*it.product((0, 1, 2, 4, 8), (1, 2, 3)))
  @test_utils.transform_and_run
  def test_fori(self, lower, n):
    upper = lower + n
    m = CountingModule()
    y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2)
    self.assertEqual(y, jnp.square(upper - 1))
    self.assertEqual(m.count, upper - lower)

  @test_utils.transform_and_run
  def test_fori_traced_length(self):
    m = CountingModule()

    def f(lower, upper):
      y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2)
      return y

    # Because of the jit, lower and upper will be tracers.
    out = stateful.jit(f)(0, 3)
    self.assertEqual(out, 4)
    self.assertEqual(m.count, 3)

  def test_vmap(self):
    def g(x):
      return CountingModule()(x)

    def f(x):
      return stateful.vmap(g, split_rng=False)(x)

    f = transform.transform_with_state(f)

    x = jnp.ones([4]) + 1
    params, state = f.init(None, x)

    # State should not be mapped.
    self.assertEmpty(params)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 0)

    # The output should be mapped but state should not be.
    y, state = f.apply(params, state, None, x)
    self.assertEqual(y.shape, (4,))
    np.testing.assert_allclose(y, x ** 2)
    cnt, = jax.tree_leaves(state)
    self.assertEqual(cnt.ndim, 0)
    self.assertEqual(cnt, 1)

  def test_vmap_must_be_called_in_transform(self):
    f = stateful.vmap(lambda x: x, split_rng=False)
    with self.assertRaisesRegex(ValueError,
                                "must be used as part of an.*hk.transform"):
      f(0)

  @test_utils.transform_and_run
  def test_vmap_no_in_axes(self):
    def fn_name(_):
      pass
    with self.assertRaisesRegex(
        ValueError, "fn_name must have at least one non-None value in in_axes"):
      stateful.vmap(fn_name, in_axes=None, split_rng=False)

  @test_utils.transform_and_run
  def test_vmap_in_axes_different_size(self):
    x = jnp.ones([1, 2])
    with self.assertRaisesRegex(
        ValueError, "vmap got inconsistent sizes for array axes to be mapped"):
      stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x)

  @test_utils.transform_and_run
  def test_vmap_no_split_rng(self):
    key_before = base.next_rng_key()
    f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False)
    x = jnp.arange(4)
    k1, k2, k3, k4 = f(x)
    key_after = base.next_rng_key()
    np.testing.assert_array_equal(k1, k2)
    np.testing.assert_array_equal(k2, k3)
    np.testing.assert_array_equal(k3, k4)
    self.assertFalse(np.array_equal(key_before, k1))
    self.assertFalse(np.array_equal(key_after, k1))
    self.assertFalse(np.array_equal(key_before, key_after))

  @test_utils.transform_and_run
  def test_vmap_split_rng(self):
    key_before = base.next_rng_key()
    f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True)
    x = jnp.arange(4)
    k1, k2, k3, k4 = f(x)
    key_after = base.next_rng_key()
    # Test that none of the keys are equal.
    named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4),
                  ("key_before", key_before), ("key_after", key_after))
    for (a_name, a), (b_name, b) in it.combinations(named_keys, 2):
      self.assertFalse(
          np.array_equal(a, b),
          msg=f"Keys should not be equal, but {a_name} == {b_name}")

  def test_while_loop_rejected_in_init(self):
    def f():
      stateful.while_loop(lambda x: x.all(), lambda x: not x, 1)
    f = transform.transform(f)
    with self.assertRaisesRegex(
        ValueError, "hk.while_loop does not support initialization"):
      f.init(None)

  def test_updating_state_in_cond_fails(self):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        stateful.while_loop(m, lambda x: x, x)

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    with self.assertRaisesRegex(
        ValueError,
        "does not support.*set_state.*next_rng_key.*in.*cond_fun`"):
      f.apply(params, state, None, x)

  def test_rng_in_cond_fails(self):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x)

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    with self.assertRaisesRegex(
        ValueError,
        "does not support.*set_state.*next_rng_key.*in.*cond_fun`"):
      f.apply(params, state, jax.random.PRNGKey(42), x)

  @parameterized.parameters(0, 1, 2, 4, 8)
  def test_while_loop_with_state(self, iters):
    def f(x):
      m = CountingModule(op=lambda x: x + 1)
      if not base.params_frozen():
        return m(x)
      else:
        _, y = stateful.while_loop(lambda a: a[0] < iters,
                                   lambda a: (a[0] + 1, m(a[1])),
                                   (0, x))
        return y

    f = transform.transform_with_state(f)
    x = jnp.zeros([])
    params, state = f.init(None, x)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], x, rtol=1e-4)

    y, state = f.apply(params, state, None, x)
    np.testing.assert_allclose(state["counting_module"]["count"], iters,
                               rtol=1e-4)
    np.testing.assert_allclose(y, iters, rtol=1e-4)

  def test_named_call(self):
    def f(x):
      return stateful.named_call(SquareModule(), name="square")(x)

    x = jnp.array(2.)
    rng = jax.random.PRNGKey(42)
    init, apply = transform.transform_with_state(f)
    params, state = init(rng, x)
    y, state = jax.jit(apply)(params, state, rng, x)
    self.assertEqual(y, x ** 2)

  @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat)
  def test_named_call_jax_transforms(self, jax_transform):
    f = jnp.sum
    x = jnp.array([1.])

    unnamed_out = jax_transform(f)(x)
    named_out = jax_transform(stateful.named_call(f, name="test"))(x)

    self.assertEqual(unnamed_out, named_out)

  def test_static_argnums_named_call(self):
    f = stateful.named_call(lambda x, y: y if x else None,
                            name="test")
    f = jax.jit(f, static_argnums=(0,))
    out = f(True, 5)
    self.assertEqual(out, 5)

  def test_named_call_non_jaxtype_arg(self):
    # For the test to fail without the invalid JaxType filter we need to pass
    # in a valid JaxType that forces the invalid Jaxtype to be raised to an
    # abstract value.
    def f(not_a_jaxtype, a_jaxtype):
      # then Jax needs to try and evaluate the abstractified non-JaxType
      if not_a_jaxtype:
        return a_jaxtype
      return 0

    f = stateful.named_call(f, name="test")
    out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1)
    self.assertEqual(out, 1)

  @parameterized.parameters("hi", None, object(), object)
  def test_named_call_non_jaxtype_result(self, non_jaxtype):
    def fun_with_non_jaxtype_output(x, non_jaxtype):
      return x, non_jaxtype

    def jitted_fun(x, non_jaxtype):
      named_fun = stateful.named_call(fun_with_non_jaxtype_output)
      # The non-jaxtype is returned out of named_call (which is supported),
      # but is not returned out of the jit (which should not be supported).
      x, non_jaxtype_out = named_fun(x, non_jaxtype)
      self.assertEqual(non_jaxtype_out, non_jaxtype)
      return x

    jitted_fun = jax.jit(jitted_fun, static_argnums=1)
    self.assertEqual(jitted_fun(0, non_jaxtype), 0)

  def test_named_call_partial_function(self):
    f = stateful.named_call(lambda x, y: y if x else None)
    f = jax.jit(functools.partial(f, True))
    out = f(5)
    self.assertEqual(out, 5)

  def test_named_call_default_name(self):
    @stateful.named_call
    def naming_things_is_hard(x):
      return x ** 2

    @jax.jit
    def f(x):
      return naming_things_is_hard(x) + naming_things_is_hard(x)

    c = jax.xla_computation(f)(1.)
    print_opts = jax.xla.xe.HloPrintOptions.short_parsable()
    print_opts.print_metadata = True
    hlo_text = c.as_hlo_module().to_string(print_opts)
    self.assertIn("naming_things_is_hard", hlo_text)

  def test_eval_shape(self):
    def some_shape_changing_fun(x):
      return x[0, :]

    def f(x):
      m = CountingModule(op=some_shape_changing_fun)
      # state is not changed in this call
      out_shape_struct = stateful.eval_shape(m, x)
      return m(x), out_shape_struct

    f = transform.transform_with_state(f)
    key = jax.random.PRNGKey(42)
    in_shape = (10, 10)
    x = jnp.ones(in_shape)
    params, state = f.init(key, x)
    self.assertEqual(list(state), ["counting_module"])
    self.assertEqual(list(state["counting_module"]), ["count"])
    np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4)
    (out, shape_struct), state = f.apply(params, state, key, x)
    # Count is only advanced once
    np.testing.assert_allclose(state["counting_module"]["count"], 1, rtol=1e-4)
    np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4)
    self.assertEqual(shape_struct.shape, (in_shape[1],))

  def test_eval_shape_no_transform(self):
    x = jnp.array(3.)
    with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"):
      stateful.eval_shape(jnp.square)(x)

  @test_utils.transform_and_run
  def test_temporary_state_resets_names(self):
    with stateful.temporary_internal_state(stateful.internal_state()):
      mod1 = module.Module(name="foo")
    mod2 = module.Module(name="foo")
    self.assertEqual(mod1.module_name, "foo")
    self.assertEqual(mod2.module_name, "foo")

  @test_utils.transform_and_run(run_apply=False)
  def test_eval_shape_no_leaked_tracers_under_leak_checker(self):
    with jax.checking_leaks():
      stateful.eval_shape(SquareModule(), jnp.ones(()))  # does not crash

  @test_utils.combined_named_parameters(base_test.SIDE_EFFECTING_FUNCTIONS,
                                        HK_OVERLOADED_JAX_PURE_EXPECTING_FNS)
  @test_utils.transform_and_run
  @test_utils.with_guardrails
  def test_safe_use_of_jax(self, haiku_side_effect_fn, hk_jax_fn):
    if "reserve_rng_keys_while_loop" in self._testMethodName:
      self.skipTest("Expected not to work.")

    # Make `f` identify with the side effecting function included.
    f = hk_jax_fn(lambda x: [haiku_side_effect_fn(), x][1])
    x = jnp.ones([1])
    # These functions should not trigger exceptions from our guardrails.
    f(x)

  @test_utils.transform_and_run
  def test_vmap_split_rng_with_default(self):
    with self.assertRaisesRegex(TypeError,
                                "hk.vmap.require_split_rng = False"):
      # Intentionally missing split_rng arg.
      stateful.vmap(lambda: None)

    with self.subTest("require_split_rng=0"):
      stateful.vmap.require_split_rng = False
      try:
        # This call should not trigger an error, even though we are missing the
        # split_rng argument which appears required (if you look at the function
        # signature). It only works because require_split_rng is
        # propagated to vmap via a sneaky decorator. This only exists to support
        # users who import code that they cannot edit (e.g. from a read only
        # file system) that is not passing the argument.
        f = stateful.vmap(base.next_rng_key, axis_size=2)
      finally:
        stateful.vmap.require_split_rng = True

    # Check that split_rng=False was implied.
    k1, k2 = f()
    self.assertTrue((k1 == k2).all())

  @parameterized.parameters(True, False)
  @test_utils.transform_and_run
  def test_vmap_split_rng_without_default(self, require_split_rng):
    # Tests that when split_rng is passed explicitly the value of
    # require_split_rng has no impact.
    x = jnp.arange(2)
    stateful.vmap.require_split_rng = require_split_rng
    k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x)
    self.assertTrue((k1 != k2).all())
    k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x)
    self.assertTrue((k1 == k2).all())
    stateful.vmap.require_split_rng = True