Esempio n. 1
0
 def test_without_state_raises_if_state_used_on_apply(self):
     f = lambda: base.set_state("~", 1)
     f = transform.without_state(transform.transform_with_state(f))
     rng = jax.random.PRNGKey(42)
     with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
         params = f.init(rng)
         f.apply(params, rng)
Esempio n. 2
0
    def test_without_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = transform.without_state(
            transform.transform_with_state(f))
        params = init_fn(None)
        out = apply_fn(params, None)
        self.assertEqual(out, 0)
Esempio n. 3
0
    def test_without_state_raises_if_state_used(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, _ = transform.without_state(transform.transform_with_state(f))

        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            init_fn(None)
Esempio n. 4
0
 def test_without_state_raises_if_state_used(self):
     init_fn, _ = transform.without_state(
         transform.transform_with_state(lambda: CountingModule()()))  # pylint: disable=unnecessary-lambda
     with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
         init_fn(None)
Esempio n. 5
0
 def test_without_state(self):
     init_fn, apply_fn = transform.without_state(
         transform.transform_with_state(lambda: ScalarModule()()))  # pylint: disable=unnecessary-lambda
     params = init_fn(None)
     out = apply_fn(params, None)
     self.assertEqual(out, 0)
Esempio n. 6
0
class TransformTest(parameterized.TestCase):
    @test_utils.transform_and_run
    def test_parameter_reuse(self):
        w1 = base.get_parameter("w", [], init=jnp.zeros)
        w2 = base.get_parameter("w", [], init=jnp.zeros)
        self.assertIs(w1, w2)

    def test_params(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, _ = transform.transform(f)
        params = init_fn(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})

    @test_utils.transform_and_run
    def test_naked_get_parameter(self):
        w1 = base.get_parameter("w", [], init=jnp.zeros)
        w2 = base.get_parameter("w", [], init=jnp.zeros)
        self.assertIs(w1, w2)

    def test_naked_parameter_in_tilde_collection(self):
        def net():
            w1 = base.get_parameter("w1", [], init=jnp.zeros)
            w2 = base.get_parameter("w2", [], init=jnp.ones)
            self.assertIsNot(w1, w2)

        init_fn, _ = transform.transform(net)
        params = init_fn(None)
        self.assertEqual(params,
                         {"~": {
                             "w1": jnp.zeros([]),
                             "w2": jnp.ones([])
                         }})

    @parameterized.parameters((None, ), ({}, ), ({"~": {}}, ))
    def test_parameter_in_apply(self, params):
        _, apply_fn = transform.transform(
            lambda: base.get_parameter("w", [], init=jnp.zeros))

        with self.assertRaisesRegex(
                ValueError, "parameters must be created as part of `init`"):
            apply_fn(params, None)

    @test_utils.transform_and_run(seed=None)
    def test_no_rng(self):
        with self.assertRaisesRegex(ValueError,
                                    "must pass a non-None PRNGKey"):
            base.next_rng_key()

    def test_invalid_rng(self):
        f = transform.transform(lambda: None)
        with self.assertRaisesRegex(
                ValueError,
                "Init must be called with an RNG as the first argument"):
            f.init("nonsense")
        with self.assertRaisesRegex(
                ValueError,
                "Apply must be called with an RNG as the second argument"):
            f.apply({}, "nonsense")

    def test_invalid_rng_state(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(
                ValueError,
                "Init must be called with an RNG as the first argument"):
            f.init("nonsense")
        with self.assertRaisesRegex(
                ValueError,
                "Apply must be called with an RNG as the third argument"):
            f.apply({}, {"x": {}}, "nonsense")

    @parameterized.parameters(transform.transform,
                              transform.transform_with_state)
    def test_invalid_rng_none_ignored(self, transform_fn):
        f = transform_fn(lambda: None)
        args = f.init(None)
        if not isinstance(args, tuple):
            args = (args, )
        f.apply(*args, None)

    def test_invalid_params(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(TypeError,
                                    "params argument does not appear valid"):
            f.apply("z", {}, None)

    def test_invalid_state(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(TypeError,
                                    "state argument does not appear valid"):
            f.apply({}, "z", None)

    def test_maybe_rng_no_transform(self):
        with self.assertRaisesRegex(
                ValueError, "must be used as part of an `hk.transform`"):
            base.maybe_next_rng_key()

    @test_utils.transform_and_run(seed=None)
    def test_maybe_no_rng(self):
        self.assertIsNone(base.maybe_next_rng_key())

    def test_maybe_rng_vs_not(self):
        """If we have an rng, then next_rng_key() == maybe_next_rng_key()."""
        rngs = []
        maybes = []

        @test_utils.transform_and_run
        def three():
            for _ in range(3):
                rngs.append(base.next_rng_key())

        @test_utils.transform_and_run
        def maybe_three():
            for _ in range(3):
                maybes.append(base.maybe_next_rng_key())

        three()
        maybe_three()
        self.assertLen(rngs, 6)
        self.assertTrue(jnp.all(jnp.array(rngs) == jnp.array(maybes)))

    def test_init_custom_creator(self):
        def zeros_creator(next_creator, shape, dtype, init, context):
            self.assertEqual(context.full_name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(shape, dtype, jnp.zeros)

        def f():
            with base.custom_creator(zeros_creator):
                return base.get_parameter("w", [], init=jnp.ones)

        params = transform.transform(f).init(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})

    def test_not_triggered_in_apply(self):
        log = []

        def counting_creator(next_creator, shape, dtype, init, context):
            log.append(context.full_name)
            return next_creator(shape, dtype, init)

        def net():
            with base.custom_creator(counting_creator):
                for i in range(4):
                    base.get_parameter("w{}".format(i), [], init=jnp.zeros)

        init_fn, apply_fn = transform.transform(net)

        params = init_fn(None)
        self.assertEqual(log, ["~/w0", "~/w1", "~/w2", "~/w3"])

        del log[:]
        apply_fn(params, None)
        self.assertEmpty(log)

    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, shape, dtype, init, context):
                del context
                log.append(log_msg)
                return next_creator(shape, dtype, init)

            return _logging_creator

        def f():
            a, b, c = map(logging_creator, ["a", "b", "c"])
            with base.custom_creator(a), \
                 base.custom_creator(b), \
                 base.custom_creator(c):
                return base.get_parameter("w", [], init=jnp.ones)

        transform.transform(f).init(None)
        self.assertEqual(log, ["a", "b", "c"])

    def test_argspec(self):
        init_fn, apply_fn = transform.transform_with_state(lambda: None)
        init_fn_spec = inspect.getfullargspec(init_fn)
        apply_fn_spec = inspect.getfullargspec(apply_fn)

        self.assertEqual(init_fn_spec.args, ["rng"])
        self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])

    def test_get_state_no_init_raises(self):
        init_fn, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i"))
        with self.assertRaisesRegex(ValueError, "set an init function"):
            init_fn(None)
        state = params = {"~": {}}
        with self.assertRaisesRegex(ValueError, "set an init function"):
            apply_fn(params, state, None)

    def test_get_state_no_shape_raises(self):
        init_fn, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i", init=jnp.zeros))
        with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
            init_fn(None)
        state = params = {"~": {}}
        with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
            apply_fn(params, state, None)

    def test_get_state_no_init(self):
        _, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i"))
        for i in range(10):
            state_in = {"~": {"i": i}}
            _, state_out = apply_fn({}, state_in, None)
            self.assertEqual(state_in, state_out)

    def test_set_then_get(self):
        def net():
            base.set_state("i", 1)
            return base.get_state("i")

        init_fn, apply_fn = transform.transform_with_state(net)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"i": 1}})

        for i in range(10):
            state_in = {"~": {"i": i}}
            y, state_out = apply_fn(params, state_in, None)
            self.assertEqual(y, 1)
            self.assertEqual(state_out, {"~": {"i": 1}})

    def test_stateful(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, apply_fn = transform.transform_with_state(f)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"count": 0}})
        _, state = apply_fn(params, state, None)
        self.assertEqual(state, {"~": {"count": 10}})

    def test_without_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = transform.without_state(
            transform.transform_with_state(f))
        params = init_fn(None)
        out = apply_fn(params, None)
        self.assertEqual(out, 0)

    def test_without_state_raises_if_state_used(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, _ = transform.without_state(transform.transform_with_state(f))

        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            init_fn(None)

    def test_inline_use(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        f = transform.transform(f)

        rng = jax.random.PRNGKey(42)
        params = f.init(rng)
        w = f.apply(params, None)
        self.assertEqual(w, 0)

    def test_method(self):
        obj = ObjectWithTransform()
        x = jnp.ones([])
        params = obj.forward.init(None, x)
        obj_out, y = obj.forward.apply(params, None, x)
        self.assertEqual(y, 1)
        self.assertIs(obj, obj_out)
        params = jax.tree_map(lambda v: v + 1, params)
        obj_out, y = obj.forward.apply(params, None, x)
        self.assertEqual(y, 2)
        self.assertIs(obj, obj_out)

    def test_trampoline(self):
        obj = ObjectWithTransform()
        x = jnp.ones([])
        params = obj.trampoline.init(None, x)
        obj_out, y = obj.trampoline.apply(params, None, x)
        self.assertEqual(y, 1)
        self.assertIs(obj, obj_out)

    @parameterized.parameters((42, True), (42, False), (28, True), (28, False))
    def test_prng_sequence(self, seed, wrap_seed):
        # Values using our sequence.
        key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed
        key_seq = base.PRNGSequence(key_or_seed)
        seq_v1 = jax.random.normal(next(key_seq), [])
        seq_v2 = jax.random.normal(next(key_seq), [])
        # Generate values using manual splitting.
        key = jax.random.PRNGKey(seed)
        key, temp_key = jax.random.split(key)
        raw_v1 = jax.random.normal(temp_key, [])
        _, temp_key = jax.random.split(key)
        raw_v2 = jax.random.normal(temp_key, [])
        self.assertEqual(raw_v1, seq_v1)
        self.assertEqual(raw_v2, seq_v2)

    def test_prng_sequence_invalid_input(self):
        with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"):
            base.PRNGSequence("nonsense")

    def test_prng_sequence_wrong_shape(self):
        with self.assertRaisesRegex(
                ValueError, "key did not have expected shape and/or dtype"):
            base.PRNGSequence(jax.random.split(jax.random.PRNGKey(42), 2))

    @parameterized.parameters(42, 28)
    def test_with_rng(self, seed):
        key = jax.random.PRNGKey(seed)
        unrelated_key = jax.random.PRNGKey(seed * 2 + 1)
        _, next_key = jax.random.split(key)
        expected_output = jax.random.uniform(next_key, ())

        def without_decorator():
            return jax.random.uniform(base.next_rng_key(), ())

        without_decorator = transform.transform(without_decorator)
        without_decorator_out = without_decorator.apply(None,
                                                        unrelated_key).item()

        def with_decorator():
            with base.with_rng(key):
                return jax.random.uniform(base.next_rng_key(), ())

        with_decorator = transform.transform(with_decorator)
        with_decorator_out = with_decorator.apply(None, unrelated_key).item()

        self.assertNotEqual(without_decorator_out, expected_output)
        self.assertEqual(with_decorator_out, expected_output)

    def test_new_context(self):
        with base.new_context() as ctx:
            pass
        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_initial_state())
        self.assertEmpty(ctx.collect_state())

    def test_context_copies_input(self):
        before = {"~": {"w": jnp.array(1.)}}
        with base.new_context(params=before, state=before) as ctx:
            base.get_parameter("w", [], init=jnp.ones)
            base.set_state("w", jnp.array(2.))
        self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
        self.assertIsNot(ctx.collect_initial_state(), before)
        self.assertEqual(ctx.collect_initial_state(), before)
        self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
        self.assertEqual(before, {"~": {"w": jnp.array(1.)}})

    def test_without_state_raises_if_state_used_on_apply(self):
        f = lambda: base.set_state("~", 1)
        f = transform.without_state(transform.transform_with_state(f))
        rng = jax.random.PRNGKey(42)
        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            params = f.init(rng)
            f.apply(params, rng)

    def test_running_init(self):
        l = []
        f = transform.transform(lambda: l.append(transform.running_init()))
        f.init(None)
        f.apply({}, None)
        init_value, apply_value = l  # pylint: disable=unbalanced-tuple-unpacking
        self.assertEqual(init_value, True)
        self.assertEqual(apply_value, False)

    def test_running_init_outside_transform(self):
        with self.assertRaisesRegex(
                ValueError, "running_init.*used as part of.*transform"):
            transform.running_init()

    @parameterized.parameters(
        None, transform.without_apply_rng, transform.without_state,
        lambda f: transform.without_state(transform.without_apply_rng(f)))
    def test_persists_original_fn(self, without):
        orig_f = lambda: None
        f = transform.transform(orig_f)
        if without is not None:
            f = without(f)
        self.assertIs(transform.get_original_fn(f), orig_f)
        self.assertIs(transform.get_original_fn(f.init), orig_f)
        self.assertIs(transform.get_original_fn(f.apply), orig_f)