Exemple #1
0
    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, shape, dtype, init, context):
                del context
                log.append(log_msg)
                return next_creator(shape, dtype, init)

            return _logging_creator

        def f():
            a, b, c = map(logging_creator, ["a", "b", "c"])
            with base.custom_creator(a), \
                 base.custom_creator(b), \
                 base.custom_creator(c):
                return base.get_parameter("w", [], init=jnp.ones)

        transform.transform(f).init(None)
        self.assertEqual(log, ["a", "b", "c"])
Exemple #2
0
    def test_optimize_rng_splitting(self):
        def f():
            k1 = base.next_rng_key()
            k2 = base.next_rng_key()
            return k1, k2

        key = jax.random.PRNGKey(42)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-5)

        # With optimize_rng_use the keys returned should be equal to split(n).
        f_opt = transform.transform(random.optimize_rng_use(f), apply_rng=True)
        jax.tree_multimap(assert_allclose, f_opt.apply({}, key),
                          tuple(jax.random.split(key, 3))[1:])

        # Without optimize_rng_use the keys should be equivalent to splitting in a
        # loop.
        f = transform.transform(f, apply_rng=True)
        jax.tree_multimap(assert_allclose, f.apply({}, key),
                          tuple(split_for_n(key, 2)))
 def test_jax_transformed_wrapper(self, jax_transform):
     # Happens in practice if someone asks for a `summary(pmap(train_step))`
     f = lambda: CallsOtherModule(MultipleParametersModule())()
     f = transform.transform(f)
     rng = jax.random.PRNGKey(42)
     if jax_transform == jax.pmap:
         rng = jnp.broadcast_to(rng, (1, *rng.shape))
     params = jax_transform(f.init)(rng)
     g = jax_transform(lambda params, rng: f.apply(params, rng))
     rows = tabulate_to_list(g, params, rng)
     self.assertNotEmpty(rows)
  def test_inline_use(self):
    def f():
      w = base.get_parameter("w", [], init=jnp.zeros)
      return w

    f = transform.transform(f)

    rng = jax.random.PRNGKey(42)
    params = f.init(rng)
    w = f.apply(params)
    self.assertEqual(w, 0)
    def test_with_empty_state(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, apply_fn = transform.with_empty_state(transform.transform(f))
        params, state = init_fn(None)
        self.assertEmpty(state)
        out, state = apply_fn(params, state, None)
        self.assertEqual(out, 0)
        self.assertEmpty(state)
Exemple #6
0
    def test_linear_without_bias_has_zero_in_null_space(self):
        def f():
            return basic.Linear(output_size=6, with_bias=False)(jnp.zeros(
                (5, 6)))

        init_fn, apply_fn = transform.transform(f)
        params = init_fn(random.PRNGKey(428))
        self.assertEqual(params["linear"]["w"].shape, (6, 6))
        self.assertNotIn("b", params["linear"])
        np.testing.assert_array_almost_equal(apply_fn(params, None),
                                             jnp.zeros((5, 6)))
Exemple #7
0
    def test_linear_without_bias_has_zero_in_null_space(self):
        def f():
            return basic.Linear(output_size=6, with_bias=False)(jnp.zeros(
                (5, 6)))

        init_fn, apply_fn = transform.transform(f)
        params = init_fn(random.PRNGKey(428))
        self.assertEqual(params.linear.w.shape, (6, 6))
        self.assertFalse(hasattr(params.linear, "b"))
        np.testing.assert_array_almost_equal(apply_fn(params), jnp.zeros(
            (5, 6)))
  def test_with_rng(self, seed):
    key = jax.random.PRNGKey(seed)
    unrelated_key = jax.random.PRNGKey(seed * 2 + 1)
    _, next_key = jax.random.split(key)
    expected_output = jax.random.uniform(next_key, ())

    def without_decorator():
      return jax.random.uniform(base.next_rng_key(), ())
    without_decorator = transform.transform(without_decorator, apply_rng=True)
    without_decorator_out = without_decorator.apply(None, unrelated_key).item()

    def with_decorator():
      with base.with_rng(key):
        return jax.random.uniform(base.next_rng_key(), ())

    with_decorator = transform.transform(with_decorator, apply_rng=True)
    with_decorator_out = with_decorator.apply(None, unrelated_key).item()

    self.assertNotEqual(without_decorator_out, expected_output)
    self.assertEqual(with_decorator_out, expected_output)
Exemple #9
0
 def test_bias_dims_custom(self):
   b, d1, d2, d3 = range(1, 5)
   def f():
     mod = bias.Bias(bias_dims=[1, 3])
     out = mod(jnp.ones([b, d1, d2, d3]))
     self.assertEqual(mod.bias_shape, (d1, 1, d3))
     return out
   f = transform.transform(f)
   params = f.init(None)
   out = f.apply(params, None)
   self.assertEqual(params["bias"]["b"].shape, (d1, 1, d3))
   self.assertEqual(out.shape, (b, d1, d2, d3))
Exemple #10
0
  def test_connect_conv_transpose_strided(self, n):
    def f():
      input_shape = [2] + [8]*n + [4]
      data = jnp.zeros(input_shape)
      net = conv.ConvNDTranspose(
          n, output_channels=3, kernel_shape=3, stride=3)
      return net(data)

    init_fn, apply_fn = transform.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))
    expected_output_shape = (2,) + (24,)*n + (3,)
    self.assertEqual(out.shape, expected_output_shape)
Exemple #11
0
    def test_bf16(self):
        """For all configurations, ensure bf16 outputs from bf16 inputs."""
        def f(x):
            ln = rms_norm.RMSNorm(axis=-1)
            return ln(x)

        fwd = transform.transform(f)
        data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16)
        params = fwd.init(jax.random.PRNGKey(428), data)
        bf16_params = jax.tree_map(lambda t: t.astype(jnp.bfloat16), params)
        self.assertEqual(
            fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16)
Exemple #12
0
    def test_lift_naming_semantics(self, inner_module):
        @transform.transform
        def fn(x):
            return with_transparent_lift(inner_module)(x)

        x = jnp.ones([10, 10])
        params_with_lift = fn.init(None, x)
        params_without_lift = transform.transform(inner_module).init(None, x)
        jax.tree_map(self.assertAlmostEqual, params_with_lift,
                     params_without_lift)

        fn.apply(params_with_lift, None, x)
Exemple #13
0
  def test_diluted_conv(self, n):
    input_shape = [2] + [16]*n + [4]

    def f():
      data = jnp.zeros(input_shape)
      net = conv.ConvND(n, output_channels=3, kernel_shape=3, rate=3)
      return net(data)

    init_fn, apply_fn = transform.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))
    expected_output_shape = (2,) + (16,)*n + (3,)
    self.assertEqual(out.shape, expected_output_shape)
Exemple #14
0
  def test_connect_conv_transpose_valid(self, n):
    def f():
      input_shape = [2] + [16]*n + [4]
      data = jnp.zeros(input_shape)
      net = conv.ConvNDTranspose(
          n, output_channels=3, kernel_shape=3, padding="VALID")
      return net(data)

    init_fn, apply_fn = transform.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)), None)
    expected_output_shape = (2,) + (18,)*n + (3,)
    self.assertEqual(out.shape, expected_output_shape)
Exemple #15
0
            def __call__(self, x):
                x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros)

                def inner_fn(x):
                    return InnerModule(name="inner")(x)

                inner_transformed = transform.transform(inner_fn)
                inner_params = lift.transparent_lift(inner_transformed.init)(
                    base.next_rng_key(), x)
                x = inner_transformed.apply(inner_params, base.next_rng_key(),
                                            x)
                return x
Exemple #16
0
  def test_connect_conv_transpose_channels_first(self, n):
    def f():
      input_shape = [2, 4] + [16]*n
      data = jnp.zeros(input_shape)
      net = conv.ConvNDTranspose(
          n, output_channels=3, kernel_shape=3, data_format="channels_first")
      return net(data)

    init_fn, apply_fn = transform.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))
    expected_output_shape = (2, 3) + (16,)*n
    self.assertEqual(out.shape, expected_output_shape)
Exemple #17
0
    def test_lift_with_scan(self):
        def inner_fn(x):
            x *= base.get_parameter("w", shape=x.shape, init=jnp.zeros)
            return x

        class Outer(module.Module):
            def __init__(self, allow_reuse):
                super().__init__()
                self._allow_reuse = allow_reuse

            def __call__(self, carry, x):
                x += base.get_parameter("w", shape=[], init=jnp.zeros)

                inner = transform.transform(inner_fn)
                keys = base.next_rng_key() if transform.running_init(
                ) else None
                params = lift.lift(inner.init,
                                   allow_reuse=self._allow_reuse)(keys, x)
                return carry, inner.apply(params, None, x)

        def model(x, *, allow_reuse):
            return stateful.scan(Outer(allow_reuse), (), x)

        rng = jax.random.PRNGKey(42)
        data = np.zeros((4, 3, 2))

        with self.subTest(name="allow_reuse"):
            init, apply = transform.transform(
                lambda x: model(x, allow_reuse=True))

            params = init(rng, data)
            _, out = apply(params, None, data)
            np.testing.assert_equal(out, np.zeros_like(data))

        with self.subTest(name="disallow_reuse"):
            init, _ = transform.transform(
                lambda x: model(x, allow_reuse=False))

            with self.assertRaisesRegex(ValueError, "Key '.*' already exists"):
                _ = init(rng, data)
Exemple #18
0
    def test_layer_stack_multi_args(self):
        """Compare layers_stack to the equivalent unrolled stack.

    Similar to `test_layer_stack`, but use a function that takes more than one
    argument.
    """
        num_layers = 20

        def inner_fn(x, y):
            x_out = x + basic.Linear(100, name="linear1")(y)
            y_out = y + basic.Linear(100, name="linear2")(x)
            return x_out, y_out

        def outer_fn_unrolled(x, y):
            for _ in range(num_layers):
                x, y = inner_fn(x, y)
            return x, y

        def outer_fn_layer_stack(x, y):
            stack = layer_stack.layer_stack(num_layers)(inner_fn)
            return stack(x, y)

        unrolled_fn = transform.transform(outer_fn_unrolled)
        layer_stack_fn = transform.transform(outer_fn_layer_stack)

        x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])
        y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100])

        rng_init = jax.random.PRNGKey(42)

        params = layer_stack_fn.init(rng_init, x, y)

        sliced_params = _slice_layers_params(params)

        unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y)
        layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y)

        np.testing.assert_allclose(unrolled_x, layer_stack_x, atol=1e-3)
        np.testing.assert_allclose(unrolled_y, layer_stack_y, atol=1e-3)
Exemple #19
0
  def test_local_stats(self, resnet_v2, bottleneck):
    def forward_fn(image):
      model = resnet.ResNet([1, 1, 1, 1], 10,
                            resnet_v2=resnet_v2,
                            bottleneck=bottleneck)
      return model(image, is_training=False, test_local_stats=True)

    forward = transform.transform(forward_fn, apply_rng=True)
    rng = jax.random.PRNGKey(42)
    image = jnp.ones([2, 64, 64, 3])
    params = forward.init(rng, image)
    logits = forward.apply(params, None, image)
    self.assertEqual(logits.shape, (2, 10))
Exemple #20
0
    def test_naked_parameter_in_tilde_collection(self):
        def net():
            w1 = base.get_parameter("w1", [], init=jnp.zeros)
            w2 = base.get_parameter("w2", [], init=jnp.ones)
            self.assertIsNot(w1, w2)

        init_fn, _ = transform.transform(net)
        params = init_fn(None)
        self.assertEqual(params,
                         {"~": {
                             "w1": jnp.zeros([]),
                             "w2": jnp.ones([])
                         }})
Exemple #21
0
    def test_filter(self):
        init_fn, _ = transform.transform(get_net)
        params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

        second_layer_params = filtering.filter(
            lambda module_name, *_: module_name == "second_layer", params)
        self.assertEqual(get_names(second_layer_params),
                         set(["second_layer/w", "second_layer/b"]))

        biases = filtering.filter(lambda module_name, name, _: name == "b",
                                  params)  # pytype: disable=wrong-arg-types
        self.assertEqual(get_names(biases),
                         set(["first_layer/b", "second_layer/b"]))
Exemple #22
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, shape, dtype, init, context):
            self.assertEqual(context.full_name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(shape, dtype, jnp.zeros)

        def f():
            with base.custom_creator(zeros_creator):
                return base.get_parameter("w", [], init=jnp.ones)

        params = transform.transform(f).init(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
    def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = transform.without_apply_rng(
            transform.transform_with_state(h))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemple #24
0
    def test_bf16(self, create_scale, create_offset, use_fast_variance):
        """For all configurations, ensure bf16 outputs from bf16 inputs."""
        def f(x):
            ln = layer_norm.LayerNorm(axis=-1,
                                      create_scale=create_scale,
                                      create_offset=create_offset,
                                      use_fast_variance=use_fast_variance)
            return ln(x)

        fwd = transform.transform(f)
        data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16)
        params = fwd.init(jax.random.PRNGKey(428), data)
        bf16_params = jax.tree_map(lambda t: t.astype(jnp.bfloat16), params)
        self.assertEqual(
            fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16)
  def test_partitioning(self):

    init_fn, _ = transform.transform(get_net)
    params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

    # parse by layer
    first_layer_params, second_layer_params = filtering.partition(
        lambda module_name, *_: module_name == "first_layer",
        params)
    self.assertEqual(
        get_names(first_layer_params),
        set(["first_layer/w", "first_layer/b"]))
    self.assertEqual(
        get_names(second_layer_params),
        set(["second_layer/w", "second_layer/b"]))

    # parse by variable type
    weights, biases = filtering.partition(
        lambda module_name, name, _: name == "w",
        params)  # pytype: disable=wrong-arg-types
    self.assertEqual(
        get_names(weights),
        set(["first_layer/w", "second_layer/w"]))
    self.assertEqual(
        get_names(biases),
        set(["first_layer/b", "second_layer/b"]))

    # Compose regexes
    regex = compile_regex(["first_layer.*", ".*w"])
    matching, not_matching = filtering.partition(
        lambda module_name, name, _: regex.match(f"{module_name}/{name}"),
        params)
    self.assertEqual(
        get_names(matching),
        set(["first_layer/w", "first_layer/b", "second_layer/w"]))
    self.assertEqual(
        get_names(not_matching),
        set(["second_layer/b"]))

    matching, not_matching = filtering.partition(
        lambda mod_name, name, _: mod_name == "first_layer" and name != "w",
        params)
    self.assertEqual(
        get_names(matching),
        set(["first_layer/b"]))
    self.assertEqual(
        get_names(not_matching),
        set(["first_layer/w", "second_layer/w", "second_layer/b"]))
Exemple #26
0
    def test_precision(self, precision):
        def f(x):
            return basic.Linear(1)(x, precision=precision)

        f = transform.transform(f)
        rng = jax.random.PRNGKey(42)
        x = np.ones([1, 1])
        params = f.init(rng, x)
        c = jax.xla_computation(lambda x: f.apply(params, None, x))(x)
        hlo = c.as_hlo_text()
        op_line = next(l for l in hlo.split("\n") if "dot(" in l)
        if precision is not None and precision != jax.lax.Precision.DEFAULT:
            name = str(precision).lower()
            self.assertRegex(op_line, f"operand_precision={{{name},{name}}}")
        else:
            self.assertNotIn("operand_precision", op_line)
Exemple #27
0
    def test_precision(self, precision, cls):
        def f(x):
            net = cls(2, output_channels=3, kernel_shape=3, padding="VALID")
            return net(x, precision=precision)

        f = transform.transform(f)
        rng = jax.random.PRNGKey(42)
        x = jnp.zeros([2, 16, 16, 4])
        params = f.init(rng, x)
        c = jax.xla_computation(lambda x: f.apply(params, None, x))(x)
        hlo = c.as_hlo_text()
        op_line = next(l for l in hlo.split("\n") if "convolution(" in l)
        if precision is not None and precision != jax.lax.Precision.DEFAULT:
            name = str(precision).lower()
            self.assertRegex(op_line, f"operand_precision={{{name},{name}}}")
        else:
            self.assertNotIn("operand_precision", op_line)
Exemple #28
0
  def test_sn_naming_scheme(self):
    sn_name = "this_is_a_wacky_but_valid_name"
    linear_name = "so_is_this"

    def f():
      return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6, 6]))

    init_fn, _ = transform.transform(f)
    params = init_fn(random.PRNGKey(428))

    def g(x):
      return spectral_norm.SNParamsTree(ignore_regex=".*b", name=sn_name)(x)
    init_fn, _ = transform.transform_with_state(g)
    _, params_state = init_fn(random.PRNGKey(428), params)

    expected_sn_states = [
        "{}/{}__{}".format(sn_name, linear_name, s) for s in ["w"]]
    self.assertSameElements(expected_sn_states, params_state.keys())
  def test_ema_naming_scheme(self):
    ema_name = "this_is_a_wacky_but_valid_name"
    linear_name = "so_is_this"

    def f():
      return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6]))

    init_fn, _ = transform.transform(f)
    params = init_fn(random.PRNGKey(428))

    def g(x):
      return moving_averages.EMAParamsTree(0.2, name=ema_name)(x)
    init_fn, _ = transform.transform_with_state(g)
    _, params_state = init_fn(None, params)

    expected_ema_states = [
        "{}/{}__{}".format(ema_name, linear_name, s) for s in ["w", "b"]]
    self.assertEqual(set(expected_ema_states), set(params_state.keys()))
Exemple #30
0
  def test_computation_padding_same(self, with_bias):
    def f():
      data = np.ones([1, 3, 3, 1])
      net = conv.Conv2DTranspose(
          output_channels=1,
          kernel_shape=3,
          padding="SAME",
          with_bias=with_bias,
          **create_constant_initializers(1.0, 1.0, with_bias))
      return net(data)

    init_fn, apply_fn = transform.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))
    expected_out = np.array([[4, 6, 4], [6, 9, 6], [4, 6, 4]])
    if with_bias:
      expected_out += 1
    expected_out = np.expand_dims(np.atleast_3d(expected_out), axis=0)
    np.testing.assert_allclose(out, expected_out, rtol=1e-5)