def test_toplevel_submodule_adoption_sharing(self): dense = functools.partial(nn.Dense, use_bias=False) class A(nn.Module): @nn.compact def __call__(self, x): return dense(2)(x) class B(nn.Module): a: nn.Module @nn.compact def __call__(self, x): return dense(2)(x) + self.a(x) class C(nn.Module): a: nn.Module b: nn.Module @nn.compact def __call__(self, x): return dense(2)(x) + self.b(x) + self.a(x) key = random.PRNGKey(0) x = jnp.ones((2, 2)) a = A() b = B(a) c = C(a, b) p = c.init(key, x) var_shapes = jax.tree_map(jnp.shape, p) ref_var_shapes = freeze({ 'params': { 'Dense_0': { 'kernel': (2, 2), }, 'a': { 'Dense_0': { 'kernel': (2, 2), }, }, 'b': { 'Dense_0': { 'kernel': (2, 2), }, }, }, }) self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
def test_partially_applied_module_constructor_transform(self): k = random.PRNGKey(0) x = jnp.ones((3,4,4)) dense = partial(nn.Dense, use_bias=False) vmap_dense = nn.vmap( dense, variable_axes={'params':0}, split_rngs={'params':True})(4) init_vars = vmap_dense.init(k, x) init_vars_shapes = jax.tree_map(jnp.shape, init_vars) ref_var_shapes = freeze({ 'params': { 'kernel': (3, 4, 4), }, }) self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))
def train_full( update_fn, validation_fn, mask_update_fn, optimizer, state, max_epochs=1e4, convergence_args={ "patience": 200, "delta": 1e-3 }, mask_update_args={ "patience": 500, "delta": 1e-5, "periodicity": 200 }, ): logger = Logger() converged = Convergence(**convergence_args) update_mask = mask_scheduler(**mask_update_args) for epoch in jnp.arange(max_epochs): (optimizer, state), metrics, output = update_fn(optimizer, state) prediction, dt, theta, coeffs = output if epoch % 1000 == 0: print(f"Loss step {epoch}: {metrics['loss']}") if epoch % 25 == 0: val_metric = validation_fn(optimizer, state) metrics = {**metrics, "validation_metric": val_metric} logger.write(metrics, epoch) apply_sparsity, optimizer = update_mask(val_metric, epoch, optimizer) if apply_sparsity: mask = mask_update_fn(theta, dt) state = freeze({"vars": {"LeastSquares_0": {"mask": mask}}}) if converged(epoch, coeffs): mask = mask_update_fn(theta, dt) print(f"Converged at epoch {epoch} with mask {mask[:, None]}.") break logger.close() return optimizer, state
def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return c, x class B(nn.Module): A: Any @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) a = A() As = {'foo': A(), 'bar': A()} b = nn.scan(B, in_axes=0, variable_carry='counter', variable_broadcast='params', split_rngs={'params': False})(As) key = random.PRNGKey(0) x = jnp.ones((10, 2)) p = B(As).init(key, x, x) y, cntrs = b.apply(p, x, x, mutable='counter') ref_cntrs = freeze({ 'counter': { 'A_bar': { 'i': jnp.array(11.0), }, 'A_foo': { 'i': jnp.array(11.0), }, }, }) self.assertTrue( jax.tree_util.tree_all( jax.tree_multimap( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), cntrs, ref_cntrs)))
def test_toplevel_submodule_adoption_pytree(self): class A(nn.Module): @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return c, x class B(nn.Module): A: Any @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) a = A() As = {'foo': A(), 'bar': A()} b = B(As) key = random.PRNGKey(0) x = jnp.ones((2, 2)) p = B(As).init(key, x, x) print('apply', x.shape) y, cntrs = b.apply(p, x, x, mutable='counter') ref_cntrs = freeze({ 'counter': { 'A_bar': { 'i': jnp.array(2.0), }, 'A_foo': { 'i': jnp.array(2.0), }, }, }) self.assertTrue( jax.tree_util.tree_all( jax.tree_multimap( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), cntrs, ref_cntrs)))
def test_module_pass_in_closure(self): a = nn.Dense(2) class B(nn.Module): def setup(self): self.foo = a def __call__(self, x): return self.foo(x) variables = B().init(random.PRNGKey(0), jnp.ones((1,))) var_shapes = jax.tree_map(jnp.shape, variables) ref_var_shapes = freeze({ 'params': { 'foo': { 'bias': (2,), 'kernel': (1, 2), } }, }) self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) self.assertEqual(a.name, None)
def test_toplevel_named_submodule_adoption(self): dense = functools.partial(nn.Dense, use_bias=False) class A(nn.Module): def setup(self): self.dense = dense(4) def __call__(self, x): return self.dense(x) class B(nn.Module): a: A def setup(self): self.proj = dense(6) def __call__(self, x): return self.proj(self.a(x)) a = A(name='foo') b = B(a=a) k = jax.random.PRNGKey(0) x = jnp.zeros((5, 5)) init_vars = b.init(k, x) var_shapes = jax.tree_map(jnp.shape, init_vars) ref_var_shapes = freeze({ 'params': { 'a': { 'dense': { 'kernel': (5, 4), }, }, 'proj': { 'kernel': (4, 6), }, }, }) self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
def test_same(self, partial_model_pair, hilbert, dtype, machine_pow, skip): batch_size = 3 model1 = partial_model_pair[0](hilbert, dtype, machine_pow) model2 = partial_model_pair[1](hilbert, dtype, machine_pow) key_spins, key_model = jax.random.split(jax.random.PRNGKey(0)) spins = hilbert.random_state(key_spins, size=batch_size) variables = model2.init(key_model, spins, 0, method=model2._conditional) p1 = model1.apply(variables, spins, method=model1.conditionals) p2 = model2.apply(variables, spins, method=model2.conditionals) # Results from `FastARNN*.conditionals` should be the same as those from `ARNN*.conditionals` np.testing.assert_allclose(p2, p1) p3 = jnp.zeros_like(p1) params = variables["params"] cache = variables["cache"] for i in range(hilbert.size): variables = freeze({"params": params, "cache": cache}) p_i, mutables = model2.apply( variables, spins, i, method=model2._conditional, mutable=["cache"], ) cache = mutables["cache"] p3 = p3.at[:, i, :].set(p_i) # Results from `FastARNN*.conditional` should be the same as those from `ARNN*.conditionals` np.testing.assert_allclose(p3, p1)
def convert_pre_linen(params): """Converts a pre-Linen parameter pytree. In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class. Consider the following module: class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(1, 1)(x) x = nn.Dense(1)(x) return x In pre-Linen the resulting params would have had the structure: {'Conv_0': { ... }, 'Dense_1': { ... } } With Linen the resulting params would instead have had the structure: {'Conv_0': { ... }, 'Dense_0': { ... } } To convert from pre-Linen format to Linen simply call: params = convert_pre_linen(pre_linen_params) Note that you can also use this utility to convert pre-Linen collections because they're following the same module naming. Note though that collections were "flat" in pre-Linen and first need to be unflattened before they can be used with this function: batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in pre_linen_model_state.as_dict().items() })) Then Linen variables can be defined from these converted collections: variables = {'params': params, 'batch_stats': batch_stats} Args: params: Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen). Returns: Parameter pytree with Linen submodule naming. """ if not isinstance(params, (dict, core.FrozenDict)): return params params_renamed = {} counts = {} names = natural_sort(params.keys()) for name in names: value = params[name] match = MODULE_NUM_RE.match(name) if match: module = match.group(1) num = counts.get(module, 0) name = f'{module}_{num}' counts[module] = num + 1 params_renamed[name] = convert_pre_linen(value) return core.freeze(params_renamed)
def wrap_params(variables): return freeze({"params": variables})
def init(self, keys, inpt): _, variables = self.ifun(keys["params"], inpt.shape) return freeze({"params": variables})
def test_eq(self): dg1 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}}) dg2 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}}) self.assertEqual(dg1, dg2) self.assertEqual(freeze(dg1), dg2) self.assertEqual(freeze(dg1), freeze(dg2))
def update_varcolbykey(var, col_name, target, leaf_only=True): wocol, col = var.pop(col_name) col = dict_replace(col, target, leaf_only=leaf_only) del var return freeze({**wocol, col_name: col})
def load_keras_model(self, checkpoint, prng_key=None): # Create the Keras beta-VAE keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq') models, model_loss_function, reconstruction_loss_function = keras_nn # Load weights into keras model from the given checkpoint models['vae'].load_weights(checkpoint).expect_partial() encoder_weights = models['encoder'].get_weights() decoder_weights = models['decoder'].get_weights() # Recast as JAX device arrays to enable autodiff through the model encoder_weights = [jnp.array(w) for w in encoder_weights] decoder_weights = [jnp.array(w) for w in decoder_weights] # Initialise if prng_key is None: prng_key = random.PRNGKey(42) init_data = jnp.ones((self.xdim, self.zdim)) key, subkey1, subkey2 = random.split(prng_key, 3) params = self.init(subkey1, init_data, z_rng=subkey2) # Replace encoder weights unfrozen_params = unfreeze(params) unfrozen_params['params']['encoder']['Conv_0'][ 'kernel'] = encoder_weights[0] unfrozen_params['params']['encoder']['Conv_1'][ 'kernel'] = encoder_weights[1] unfrozen_params['params']['encoder']['BatchNorm_0'][ 'scale'] = encoder_weights[2] unfrozen_params['params']['encoder']['BatchNorm_0'][ 'bias'] = encoder_weights[3] unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][ 'mean'] = encoder_weights[4] unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][ 'var'] = encoder_weights[5] unfrozen_params['params']['encoder']['Conv_2'][ 'kernel'] = encoder_weights[6] unfrozen_params['params']['encoder']['BatchNorm_1'][ 'scale'] = encoder_weights[7] unfrozen_params['params']['encoder']['BatchNorm_1'][ 'bias'] = encoder_weights[8] unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][ 'mean'] = encoder_weights[9] unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][ 'var'] = encoder_weights[10] unfrozen_params['params']['encoder']['Conv_3'][ 'kernel'] = encoder_weights[11] unfrozen_params['params']['encoder']['BatchNorm_2'][ 'scale'] = encoder_weights[12] unfrozen_params['params']['encoder']['BatchNorm_2'][ 'bias'] = encoder_weights[13] unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][ 'mean'] = encoder_weights[14] unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][ 'var'] = encoder_weights[15] unfrozen_params['params']['encoder']['Conv_4'][ 'kernel'] = encoder_weights[16] unfrozen_params['params']['encoder']['Dense_0'][ 'kernel'] = encoder_weights[17] unfrozen_params['params']['encoder']['Dense_0'][ 'bias'] = encoder_weights[18] unfrozen_params['params']['encoder']['Dense_1'][ 'kernel'] = encoder_weights[19] unfrozen_params['params']['encoder']['Dense_1'][ 'bias'] = encoder_weights[20] # Replace decoder weights unfrozen_params['params']['decoder']['ConvTranspose_0'][ 'kernel'] = np.swapaxes(decoder_weights[0], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_0'][ 'scale'] = decoder_weights[1] unfrozen_params['params']['decoder']['BatchNorm_0'][ 'bias'] = decoder_weights[2] unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][ 'mean'] = decoder_weights[3] unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][ 'var'] = decoder_weights[4] unfrozen_params['params']['decoder']['ConvTranspose_1'][ 'kernel'] = np.swapaxes(decoder_weights[5], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_1'][ 'scale'] = decoder_weights[6] unfrozen_params['params']['decoder']['BatchNorm_1'][ 'bias'] = decoder_weights[7] unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][ 'mean'] = decoder_weights[8] unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][ 'var'] = decoder_weights[9] unfrozen_params['params']['decoder']['ConvTranspose_2'][ 'kernel'] = np.swapaxes(decoder_weights[10], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_2'][ 'scale'] = decoder_weights[11] unfrozen_params['params']['decoder']['BatchNorm_2'][ 'bias'] = decoder_weights[12] unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][ 'mean'] = decoder_weights[13] unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][ 'var'] = decoder_weights[14] unfrozen_params['params']['decoder']['ConvTranspose_3'][ 'kernel'] = np.swapaxes(decoder_weights[15], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_3'][ 'scale'] = decoder_weights[16] unfrozen_params['params']['decoder']['BatchNorm_3'][ 'bias'] = decoder_weights[17] unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][ 'mean'] = decoder_weights[18] unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][ 'var'] = decoder_weights[19] unfrozen_params['params']['decoder']['ConvTranspose_4'][ 'kernel'] = np.swapaxes(decoder_weights[20], 2, 3) self.params = freeze(unfrozen_params)
def test_toplevel_submodule_adoption_transform(self): class A(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(3)(x) class B(nn.Module): A: nn.Module @nn.compact def __call__(self, x): return self.A(x) class C(nn.Module): A: nn.Module B: nn.Module @partial( nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True}) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class Csimple(nn.Module): A: nn.Module B: nn.Module @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class D(nn.Module): @nn.compact def __call__(self, x): a1 = A() a2 = A() b = B(a1) c = C(a2, b) return c(x) key = random.PRNGKey(0) x = jnp.ones((10, 10)) p1 = D().init(key, x) y1 = D().apply(p1, x) a1 = A() a2 = A() b = B(a1) p2 = freeze({'params': { 'A': p1['params']['A_0'], 'B': { 'A': p1['params']['A_1'], } }}) print(jax.tree_map(jnp.shape, p1)) # Test method wrapper transform. y2 = C(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y2, atol=1e-7) # Test class transform. Ctrafo = nn.vmap(Csimple, variable_axes={'params': 0}, split_rngs={'params': True}) y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7)
def init(self, rng, *args, **kwargs): variables = self.transformed.init(rng["params"], *args, **kwargs) return freeze({"params": variables})
def unflatten_params(flat_params): return freeze( traverse_util.unflatten_dict( {tuple(k.split("/")): v for k, v in flat_params.items()}))
def test_frozen_items(self): xs = {'a': 1, 'b': {'c': 2}} items = list(freeze(xs).items()) self.assertEqual(items, [('a', 1), ('b', freeze(xs['b']))])
def test_frozen_dict_hash(self): xs = {'a': 1, 'b': {'c': 2}} ys = {'a': 1, 'b': {'c': 3}} self.assertNotEqual(hash(freeze(xs)), hash(freeze(ys)))
def test_frozen_dict_partially_maps(self): x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}), freeze({'a': { 'b': 1 }})) self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
def test_frozen_dict_copies(self): xs = {'a': 1, 'b': {'c': 2}} frozen = freeze(xs) xs['a'] += 1 xs['b']['c'] += 1 self.assertEqual(unfreeze(frozen), {'a': 1, 'b': {'c': 2}})