def test_parameter_in_apply(self, params): _, apply_fn = base.transform( lambda: base.get_parameter("w", [], init=jnp.zeros)) with self.assertRaisesRegex( ValueError, "parameters must be created as part of `init`"): apply_fn(params)
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = base.without_apply_rng( base.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_connect_conv_same(self, n): batch_size = 2 input_shape = [16] * n input_shape_ = [batch_size] + input_shape + [4] state_shape_ = [batch_size] + input_shape + [3] data_ = jnp.zeros(input_shape_) state_init_ = (jnp.zeros(state_shape_), jnp.zeros(state_shape_)) def f(data, state_init): net = recurrent.ConvNDLSTM(n, input_shape=input_shape, output_channels=3, kernel_shape=3) return net(data, state_init) init_fn, apply_fn = base.transform(f) params = init_fn(random.PRNGKey(42), data_, state_init_) out, state = apply_fn(params, data_, state_init_) expected_output_shape = (batch_size, ) + tuple(input_shape) + (3, ) self.assertEqual(out.shape, expected_output_shape) self.assertEqual(state[0].shape, expected_output_shape) self.assertEqual(state[1].shape, expected_output_shape)
def test_functionalize(self): def inner_fn(x): assert x.ndim == 1 return Bias()(x) def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = base.without_apply_rng(base.transform_with_state(inner_fn)) inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0)) return vmap_inner(inner_p, inner_s, x)[0] key = jax.random.PRNGKey(428) init_key, apply_key = jax.random.split(key) data = np.zeros((3, 2)) outer = base.transform(outer_fn, apply_rng=True) outer_params = outer.init(init_key, data) self.assertEqual( outer_params, { "bias": { "b": np.ones(()) }, "lifted_init_fn": { "sentinel": np.zeros(()) }, "lifted_init_fn/bias": { "b": np.ones(()) }, }) out = outer.apply(outer_params, apply_key, data) np.testing.assert_equal(out, 2 * np.ones((3, 2)))
def test_used_inside_transform(self): log = [] def counting_creator(next_creator, name, shape, dtype, init): log.append(name) return next_creator(name, shape, dtype, init) def net(): with base.custom_creator(counting_creator): return MultipleForwardMethods()() init_fn, apply_fn = base.transform(net) params = init_fn(None) self.assertEqual( log, [ "multiple_forward_methods/~/scalar_module/w", # __init__ "multiple_forward_methods/scalar_module/w", # __call__ "multiple_forward_methods/~encode/scalar_module/w", # encode "multiple_forward_methods/~decode/scalar_module/w", # decode ]) del log[:] apply_fn(params) self.assertEmpty(log)
def test_convolution(self, with_bias): def f(): data = np.ones([1, 10, 10, 3]) data[0, :, :, 1] += 1 data[0, :, :, 2] += 2 data = jnp.array(data) net = depthwise_conv.DepthwiseConv2D( channel_multiplier=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, data_format="channels_last", **create_constant_initializers(1.0, 1.0, with_bias)) return net(data) init_fn, apply_fn = base.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) self.assertEqual(out.shape, (1, 8, 8, 3)) self.assertLen(np.unique(out[0, :, :, 0]), 1) self.assertLen(np.unique(out[0, :, :, 1]), 1) self.assertLen(np.unique(out[0, :, :, 2]), 1) if with_bias: self.assertEqual(np.unique(out[0, :, :, 0])[0], 1 * 3.0 * 3.0 + 1) self.assertEqual(np.unique(out[0, :, :, 1])[0], 2 * 3.0 * 3.0 + 1) self.assertEqual(np.unique(out[0, :, :, 2])[0], 3 * 3.0 * 3.0 + 1) else: self.assertEqual(np.unique(out[0, :, :, 0])[0], 1 * 3.0 * 3.0) self.assertEqual(np.unique(out[0, :, :, 1])[0], 2 * 3.0 * 3.0) self.assertEqual(np.unique(out[0, :, :, 2])[0], 3 * 3.0 * 3.0)
def test_computation_padding_valid(self, with_bias): expected_out = [[9, 9, 9], [9, 9, 9], [9, 9, 9]] def f(): data = jnp.ones([1, 5, 5, 1]) net = conv.Conv2D( output_channels=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) return net(data) init_fn, apply_fn = base.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) self.assertEqual(out.shape, (1, 3, 3, 1)) out = np.squeeze(out, axis=(0, 3)) expected_out = np.asarray(expected_out, dtype=float) if with_bias: expected_out += 1 np.testing.assert_equal(out, expected_out)
def test_bias_dims_negative_out_of_order(self): def f(): mod = bias.Bias(bias_dims=[-1, -2]) mod(jnp.ones([1, 2, 3])) self.assertEqual(mod.bias_shape, (2, 3)) params = base.transform(f).init(None) self.assertEqual(params["bias"]["b"].shape, (2, 3))
def test_computation_padding_same(self, with_bias): expected_out = np.asarray([ 9, 13, 13, 13, 9, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 9, 13, 13, 13, 9, 13, 19, 19, 19, 13, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 13, 19, 19, 19, 13, 9, 13, 13, 13, 9, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 9, 13, 13, 13, 9 ], dtype=float).reshape((5, 5, 5)) if not with_bias: expected_out -= 1 def f(): data = jnp.ones([1, 5, 5, 5, 1]) net = conv.Conv3D( output_channels=1, kernel_shape=3, stride=1, padding="SAME", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) return net(data) init_fn, apply_fn = base.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) self.assertEqual(out.shape, (1, 5, 5, 5, 1)) out = np.squeeze(out, axis=(0, 4)) np.testing.assert_equal(out, expected_out)
def testIncorrectN(self, n): init_fn, _ = base.transform( lambda: conv.ConvND(n, output_channels=1, kernel_shape=3)) with self.assertRaisesRegex( ValueError, "only support convolution operations for num_spatial_dims=1, 2 or 3"): init_fn(None)
def test_computation_padding_valid(self, with_bias): expected_out = np.asarray([ 1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 4, 8, 12, 8, 4, 6, 12, 18, 12, 6, 4, 8, 12, 8, 4, 2, 4, 6, 4, 2, 3, 6, 9, 6, 3, 6, 12, 18, 12, 6, 9, 18, 27, 18, 9, 6, 12, 18, 12, 6, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 4, 8, 12, 8, 4, 6, 12, 18, 12, 6, 4, 8, 12, 8, 4, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1. ]).reshape((5, 5, 5)) if with_bias: expected_out += 1 def f(): data = jnp.ones([1, 3, 3, 3, 1]) net = conv.Conv3DTranspose( output_channels=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) return net(data) init_fn, apply_fn = base.transform(f) out = apply_fn(init_fn(random.PRNGKey(428))) self.assertEqual(out.shape, (1, 5, 5, 5, 1)) out = np.squeeze(out, axis=(0, 4)) np.testing.assert_equal(out, expected_out)
def test_flatten(self): def f(): return reshape.Flatten(preserve_dims=2)(jnp.zeros([2, 3, 4, 5])) init_fn, apply_fn = base.transform(f) params = init_fn(None) self.assertEqual(apply_fn(params).shape, (2, 3, 20))
def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = base.transform(inner_fn, state=True) inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0)) return vmap_inner(inner_p, inner_s, x)[0]
def test_tree_update_stats(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): """This should never update internal stats.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=False) init_fn, apply_fn_g = base.transform(g, state=True) _, params_state = init_fn(None, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_g(None, params_state, changed_params) ema_params2, params_state = apply_fn_g(None, params_state, changed_params) # ema_params should be the same as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) np.testing.assert_allclose(p1, p2) def h(x): """This will behave like normal.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=True) init_fn, apply_fn_h = base.transform(h, state=True) _, params_state = init_fn(None, params) params, params_state = apply_fn_h(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_h(None, params_state, changed_params) ema_params2, params_state = apply_fn_h(None, params_state, changed_params) # ema_params should be different as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_params_nested(self): init_fn, _ = base.transform(lambda: MultipleForwardMethods(name="outer")()) # pylint: disable=unnecessary-lambda params = init_fn(None) self.assertEqual(params, {"outer/~/scalar_module": {"w": jnp.zeros([])}, "outer/scalar_module": {"w": jnp.zeros([])}, "outer/~encode/scalar_module": {"w": jnp.zeros([])}, "outer/~decode/scalar_module": {"w": jnp.zeros([])}})
def test_invalid_rng(self): f = base.transform(lambda: None, apply_rng=True) with self.assertRaisesRegex( ValueError, "Init must be called with an RNG as the first argument"): f.init("nonsense") with self.assertRaisesRegex( ValueError, "Apply must be called with an RNG as the second argument"): f.apply({}, "nonsense")
def test_invalid_multiple_wildcard(self): def f(): mod = reshape.Reshape(output_shape=[-1, -1]) return mod(np.ones([1, 2, 3])) init_fn, _ = base.transform(f) with self.assertRaises(ValueError): init_fn(None)
def test_sequential(self): def f(): seq = basic.Sequential([basic.Linear(2), jax.nn.relu]) return seq(jnp.zeros([3, 2])) init_fn, apply_fn = base.transform(f) params = init_fn(random.PRNGKey(428)) self.assertEqual(apply_fn(params).shape, (3, 2))
def test_params(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w init_fn, _ = base.transform(f) params = init_fn(None) self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
def test_invalid_type(self): def f(): mod = reshape.Reshape(output_shape=[7, "string"]) return mod(np.ones([1, 2, 3])) init_fn, _ = base.transform(f) with self.assertRaises(TypeError): init_fn(None)
def test_linear_rank3(self): def f(): return basic.Linear(output_size=2)(jnp.zeros((2, 5, 6))) init_fn, apply_fn = base.transform(f) params = init_fn(random.PRNGKey(428)) self.assertEqual(params.linear.w.shape, (6, 2)) self.assertEqual(params.linear.b.shape, (2,)) self.assertEqual(apply_fn(params).shape, (2, 5, 2))
def test_linear_without_bias_has_zero_in_null_space(self): def f(): return basic.Linear(output_size=6, with_bias=False)(jnp.zeros((5, 6))) init_fn, apply_fn = base.transform(f) params = init_fn(random.PRNGKey(428)) self.assertEqual(params.linear.w.shape, (6, 6)) self.assertFalse(hasattr(params.linear, "b")) np.testing.assert_array_almost_equal(apply_fn(params), jnp.zeros((5, 6)))
def test_reshape(self, preserve_dims, expected_output_shape): def f(inputs): return reshape.Reshape(output_shape=(-1, D), preserve_dims=preserve_dims)(inputs) init_fn, apply_fn = base.transform(f) params = init_fn(None, jnp.ones([B, H, W, C, D])) outputs = apply_fn(params, np.ones([B, H, W, C, D])) self.assertEqual(outputs.shape, expected_output_shape)
def test_transforms_with_filer(self): # Note to make sense of test: # # out = (w0 + b0) * w1 + b1 # = w0 * w1 + b0 * w1 + b1 # doutdw0 = w1 # doutdw1 = w0 + b0 # with w0 = 1.0, b0 = 1.5, w1 = 3.0, b1 = 4.5 init_fn, apply_fn = base.transform(get_net) inputs = jnp.ones((1, 1)) params = init_fn(jax.random.PRNGKey(428), inputs) df_fn = jax_fn_with_filter( jax_fn=jax.grad, f=apply_fn, predicate=lambda module_name, name, _: name == "w") df = df_fn(params, inputs) self.assertEqual( to_set(df), set([("first_layer/w", (3.0, )), ("second_layer/w", (2.5, ))])) fn = jax_fn_with_filter( jax_fn=jax.value_and_grad, f=apply_fn, predicate=lambda module_name, name, _: name == "w") v = fn(params, inputs) self.assertEqual(v[0], jnp.array([12.0])) self.assertEqual(to_set(df), to_set(v[1])) def get_stacked_net(x): y = get_net(x) return jnp.stack([y, 2.0 * y]) _, apply_fn = base.transform(get_stacked_net) jf_fn = jax_fn_with_filter( jax_fn=jax.jacobian, f=apply_fn, predicate=lambda module_name, name, _: name == "w") jf = jf_fn(params, inputs) self.assertEqual( to_set(jf), set([("first_layer/w", (3.0, 6.0)), ("second_layer/w", (2.5, 5.0))]))
def test_naked_parameter_in_tilde_collection(self): def net(): w1 = base.get_parameter("w1", [], init=jnp.zeros) w2 = base.get_parameter("w2", [], init=jnp.ones) self.assertIsNot(w1, w2) init_fn, _ = base.transform(net) params = init_fn(None) self.assertEqual(params, {"~": {"w1": jnp.zeros([]), "w2": jnp.ones([])}})
def test_inline_use(self): def f(): return ScalarModule()() f = base.transform(f) rng = jax.random.PRNGKey(42) params = f.init(rng) w = f.apply(params) self.assertEqual(w, 0)
def test_grad_and_jit(self): def f(x): g = stateful.grad(SquareModule())(x) return g x = jnp.array(3.) f = base.transform(f, state=True) params, state = jax.jit(f.init)(None, x) g, state = jax.jit(f.apply)(params, state, x) np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): with base.custom_creator(mutates_name): init_fn(None)
def test_inline_use(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w f = base.transform(f) rng = jax.random.PRNGKey(42) params = f.init(rng) w = f.apply(params) self.assertEqual(w, 0)
def test_with_rng(self, seed): key = jax.random.PRNGKey(seed) unrelated_key = jax.random.PRNGKey(seed * 2 + 1) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) def without_decorator(): return jax.random.uniform(base.next_rng_key(), ()) without_decorator = base.transform(without_decorator, apply_rng=True) without_decorator_out = without_decorator.apply(None, unrelated_key).item() def with_decorator(): with base.with_rng(key): return jax.random.uniform(base.next_rng_key(), ()) with_decorator = base.transform(with_decorator, apply_rng=True) with_decorator_out = with_decorator.apply(None, unrelated_key).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output)