def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, shape, dtype, init, context): del context log.append(log_msg) return next_creator(shape, dtype, init) return _logging_creator def f(): a, b, c = map(logging_creator, ["a", "b", "c"]) with base.custom_creator(a), \ base.custom_creator(b), \ base.custom_creator(c): return base.get_parameter("w", [], init=jnp.ones) transform.transform(f).init(None) self.assertEqual(log, ["a", "b", "c"])
def test_optimize_rng_splitting(self): def f(): k1 = base.next_rng_key() k2 = base.next_rng_key() return k1, k2 key = jax.random.PRNGKey(42) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) # With optimize_rng_use the keys returned should be equal to split(n). f_opt = transform.transform(random.optimize_rng_use(f), apply_rng=True) jax.tree_multimap(assert_allclose, f_opt.apply({}, key), tuple(jax.random.split(key, 3))[1:]) # Without optimize_rng_use the keys should be equivalent to splitting in a # loop. f = transform.transform(f, apply_rng=True) jax.tree_multimap(assert_allclose, f.apply({}, key), tuple(split_for_n(key, 2)))
def test_jax_transformed_wrapper(self, jax_transform): # Happens in practice if someone asks for a `summary(pmap(train_step))` f = lambda: CallsOtherModule(MultipleParametersModule())() f = transform.transform(f) rng = jax.random.PRNGKey(42) if jax_transform == jax.pmap: rng = jnp.broadcast_to(rng, (1, *rng.shape)) params = jax_transform(f.init)(rng) g = jax_transform(lambda params, rng: f.apply(params, rng)) rows = tabulate_to_list(g, params, rng) self.assertNotEmpty(rows)
def test_inline_use(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w f = transform.transform(f) rng = jax.random.PRNGKey(42) params = f.init(rng) w = f.apply(params) self.assertEqual(w, 0)
def test_with_empty_state(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w init_fn, apply_fn = transform.with_empty_state(transform.transform(f)) params, state = init_fn(None) self.assertEmpty(state) out, state = apply_fn(params, state, None) self.assertEqual(out, 0) self.assertEmpty(state)
def test_linear_without_bias_has_zero_in_null_space(self): def f(): return basic.Linear(output_size=6, with_bias=False)(jnp.zeros( (5, 6))) init_fn, apply_fn = transform.transform(f) params = init_fn(random.PRNGKey(428)) self.assertEqual(params["linear"]["w"].shape, (6, 6)) self.assertNotIn("b", params["linear"]) np.testing.assert_array_almost_equal(apply_fn(params, None), jnp.zeros((5, 6)))
def test_linear_without_bias_has_zero_in_null_space(self): def f(): return basic.Linear(output_size=6, with_bias=False)(jnp.zeros( (5, 6))) init_fn, apply_fn = transform.transform(f) params = init_fn(random.PRNGKey(428)) self.assertEqual(params.linear.w.shape, (6, 6)) self.assertFalse(hasattr(params.linear, "b")) np.testing.assert_array_almost_equal(apply_fn(params), jnp.zeros( (5, 6)))
def test_with_rng(self, seed): key = jax.random.PRNGKey(seed) unrelated_key = jax.random.PRNGKey(seed * 2 + 1) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) def without_decorator(): return jax.random.uniform(base.next_rng_key(), ()) without_decorator = transform.transform(without_decorator, apply_rng=True) without_decorator_out = without_decorator.apply(None, unrelated_key).item() def with_decorator(): with base.with_rng(key): return jax.random.uniform(base.next_rng_key(), ()) with_decorator = transform.transform(with_decorator, apply_rng=True) with_decorator_out = with_decorator.apply(None, unrelated_key).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output)
def test_bias_dims_custom(self): b, d1, d2, d3 = range(1, 5) def f(): mod = bias.Bias(bias_dims=[1, 3]) out = mod(jnp.ones([b, d1, d2, d3])) self.assertEqual(mod.bias_shape, (d1, 1, d3)) return out f = transform.transform(f) params = f.init(None) out = f.apply(params, None) self.assertEqual(params["bias"]["b"].shape, (d1, 1, d3)) self.assertEqual(out.shape, (b, d1, d2, d3))
def test_connect_conv_transpose_strided(self, n): def f(): input_shape = [2] + [8]*n + [4] data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, stride=3) return net(data) init_fn, apply_fn = transform.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) expected_output_shape = (2,) + (24,)*n + (3,) self.assertEqual(out.shape, expected_output_shape)
def test_bf16(self): """For all configurations, ensure bf16 outputs from bf16 inputs.""" def f(x): ln = rms_norm.RMSNorm(axis=-1) return ln(x) fwd = transform.transform(f) data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16) params = fwd.init(jax.random.PRNGKey(428), data) bf16_params = jax.tree_map(lambda t: t.astype(jnp.bfloat16), params) self.assertEqual( fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16)
def test_lift_naming_semantics(self, inner_module): @transform.transform def fn(x): return with_transparent_lift(inner_module)(x) x = jnp.ones([10, 10]) params_with_lift = fn.init(None, x) params_without_lift = transform.transform(inner_module).init(None, x) jax.tree_map(self.assertAlmostEqual, params_with_lift, params_without_lift) fn.apply(params_with_lift, None, x)
def test_diluted_conv(self, n): input_shape = [2] + [16]*n + [4] def f(): data = jnp.zeros(input_shape) net = conv.ConvND(n, output_channels=3, kernel_shape=3, rate=3) return net(data) init_fn, apply_fn = transform.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) expected_output_shape = (2,) + (16,)*n + (3,) self.assertEqual(out.shape, expected_output_shape)
def test_connect_conv_transpose_valid(self, n): def f(): input_shape = [2] + [16]*n + [4] data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, padding="VALID") return net(data) init_fn, apply_fn = transform.transform(f) out = apply_fn(init_fn(random.PRNGKey(428)), None) expected_output_shape = (2,) + (18,)*n + (3,) self.assertEqual(out.shape, expected_output_shape)
def __call__(self, x): x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros) def inner_fn(x): return InnerModule(name="inner")(x) inner_transformed = transform.transform(inner_fn) inner_params = lift.transparent_lift(inner_transformed.init)( base.next_rng_key(), x) x = inner_transformed.apply(inner_params, base.next_rng_key(), x) return x
def test_connect_conv_transpose_channels_first(self, n): def f(): input_shape = [2, 4] + [16]*n data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, data_format="channels_first") return net(data) init_fn, apply_fn = transform.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) expected_output_shape = (2, 3) + (16,)*n self.assertEqual(out.shape, expected_output_shape)
def test_lift_with_scan(self): def inner_fn(x): x *= base.get_parameter("w", shape=x.shape, init=jnp.zeros) return x class Outer(module.Module): def __init__(self, allow_reuse): super().__init__() self._allow_reuse = allow_reuse def __call__(self, carry, x): x += base.get_parameter("w", shape=[], init=jnp.zeros) inner = transform.transform(inner_fn) keys = base.next_rng_key() if transform.running_init( ) else None params = lift.lift(inner.init, allow_reuse=self._allow_reuse)(keys, x) return carry, inner.apply(params, None, x) def model(x, *, allow_reuse): return stateful.scan(Outer(allow_reuse), (), x) rng = jax.random.PRNGKey(42) data = np.zeros((4, 3, 2)) with self.subTest(name="allow_reuse"): init, apply = transform.transform( lambda x: model(x, allow_reuse=True)) params = init(rng, data) _, out = apply(params, None, data) np.testing.assert_equal(out, np.zeros_like(data)) with self.subTest(name="disallow_reuse"): init, _ = transform.transform( lambda x: model(x, allow_reuse=False)) with self.assertRaisesRegex(ValueError, "Key '.*' already exists"): _ = init(rng, data)
def test_layer_stack_multi_args(self): """Compare layers_stack to the equivalent unrolled stack. Similar to `test_layer_stack`, but use a function that takes more than one argument. """ num_layers = 20 def inner_fn(x, y): x_out = x + basic.Linear(100, name="linear1")(y) y_out = y + basic.Linear(100, name="linear2")(x) return x_out, y_out def outer_fn_unrolled(x, y): for _ in range(num_layers): x, y = inner_fn(x, y) return x, y def outer_fn_layer_stack(x, y): stack = layer_stack.layer_stack(num_layers)(inner_fn) return stack(x, y) unrolled_fn = transform.transform(outer_fn_unrolled) layer_stack_fn = transform.transform(outer_fn_layer_stack) x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100]) y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100]) rng_init = jax.random.PRNGKey(42) params = layer_stack_fn.init(rng_init, x, y) sliced_params = _slice_layers_params(params) unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y) layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y) np.testing.assert_allclose(unrolled_x, layer_stack_x, atol=1e-3) np.testing.assert_allclose(unrolled_y, layer_stack_y, atol=1e-3)
def test_local_stats(self, resnet_v2, bottleneck): def forward_fn(image): model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) return model(image, is_training=False, test_local_stats=True) forward = transform.transform(forward_fn, apply_rng=True) rng = jax.random.PRNGKey(42) image = jnp.ones([2, 64, 64, 3]) params = forward.init(rng, image) logits = forward.apply(params, None, image) self.assertEqual(logits.shape, (2, 10))
def test_naked_parameter_in_tilde_collection(self): def net(): w1 = base.get_parameter("w1", [], init=jnp.zeros) w2 = base.get_parameter("w2", [], init=jnp.ones) self.assertIsNot(w1, w2) init_fn, _ = transform.transform(net) params = init_fn(None) self.assertEqual(params, {"~": { "w1": jnp.zeros([]), "w2": jnp.ones([]) }})
def test_filter(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) second_layer_params = filtering.filter( lambda module_name, *_: module_name == "second_layer", params) self.assertEqual(get_names(second_layer_params), set(["second_layer/w", "second_layer/b"])) biases = filtering.filter(lambda module_name, name, _: name == "b", params) # pytype: disable=wrong-arg-types self.assertEqual(get_names(biases), set(["first_layer/b", "second_layer/b"]))
def test_init_custom_creator(self): def zeros_creator(next_creator, shape, dtype, init, context): self.assertEqual(context.full_name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(shape, dtype, jnp.zeros) def f(): with base.custom_creator(zeros_creator): return base.get_parameter("w", [], init=jnp.ones) params = transform.transform(f).init(None) self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
def test_tree_update_stats(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): """This should never update internal stats.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=False) init_fn, apply_fn_g = transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_g(None, params_state, changed_params) ema_params2, params_state = apply_fn_g(None, params_state, changed_params) # ema_params should be the same as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) np.testing.assert_allclose(p1, p2) def h(x): """This will behave like normal.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=True) init_fn, apply_fn_h = transform.without_apply_rng( transform.transform_with_state(h)) _, params_state = init_fn(None, params) params, params_state = apply_fn_h(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_h(None, params_state, changed_params) ema_params2, params_state = apply_fn_h(None, params_state, changed_params) # ema_params should be different as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_bf16(self, create_scale, create_offset, use_fast_variance): """For all configurations, ensure bf16 outputs from bf16 inputs.""" def f(x): ln = layer_norm.LayerNorm(axis=-1, create_scale=create_scale, create_offset=create_offset, use_fast_variance=use_fast_variance) return ln(x) fwd = transform.transform(f) data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16) params = fwd.init(jax.random.PRNGKey(428), data) bf16_params = jax.tree_map(lambda t: t.astype(jnp.bfloat16), params) self.assertEqual( fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16)
def test_partitioning(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) # parse by layer first_layer_params, second_layer_params = filtering.partition( lambda module_name, *_: module_name == "first_layer", params) self.assertEqual( get_names(first_layer_params), set(["first_layer/w", "first_layer/b"])) self.assertEqual( get_names(second_layer_params), set(["second_layer/w", "second_layer/b"])) # parse by variable type weights, biases = filtering.partition( lambda module_name, name, _: name == "w", params) # pytype: disable=wrong-arg-types self.assertEqual( get_names(weights), set(["first_layer/w", "second_layer/w"])) self.assertEqual( get_names(biases), set(["first_layer/b", "second_layer/b"])) # Compose regexes regex = compile_regex(["first_layer.*", ".*w"]) matching, not_matching = filtering.partition( lambda module_name, name, _: regex.match(f"{module_name}/{name}"), params) self.assertEqual( get_names(matching), set(["first_layer/w", "first_layer/b", "second_layer/w"])) self.assertEqual( get_names(not_matching), set(["second_layer/b"])) matching, not_matching = filtering.partition( lambda mod_name, name, _: mod_name == "first_layer" and name != "w", params) self.assertEqual( get_names(matching), set(["first_layer/b"])) self.assertEqual( get_names(not_matching), set(["first_layer/w", "second_layer/w", "second_layer/b"]))
def test_precision(self, precision): def f(x): return basic.Linear(1)(x, precision=precision) f = transform.transform(f) rng = jax.random.PRNGKey(42) x = np.ones([1, 1]) params = f.init(rng, x) c = jax.xla_computation(lambda x: f.apply(params, None, x))(x) hlo = c.as_hlo_text() op_line = next(l for l in hlo.split("\n") if "dot(" in l) if precision is not None and precision != jax.lax.Precision.DEFAULT: name = str(precision).lower() self.assertRegex(op_line, f"operand_precision={{{name},{name}}}") else: self.assertNotIn("operand_precision", op_line)
def test_precision(self, precision, cls): def f(x): net = cls(2, output_channels=3, kernel_shape=3, padding="VALID") return net(x, precision=precision) f = transform.transform(f) rng = jax.random.PRNGKey(42) x = jnp.zeros([2, 16, 16, 4]) params = f.init(rng, x) c = jax.xla_computation(lambda x: f.apply(params, None, x))(x) hlo = c.as_hlo_text() op_line = next(l for l in hlo.split("\n") if "convolution(" in l) if precision is not None and precision != jax.lax.Precision.DEFAULT: name = str(precision).lower() self.assertRegex(op_line, f"operand_precision={{{name},{name}}}") else: self.assertNotIn("operand_precision", op_line)
def test_sn_naming_scheme(self): sn_name = "this_is_a_wacky_but_valid_name" linear_name = "so_is_this" def f(): return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6, 6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return spectral_norm.SNParamsTree(ignore_regex=".*b", name=sn_name)(x) init_fn, _ = transform.transform_with_state(g) _, params_state = init_fn(random.PRNGKey(428), params) expected_sn_states = [ "{}/{}__{}".format(sn_name, linear_name, s) for s in ["w"]] self.assertSameElements(expected_sn_states, params_state.keys())
def test_ema_naming_scheme(self): ema_name = "this_is_a_wacky_but_valid_name" linear_name = "so_is_this" def f(): return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2, name=ema_name)(x) init_fn, _ = transform.transform_with_state(g) _, params_state = init_fn(None, params) expected_ema_states = [ "{}/{}__{}".format(ema_name, linear_name, s) for s in ["w", "b"]] self.assertEqual(set(expected_ema_states), set(params_state.keys()))
def test_computation_padding_same(self, with_bias): def f(): data = np.ones([1, 3, 3, 1]) net = conv.Conv2DTranspose( output_channels=1, kernel_shape=3, padding="SAME", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) return net(data) init_fn, apply_fn = transform.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) expected_out = np.array([[4, 6, 4], [6, 9, 6], [4, 6, 4]]) if with_bias: expected_out += 1 expected_out = np.expand_dims(np.atleast_3d(expected_out), axis=0) np.testing.assert_allclose(out, expected_out, rtol=1e-5)