Ejemplo n.º 1
0
  def test_without_apply_rng_output_type(self):
    def f():
      w = base.get_parameter("w", [], init=jnp.zeros)
      return w

    f = transform.without_apply_rng(transform.transform_with_state(f))
    self.assertIsInstance(f, transform.TransformedWithState)

    f = transform.without_apply_rng(transform.transform(f))
    self.assertIsInstance(f, transform.Transformed)
Ejemplo n.º 2
0
    def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = transform.without_apply_rng(
            transform.transform_with_state(h))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Ejemplo n.º 3
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = transform.without_apply_rng(transform.transform(inner_fn))
     inner_p = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, 0))
     return vmap_inner(inner_p, x)
Ejemplo n.º 4
0
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
  def test_ema_is_identity_on_unchanged_data(self):
    def f(x):
      return moving_averages.ExponentialMovingAverage(0.5)(x)

    inp_value = 1.0
    init_fn, apply_fn = transform.without_apply_rng(
        transform.transform_with_state(f))
    _, params_state = init_fn(None, inp_value)

    # The output should never change as long as the input doesn't.
    value = inp_value
    for _ in range(10):
      value, params_state = apply_fn(None, params_state, value)
      # Floating point error creeps up to 1e-7 (the default).
      np.testing.assert_allclose(inp_value, value, rtol=1e-6)
  def test_ignore_regex(self):
    def f():
      return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

    init_fn, _ = transform.transform(f)
    params = init_fn(random.PRNGKey(428))

    def g(x):
      return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x)
    init_fn, apply_fn = transform.without_apply_rng(
        transform.transform_with_state(g))
    _, params_state = init_fn(None, params)
    params, params_state = apply_fn(None, params_state, params)
    # Let's modify our params.
    changed_params = tree.map_structure(lambda t: 2. * t, params)
    ema_params, params_state = apply_fn(None, params_state, changed_params)

    # W should be the same!
    # ... but b should have changed!
    self.assertTrue((changed_params.linear.b != ema_params.linear.b).all())
    self.assertTrue((changed_params.linear.w == ema_params.linear.w).all())
Ejemplo n.º 7
0
class TransformTest(parameterized.TestCase):
    @test_utils.transform_and_run
    def test_parameter_reuse(self):
        w1 = base.get_parameter("w", [], init=jnp.zeros)
        w2 = base.get_parameter("w", [], init=jnp.zeros)
        self.assertIs(w1, w2)

    def test_params(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, _ = transform.transform(f)
        params = init_fn(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})

    @test_utils.transform_and_run
    def test_naked_get_parameter(self):
        w1 = base.get_parameter("w", [], init=jnp.zeros)
        w2 = base.get_parameter("w", [], init=jnp.zeros)
        self.assertIs(w1, w2)

    def test_naked_parameter_in_tilde_collection(self):
        def net():
            w1 = base.get_parameter("w1", [], init=jnp.zeros)
            w2 = base.get_parameter("w2", [], init=jnp.ones)
            self.assertIsNot(w1, w2)

        init_fn, _ = transform.transform(net)
        params = init_fn(None)
        self.assertEqual(params,
                         {"~": {
                             "w1": jnp.zeros([]),
                             "w2": jnp.ones([])
                         }})

    @parameterized.parameters((None, ), ({}, ), ({"~": {}}, ))
    def test_parameter_in_apply(self, params):
        _, apply_fn = transform.transform(
            lambda: base.get_parameter("w", [], init=jnp.zeros))

        with self.assertRaisesRegex(
                ValueError, "parameters must be created as part of `init`"):
            apply_fn(params, None)

    @test_utils.transform_and_run(seed=None)
    def test_no_rng(self):
        with self.assertRaisesRegex(ValueError,
                                    "must pass a non-None PRNGKey"):
            base.next_rng_key()

    def test_invalid_rng(self):
        f = transform.transform(lambda: None)
        with self.assertRaisesRegex(
                ValueError,
                "Init must be called with an RNG as the first argument"):
            f.init("nonsense")
        with self.assertRaisesRegex(
                ValueError,
                "Apply must be called with an RNG as the second argument"):
            f.apply({}, "nonsense")

    def test_invalid_rng_state(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(
                ValueError,
                "Init must be called with an RNG as the first argument"):
            f.init("nonsense")
        with self.assertRaisesRegex(
                ValueError,
                "Apply must be called with an RNG as the third argument"):
            f.apply({}, {"x": {}}, "nonsense")

    @parameterized.parameters(transform.transform,
                              transform.transform_with_state)
    def test_invalid_rng_none_ignored(self, transform_fn):
        f = transform_fn(lambda: None)
        args = f.init(None)
        if not isinstance(args, tuple):
            args = (args, )
        f.apply(*args, None)

    def test_invalid_params(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(TypeError,
                                    "params argument does not appear valid"):
            f.apply("z", {}, None)

    def test_invalid_state(self):
        f = transform.transform_with_state(lambda: None)
        with self.assertRaisesRegex(TypeError,
                                    "state argument does not appear valid"):
            f.apply({}, "z", None)

    def test_maybe_rng_no_transform(self):
        with self.assertRaisesRegex(
                ValueError, "must be used as part of an `hk.transform`"):
            base.maybe_next_rng_key()

    @test_utils.transform_and_run(seed=None)
    def test_maybe_no_rng(self):
        self.assertIsNone(base.maybe_next_rng_key())

    def test_maybe_rng_vs_not(self):
        """If we have an rng, then next_rng_key() == maybe_next_rng_key()."""
        rngs = []
        maybes = []

        @test_utils.transform_and_run
        def three():
            for _ in range(3):
                rngs.append(base.next_rng_key())

        @test_utils.transform_and_run
        def maybe_three():
            for _ in range(3):
                maybes.append(base.maybe_next_rng_key())

        three()
        maybe_three()
        self.assertLen(rngs, 6)
        self.assertTrue(jnp.all(jnp.array(rngs) == jnp.array(maybes)))

    def test_init_custom_creator(self):
        def zeros_creator(next_creator, shape, dtype, init, context):
            self.assertEqual(context.full_name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(shape, dtype, jnp.zeros)

        def f():
            with base.custom_creator(zeros_creator):
                return base.get_parameter("w", [], init=jnp.ones)

        params = transform.transform(f).init(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})

    def test_not_triggered_in_apply(self):
        log = []

        def counting_creator(next_creator, shape, dtype, init, context):
            log.append(context.full_name)
            return next_creator(shape, dtype, init)

        def net():
            with base.custom_creator(counting_creator):
                for i in range(4):
                    base.get_parameter("w{}".format(i), [], init=jnp.zeros)

        init_fn, apply_fn = transform.transform(net)

        params = init_fn(None)
        self.assertEqual(log, ["~/w0", "~/w1", "~/w2", "~/w3"])

        del log[:]
        apply_fn(params, None)
        self.assertEmpty(log)

    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, shape, dtype, init, context):
                del context
                log.append(log_msg)
                return next_creator(shape, dtype, init)

            return _logging_creator

        def f():
            a, b, c = map(logging_creator, ["a", "b", "c"])
            with base.custom_creator(a), \
                 base.custom_creator(b), \
                 base.custom_creator(c):
                return base.get_parameter("w", [], init=jnp.ones)

        transform.transform(f).init(None)
        self.assertEqual(log, ["a", "b", "c"])

    def test_argspec(self):
        init_fn, apply_fn = transform.transform_with_state(lambda: None)
        init_fn_spec = inspect.getfullargspec(init_fn)
        apply_fn_spec = inspect.getfullargspec(apply_fn)

        self.assertEqual(init_fn_spec.args, ["rng"])
        self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])

    def test_get_state_no_init_raises(self):
        init_fn, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i"))
        with self.assertRaisesRegex(ValueError, "set an init function"):
            init_fn(None)
        state = params = {"~": {}}
        with self.assertRaisesRegex(ValueError, "set an init function"):
            apply_fn(params, state, None)

    def test_get_state_no_shape_raises(self):
        init_fn, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i", init=jnp.zeros))
        with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
            init_fn(None)
        state = params = {"~": {}}
        with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
            apply_fn(params, state, None)

    def test_get_state_no_init(self):
        _, apply_fn = transform.transform_with_state(
            lambda: base.get_state("i"))
        for i in range(10):
            state_in = {"~": {"i": i}}
            _, state_out = apply_fn({}, state_in, None)
            self.assertEqual(state_in, state_out)

    def test_set_then_get(self):
        def net():
            base.set_state("i", 1)
            return base.get_state("i")

        init_fn, apply_fn = transform.transform_with_state(net)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"i": 1}})

        for i in range(10):
            state_in = {"~": {"i": i}}
            y, state_out = apply_fn(params, state_in, None)
            self.assertEqual(y, 1)
            self.assertEqual(state_out, {"~": {"i": 1}})

    def test_stateful(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, apply_fn = transform.transform_with_state(f)
        params, state = init_fn(None)
        self.assertEqual(state, {"~": {"count": 0}})
        _, state = apply_fn(params, state, None)
        self.assertEqual(state, {"~": {"count": 10}})

    def test_without_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = transform.without_state(
            transform.transform_with_state(f))
        params = init_fn(None)
        out = apply_fn(params, None)
        self.assertEqual(out, 0)

    def test_without_state_raises_if_state_used(self):
        def f():
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)
            return count

        init_fn, _ = transform.without_state(transform.transform_with_state(f))

        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            init_fn(None)

    def test_inline_use(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        f = transform.transform(f)

        rng = jax.random.PRNGKey(42)
        params = f.init(rng)
        w = f.apply(params, None)
        self.assertEqual(w, 0)

    def test_method(self):
        obj = ObjectWithTransform()
        x = jnp.ones([])
        params = obj.forward.init(None, x)
        obj_out, y = obj.forward.apply(params, None, x)
        self.assertEqual(y, 1)
        self.assertIs(obj, obj_out)
        params = jax.tree_map(lambda v: v + 1, params)
        obj_out, y = obj.forward.apply(params, None, x)
        self.assertEqual(y, 2)
        self.assertIs(obj, obj_out)

    def test_trampoline(self):
        obj = ObjectWithTransform()
        x = jnp.ones([])
        params = obj.trampoline.init(None, x)
        obj_out, y = obj.trampoline.apply(params, None, x)
        self.assertEqual(y, 1)
        self.assertIs(obj, obj_out)

    @parameterized.parameters((42, True), (42, False), (28, True), (28, False))
    def test_prng_sequence(self, seed, wrap_seed):
        # Values using our sequence.
        key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed
        key_seq = base.PRNGSequence(key_or_seed)
        seq_v1 = jax.random.normal(next(key_seq), [])
        seq_v2 = jax.random.normal(next(key_seq), [])
        # Generate values using manual splitting.
        key = jax.random.PRNGKey(seed)
        key, temp_key = jax.random.split(key)
        raw_v1 = jax.random.normal(temp_key, [])
        _, temp_key = jax.random.split(key)
        raw_v2 = jax.random.normal(temp_key, [])
        self.assertEqual(raw_v1, seq_v1)
        self.assertEqual(raw_v2, seq_v2)

    def test_prng_sequence_invalid_input(self):
        with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"):
            base.PRNGSequence("nonsense")

    def test_prng_sequence_wrong_shape(self):
        with self.assertRaisesRegex(
                ValueError, "key did not have expected shape and/or dtype"):
            base.PRNGSequence(jax.random.split(jax.random.PRNGKey(42), 2))

    @parameterized.parameters(42, 28)
    def test_with_rng(self, seed):
        key = jax.random.PRNGKey(seed)
        unrelated_key = jax.random.PRNGKey(seed * 2 + 1)
        _, next_key = jax.random.split(key)
        expected_output = jax.random.uniform(next_key, ())

        def without_decorator():
            return jax.random.uniform(base.next_rng_key(), ())

        without_decorator = transform.transform(without_decorator)
        without_decorator_out = without_decorator.apply(None,
                                                        unrelated_key).item()

        def with_decorator():
            with base.with_rng(key):
                return jax.random.uniform(base.next_rng_key(), ())

        with_decorator = transform.transform(with_decorator)
        with_decorator_out = with_decorator.apply(None, unrelated_key).item()

        self.assertNotEqual(without_decorator_out, expected_output)
        self.assertEqual(with_decorator_out, expected_output)

    def test_new_context(self):
        with base.new_context() as ctx:
            pass
        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_initial_state())
        self.assertEmpty(ctx.collect_state())

    def test_context_copies_input(self):
        before = {"~": {"w": jnp.array(1.)}}
        with base.new_context(params=before, state=before) as ctx:
            base.get_parameter("w", [], init=jnp.ones)
            base.set_state("w", jnp.array(2.))
        self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
        self.assertIsNot(ctx.collect_initial_state(), before)
        self.assertEqual(ctx.collect_initial_state(), before)
        self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
        self.assertEqual(before, {"~": {"w": jnp.array(1.)}})

    def test_without_state_raises_if_state_used_on_apply(self):
        f = lambda: base.set_state("~", 1)
        f = transform.without_state(transform.transform_with_state(f))
        rng = jax.random.PRNGKey(42)
        with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
            params = f.init(rng)
            f.apply(params, rng)

    def test_running_init(self):
        l = []
        f = transform.transform(lambda: l.append(transform.running_init()))
        f.init(None)
        f.apply({}, None)
        init_value, apply_value = l  # pylint: disable=unbalanced-tuple-unpacking
        self.assertEqual(init_value, True)
        self.assertEqual(apply_value, False)

    def test_running_init_outside_transform(self):
        with self.assertRaisesRegex(
                ValueError, "running_init.*used as part of.*transform"):
            transform.running_init()

    @parameterized.parameters(
        None, transform.without_apply_rng, transform.without_state,
        lambda f: transform.without_state(transform.without_apply_rng(f)))
    def test_persists_original_fn(self, without):
        orig_f = lambda: None
        f = transform.transform(orig_f)
        if without is not None:
            f = without(f)
        self.assertIs(transform.get_original_fn(f), orig_f)
        self.assertIs(transform.get_original_fn(f.init), orig_f)
        self.assertIs(transform.get_original_fn(f.apply), orig_f)
Ejemplo n.º 8
0
 def test_calling_with_duplicate_rng_kwarg(self, transform_fn):
   def f(rng):
     del rng
   self.assert_raises_by_name_error(
       transform.without_apply_rng(transform_fn(f)))