def test_mutate_undefined_collection(self): def f(scope): scope.put_variable('state', 'test', 123) msg = r'Cannot update variable "test" in "/" because collection "state" is immutable.' with self.assertRaisesRegex(errors.ModifyScopeVariableError, msg): init(f, mutable='params')(random.PRNGKey(0))
def test_jit_cache(self): compiles = 0 @lift.jit def f(scope, x): nonlocal compiles compiles += 1 if scope.is_mutable_collection( 'intermediates' ) and not scope.is_mutable_collection('params'): scope.put_variable('intermediates', 'x', x + 1) return nn.dense(scope, x, 1) x = np.ones((3, 2)) _, params = init(f)(random.PRNGKey(0), x) init(f)(random.PRNGKey(0), x) self.assertEqual(compiles, 1) apply(f)(params, x) self.assertEqual(compiles, 2) # apply should cause a compile apply(f)(params, x) self.assertEqual(compiles, 2) # applying again should not # edge case where only the implicit return of the jitted functions changes. # this should not use the previously cached apply. _, state = apply(f, mutable='intermediates')(params, x) self.assertEqual(compiles, 3) # applying again should not self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2)
def test_rng(self): def f(scope): self.assertTrue(scope.has_rng('params')) self.assertFalse(scope.has_rng('dropout')) rng = scope.make_rng('params') self.assertTrue(np.all(rng == random.fold_in(random.PRNGKey(0), 1))) init(f)(random.PRNGKey(0))
def test_mutate_undefined_collection(self): def f(scope): scope.put_variable('state', 'test', 123) msg = 'Trying to update variable "test" in "/" but collection "state" is immutable.' with self.assertRaisesWithLiteralMatch(ValueError, msg): init(f, mutable='params')(random.PRNGKey(0))
def test_aliasing(self): def f(scope): a = scope.push('a') def g(scopes, _): scope, a = scopes self.assertEqual(a.parent, scope) lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,))) init(f)(random.PRNGKey(0))
def test_auto_encoder_bind_method(self): ae = lambda scope, x: AutoEncoder3.create( scope, latents=2, features=4, hidden=3)(x) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual( variable_shapes, { 'encode': { 'hidden': { 'kernel': (4, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 2), 'bias': (2, ) }, }, 'decode': { 'hidden': { 'kernel': (2, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 4), 'bias': (4, ) }, }, })
def test_auto_encoder_hp_struct(self): ae = AutoEncoder(latents=2, features=4, hidden=3) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual( variable_shapes, { 'encoder': { 'hidden': { 'kernel': (4, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 2), 'bias': (2, ) }, }, 'decoder': { 'hidden': { 'kernel': (2, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 4), 'bias': (4, ) }, }, })
def test_attention(self): inputs = jnp.ones((2, 7, 16)) model = partial(multi_head_dot_product_attention, num_heads=2, batch_axes=(0, ), attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False)) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} y, variables = jax.jit(init(model))(rngs, inputs, inputs) variable_shapes = jax.tree_map(jnp.shape, variables['params']) self.assertEqual(y.shape, (2, 7, 16)) self.assertEqual( unfreeze(variable_shapes), { 'key': { 'kernel': (2, 16, 8) }, 'value': { 'kernel': (2, 16, 8) }, 'query': { 'kernel': (2, 16, 8) }, 'out': { 'bias': (2, 16), 'kernel': (2, 8, 16) }, })
def test_explicit_dense(self): x = jnp.ones((1, 3)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 4)) self.assertEqual(param_shapes, { 'kernel': (3, 4), 'bias': (4, ), })
def test_tied_auto_encoder(self): ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) x_r, variables = init(ae)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'kernel': (4, 2), }) self.assertEqual(x.shape, x_r.shape)
def test_init_from_decoder(self): ae = TiedAutoEncoder(latents=2, features=4) z = jnp.ones((1, ae.latents)) x_r, variables = init(ae.decode)(random.PRNGKey(0), z) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'kernel': (4, 2), }) self.assertEqual(x_r.shape, (1, 4))
def test_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': ExplicitDense((4, 3), (3, )), 'dense_1': ExplicitDense((3, 1), (1, )) })
def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) y, variables = init(flow.forward)(random.PRNGKey(0), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 3)) self.assertEqual(param_shapes, { '0': {'kernel': (3, 3), 'bias': (3,)}, '1': {'kernel': (3, 3), 'bias': (3,)}, '2': {'kernel': (3, 3), 'bias': (3,)}, }) x_restored = apply(flow.backward)(variables, y) self.assertTrue(jnp.allclose(x, x_restored))
def test_vmap_unshared(self): x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)}, 'out': {'kernel': (2, 8, 1), 'bias': (2, 1)}, }) self.assertEqual(y.shape, (2, 1)) self.assertFalse(jnp.allclose(y[0], y[1]))
def test_weight_std(self): x = random.normal(random.PRNGKey(0), (1, 4,)) y, variables = init(mlp)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, }) self.assertEqual(y.shape, (1, 1)) self.assertTrue(y.ravel() < 1.) y2 = apply(mlp)(variables, x) self.assertTrue(jnp.allclose(y, y2))
def test_semi_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': { 'kernel': (4, 3), 'bias': (3, ) }, 'dense_1': { 'kernel': (3, 1), 'bias': (1, ) } })
def test_custom_vjp(self): x = random.normal(random.PRNGKey(0), (1, 4)) y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2) grad = jax.grad(loss_fn)(variables, x) grad_shapes = unfreeze( jax.tree_map(jnp.shape, grad['params'])) self.assertEqual(y.shape, (1, 1)) expected_param_shapes = { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, } self.assertEqual(param_shapes, expected_param_shapes) self.assertEqual(grad_shapes, expected_param_shapes) for g in jax.tree_leaves(grad): self.assertTrue(np.all(g == np.sign(g)))
def init_with_output( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = True ) -> Callable[..., Tuple[Any, FrozenVariableDict]]: """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. Unlike ``Module.init_with_output`` this function returns a new function with the signature ``(rngs, *args, **kwargs) -> (T, variables)`` where `T` is the return type of ``fn``. The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is equivalant to passing a dict with one PRNGKey with the name "params". The init function that is returned can be directly composed with JAX transformations like ``jax.jit``:: def f(foo, x): z = foo.encode(x) y = foo.decode(z) # ... return y foo = Foo() f_jitted = jax.jit(nn.init_with_output(f, foo)) y, variables = f_jitted(rng, x) Args: fn: The function that should be applied. The first argument passed will be an module instance of the ``module`` with variables and RNGs bound to it. module: The ``Module`` that will be used to bind variables and RNGs to. The ``Module`` passed as the first argument to ``fn`` will be a clone of module. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. Returns: The init function wrapping ``fn``. """ @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): return fn(module.clone(parent=scope), *args, **kwargs) return core.init(scope_fn, mutable=mutable)
def test_big_resnet(self): x = random.normal(random.PRNGKey(0), (1, 8, 8, 8)) y, variables = init(big_resnet)(random.PRNGKey(1), x) self.assertEqual(y.shape, (1, 8, 8, 8)) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) batch_stats_shapes = unfreeze( jax.tree_map(jnp.shape, variables['batch_stats'])) print(param_shapes) self.assertEqual(param_shapes, { 'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)}, 'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)}, 'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, 'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)} }) self.assertEqual(batch_stats_shapes, { 'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, 'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)} })
x = nn.relu(x) return x def semi_explicit_mlp(scope, x, sizes=(3, 1)): for i, size in enumerate(sizes): dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')(x.shape[-1], size) x = dense(x) if i + 1 < len(sizes): x = nn.relu(x) return x if __name__ == "__main__": model = Dense(features=4) x = jnp.ones((1, 3)) y, params = init(model)(random.PRNGKey(0), x) print(y) print(params) print('explicit dense:') y, params = init(explicit_mlp)(random.PRNGKey(0), x) print(y) print(params) print('semi-explicit dense:') y, params = init(semi_explicit_mlp)(random.PRNGKey(0), x) print(y)
variable_carry='counter', variable_in_axes={'param': lift.broadcast}, variable_out_axes={'param': lift.broadcast}, split_rngs={'param': False})(scope, (), xs) else: carry, ys = lift.scan(body_fn, variable_carry='counter', variable_in_axes={'param': 0}, variable_out_axes={'param': 0}, split_rngs={'param': True})(scope, (), xs) # output layer return carry, ys if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print( 'unshared params: (outputs should be different, parameters has extra dim)' ) y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False) print(y) print(unfreeze(variables)) print('shared params: (outputs should be the same)') y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True) print(y) print(unfreeze(variables))
else: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), variable_axes={'params': 0}, split_rngs={'params': True}) # hidden layers for size in sizes[:-1]: x = scope.child(dense_vmap)(x, size) x = act_fn(x) # output layer return scope.child(dense_vmap)(x, sizes[-1]) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print('shared params: (same inputs, same outputs)') y, params = init(mlp_vmap)(random.PRNGKey(1), x, share_params=True) print(y) print(jax.tree_map(jnp.shape, unfreeze(params))) print( 'unshared params: (sampe inputs, different outputs, extra dim in params)' ) y, params = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False) print(y) print(jax.tree_map(jnp.shape, unfreeze(params)))
return y, x def bwd(features, scope_fn, params, res, g): x = res fn = lambda params, x: nn.dense(scope_fn(params), x, features) _, pullback = jax.vjp(fn, params, x) g_param, g_x = pullback(g) g_param = jax.tree_map(jnp.sign, g_param) return g_param, g_x dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,)) # hidden layers for size in sizes[:-1]: x = scope.child(dense_custom_grad)(x, size) x = act_fn(x) # output layer return scope.child(dense_custom_grad)(x, sizes[-1]) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print('shared params: (same inputs, same outputs)') y, params = init(mlp_custom_grad)(random.PRNGKey(1), x) print(y) print(jax.tree_map(jnp.shape, unfreeze(params))) print(jax.grad(lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x) ** 2))(params, x))
conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3))) x = scope.child(norm, 'init_bn')(x) x = act(x) x = nn.max_pool(x, (2, 2), (2, 2), 'SAME') for i, size in enumerate(block_sizes): for j in range(size): strides = (1, 1) if i > 0 and j == 0: strides = (2, 2) block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') x = residual_block(block_scope, x, conv, norm, act, block_features, strides) # we can access parameters of the sub module by operating on the scope # Example: # block_scope.get_kind('param')['conv_1']['kernel'] x = jnp.mean(x, (1, 2)) x = scope.child(nn.dense, 'out')(x, num_classes) return x if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 224, 224, 3)) y, params = init(resnet)(random.PRNGKey(1), x) print(y.shape) print(jax.tree_map(jnp.shape, unfreeze(params)))
def test_mutate_undefined_collection(self): def f(scope): scope.put_variable('test', 'test', 123) with self.assertRaisesWithLiteralMatch(ValueError, 'Collection is not mutable: "test"'): init(f, mutable='params')(random.PRNGKey(0))
return jnp.dot(x, expm(kernel)) + bias def backward(self, scope: Scope, y: Array): kernel, bias = self.params(scope, y.shape[-1]) return jnp.dot(y - bias, expm(-kernel)) @dataclass class StackFlow: flows: Sequence[Flow] def forward(self, scope: Scope, x: Array): for i, f in enumerate(self.flows): x = scope.child(f.forward, name=str(i))(x) return x def backward(self, scope: Scope, x: Array): for i, f in reversed(tuple(enumerate(self.flows))): x = scope.child(f.backward, name=str(i))(x) return x if __name__ == "__main__": flow = StackFlow((DenseFlow(), ) * 3) # forward and backward are interchangeable here # so shape inference and initialization can be done on the forward and backward pass of the flow y, params = init(flow.forward)(random.PRNGKey(0), jnp.ones((1, 3))) print(params) x_restore = apply(flow.backward)(params, y) print(x_restore)
def _tied(self, fn, transpose=False): if not transpose: return fn def trans(variables): if 'param' not in variables: return variables params = variables['param'] params['kernel'] = params['kernel'].T return variables return lift.transform_module( fn, trans_in_fn=trans, trans_out_fn=trans) if __name__ == "__main__": ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) x_r, params = init(ae)(random.PRNGKey(0), x) print(x, x_r) print(params) print('init from decoder:') z = jnp.ones((1, ae.latents)) x_r, params = init(ae.decode)(random.PRNGKey(0), z) print(apply(ae)(params, x)) print(params)
# the transformed kind will be immutable inside fn # this way we avoid lost mutations to param # transform also avoids accidental reuse of rngs # and it makes sure that other state is updated correctly (not twice during init!) return lift.transform_module(fn, trans_in_fn=std) def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (2, 4, 1), act_fn: Callable[[Array], Array] = nn.relu): std_dense = weight_std( partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5))) # hidden layers for size in sizes[:-1]: x = scope.child(std_dense, prefix='hidden_')(x, size) # x = act_fn(x) # output layer return scope.child(nn.dense, 'out')(x, sizes[-1]) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), ( 1, 4, )) y, params = init(mlp)(random.PRNGKey(1), x) print(y) print(jax.tree_map(jnp.shape, unfreeze(params)))
for axis in reversed(sorted(batch_axes)): attn_fn = lift.vmap(attn_fn, in_axes=(axis, axis, axis), out_axes=axis, variable_axes={'params': None}, split_rngs={ 'params': False, 'dropout': not broadcast_dropout }) y = attn_fn(scope, inputs_q, inputs_kv, bias) return y.mean(axis=-2) if __name__ == "__main__": inputs = jnp.ones((2, 7, 16)) y, variables = init(multi_head_dot_product_attention)( { 'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1) }, inputs, inputs, num_heads=2, batch_axes=(0, ), attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False)) print(y.shape) print(jax.tree_map(jnp.shape, variables))
def init_fun(rng, *args, **kwargs): return init(core_fun)(rng, *args, **kwargs)