예제 #1
0
  def test_optimize_rng_use_under_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x):
      return module_fn()(x)

    f = hk.transform_with_state(hk.experimental.optimize_rng_use(g))

    module_type = descriptors.module_type(module_fn)
    atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
    assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol)

    params, state = jax.jit(f.init)(rng, x)
    jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x))

    if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA):
      # For stochastic modules just test apply runs.
      jax.device_get(jax.jit(f.apply)(params, state, rng, x))

    else:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x))
예제 #2
0
  def test_jit(self, module_fn: ModuleFn, shape, dtype):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x):
      return module_fn()(x)

    f = hk.transform_with_state(g)

    atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn), DEFAULT_ATOL)
    assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol)

    # Ensure initialization under jit is the same.
    jax.tree_multimap(assert_allclose,
                      f.init(rng, x),
                      jax.jit(f.init)(rng, x))

    # Ensure application under jit is the same.
    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      f.apply(params, state, rng, x),
                      jax.jit(f.apply)(params, state, rng, x))
예제 #3
0
    def test_vmap(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        # Expand our input since we will map over it.
        x = jnp.broadcast_to(x, (2, ) + x.shape)

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
        f_mapped = hk.transform_with_state(
            lambda x: hk.vmap(lambda x: module_fn()(x))(x))  # pylint: disable=unnecessary-lambda

        params, state = f_mapped.init(rng, x)

        # JAX vmap with explicitly unmapped params/state/rng. This should be
        # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map
        # params/state/rng).
        v_apply = jax.vmap(f.apply,
                           in_axes=(None, None, None, 0),
                           out_axes=(0, None))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)
        jax.tree_multimap(assert_allclose,
                          f_mapped.apply(params, state, rng, x),
                          v_apply(params, state, rng, x))
예제 #4
0
 def test_info_and_html(self, module_fn: ModuleFn, shape, dtype):
   x = jnp.ones(shape, dtype)
   f = hk.transform_with_state(lambda: module_fn()(x))
   rng = jax.random.PRNGKey(42)
   params, state = f.init(rng)
   info = jaxpr_info.make_model_info(f.apply)(params, state, rng)
   if descriptors.module_type(module_fn).__name__ != 'Sequential':
     self.assertNotEmpty(info.expressions)
   self.assertIsNotNone(jaxpr_info.as_html_page(info))
예제 #5
0
  def test_strict_promotion(self, module_fn: ModuleFn, shape, dtype):
    if descriptors.module_type(module_fn) in (hk.nets.VectorQuantizer,
                                              hk.nets.VectorQuantizerEMA):
      self.skipTest('Requires: https://github.com/google/jax/pull/2901')

    f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
    rng = jax.random.PRNGKey(42)
    x = jnp.ones(shape, dtype)
    params, state = f.init(rng, x)
    self.assertIsNotNone(f.apply(params, state, rng, x))
예제 #6
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  for descriptor in descriptors.ALL_MODULES:
    cls = descriptors.module_type(descriptor.create)
    file_name = descriptors.to_file_name(descriptor) + ".json"
    summary = checkpoint_utils.summarize(descriptor)
    with open(os.path.join(FLAGS.base_dir, file_name), "w") as fp:
      fp.write(json.dumps(summary, indent=2))
      fp.write("\n")
예제 #7
0
    def test_checkpoint_format(self, name, module_fn: ModuleFn, shape, dtype):
        descriptor = descriptors.ModuleDescriptor(name, module_fn, shape,
                                                  dtype)
        cls = descriptors.module_type(descriptor.create)
        expected = checkpoint_utils.summarize(descriptor)
        file_path = os.path.join(
            "haiku/_src/integration/checkpoints/",
            descriptors.to_file_name(descriptor) + ".json")
        if not os.path.exists(file_path):
            expected_json = json.dumps(expected, indent=2)
            raise ValueError(f"Missing checkpoint file: {file_path}\n\n"
                             f"Expected:\n\n{expected_json}")

        with open(file_path, "r") as fp:
            actual = json.load(fp)

        self.assertEqual(expected, actual, msg=HOW_TO_REGENERATE)
예제 #8
0
    def test_convert(
        self,
        module_fn: ModuleFn,
        shape,
        dtype,
        init: bool,
        tf_transform,
        jax_transform,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn),
                               DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        get = lambda t: jax.tree_map(lambda x: x.numpy(), t)

        if init:
            init_jax = jax_transform(f.init)
            init_tf = tf_transform(jax2tf.convert(f.init))
            jax.tree_multimap(assert_allclose, init_jax(rng, x),
                              get(init_tf(rng, x)))

        else:
            params, state = f.init(rng, x)
            apply_jax = jax_transform(f.apply)
            apply_tf = tf_transform(jax2tf.convert(f.apply))
            jax.tree_multimap(assert_allclose,
                              apply_jax(params, state, rng, x),
                              get(apply_tf(params, state, rng, x)))