def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x) lstm = nn.Model(nn.LSTMCell, initial_params) # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial( name='LSTMCell').init(key2, (c0, h0), x) lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), initial_params) onp.testing.assert_allclose(y, y_opt, rtol=1e-6) jtu.check_eq(lstm.params, lstm_opt.params)
def test_param_selection(self): params = { 'x': { 'kernel': 1, 'bias': 2, 'y': { 'kernel': 3, 'bias': 4, }, }, } names = [] def filter_fn(name, _): names.append(name) # track names passed to filter_fn for testing return 'kernel' in name model = nn.Model(None, params) traversal = optim.ModelParamTraversal(filter_fn) values = list(traversal.iterate(model)) self.assertEqual(values, [1, 3]) self.assertEqual(set(names), set([ '/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias'])) new_model = traversal.update(lambda x: x + x, model) expected_params = { 'x': { 'kernel': 2, 'bias': 2, 'y': { 'kernel': 6, 'bias': 4, }, }, } expected_model = nn.Model(None, expected_params) self.assertEqual(new_model, expected_model)
def test_nested_model(self): x = jnp.array([1.]) _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x) inner_model = nn.Model(DummyModule, inner_initial_params) _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model) model = nn.Model(NestedModel, initial_params) y = model(x, inner_model) self.assertEqual(y, jnp.array([3.]))
def test_nested_model_capture_outputs(self): x = jnp.array([1.]) _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x) inner_model = nn.Model(DummyModule, inner_initial_params) _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model) model = nn.Model(NestedModel, initial_params) with nn.capture_module_outputs() as activations: model(x, inner_model) expected_activations = { '/': [x + 2], '/dummy_0': [x + 1], '/inner_model': [x + 2], } self.assertEqual(activations.as_dict(), expected_activations)
def test_autoregressive_sampling_with_lstm(self): L = 4 # Set up symmetry orbit orbit = jnp.array([ jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1) for l in range(L) ]) # Set up variational wave function rnn = nets.LSTM.partial(L=L, hiddenSize=5) _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )]) rnnModel = nn.Model(rnn, params) rbm = nets.RBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS((rnnModel, rbmModel)) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, logPsi, pex = exactSampler.sample(psi) numSamples = 1000000 smc, p, _ = mcSampler.sample(psi, numSamples=numSamples) self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 1e-3)
def create_model(key, flax_module, input_shape, model_kwargs): module = flax_module.partial(**model_kwargs) with nn.stochastic(key): _, initial_params = module.init_by_shape(key, [(input_shape, jnp.float32)]) model = nn.Model(module, initial_params) return model
def test_optimizer_serialization(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) optim_def = optim.Momentum(learning_rate=1.) optimizer = optim_def.create(model) state = serialization.to_state_dict(optimizer) expected_state = { 'target': { 'params': { 'kernel': onp.ones((1, 1)), 'bias': onp.zeros((1, )), } }, 'state': { 'step': 0, 'param_states': { 'params': { 'kernel': { 'momentum': onp.zeros((1, 1)) }, 'bias': { 'momentum': onp.zeros((1, )) }, } } }, } self.assertEqual(state, expected_state) state = jax.tree_map(lambda x: x + 1, expected_state) restored_optimizer = serialization.from_state_dict(optimizer, state) optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer) self.assertEqual(restored_optimizer, optimizer_plus1)
def test_grad_var(self): model_size = 10 example_grads = [{ 'layer1': np.ones(model_size), 'layer2': 3 * np.ones(model_size) }, { 'layer1': 2 * np.ones(model_size), 'layer2': np.ones(model_size) }] eval_config = {'ema_beta': 0.5} training_metrics_grabber = utils.TrainingMetricsGrabber.create( example_grads[0], eval_config) # For the purposes of this test, we create fake optimizers to satisfy # metrics grabber API. fake_model = nn.Model(None, example_grads[0]) new_optimizer = optimizers.GradientDescent( learning_rate=None).create(fake_model) old_optimizer = optimizers.GradientDescent( learning_rate=None).create(fake_model) for grad in example_grads: training_metrics_grabber = training_metrics_grabber.update( grad, old_optimizer, new_optimizer) for layer in ['layer1', 'layer2']: expected_grad_ema = 1 / 4 * np.zeros(model_size) + 1 / 4 * example_grads[ 0][layer] + 1 / 2 * example_grads[1][layer] self.assertArraysAllClose(expected_grad_ema, training_metrics_grabber.state[layer].grad_ema)
def init(key): with nn.attention.Cache().mutate() as cache_def: _, initial_params = model_def.init_by_shape( key, [(input_shape, jnp.float32), (target_shape, jnp.float32)], cache=cache_def) model = nn.Model(model_def, initial_params) return model, cache_def
def test_call_module_method(self): class MultiMethod(nn.Module): def apply(self, x): return x + self.param('bias', x.shape, initializers.ones) @nn.module_method def l2(self): return jnp.sum(self.get_param('bias') ** 2) class MultiMethodModel(nn.Module): def apply(self, x): layer = MultiMethod.shared() layer(x) # init return layer.l2() self.assertEqual( MultiMethod.l2.__qualname__, MultiMethod.__qualname__ + '.l2') x = jnp.array([1., 2.]) _, params = MultiMethod.init(random.PRNGKey(0), x) model = nn.Model(MultiMethod, params) self.assertEqual(model.l2(), 2.) y, _ = MultiMethodModel.init(random.PRNGKey(0), x) self.assertEqual(y, 2.)
def create_representation_model(encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, num_categories, output_features, embed=False, key=random.PRNGKey(0)): """Instantiates a RepresentationModel object.""" module = RepresentationModel.partial(encoder_fn=encoder_fn, encoder_fn_kwargs=encoder_fn_kwargs, reduce_fn=reduce_fn, reduce_fn_kwargs=reduce_fn_kwargs, num_categories=num_categories, output_features=output_features, embed=embed) _, initial_params = RepresentationModel.init_by_shape( key, input_specs=[((1, 1), jnp.float32)], encoder_fn=encoder_fn, encoder_fn_kwargs=encoder_fn_kwargs, reduce_fn=reduce_fn, reduce_fn_kwargs=reduce_fn_kwargs, num_categories=num_categories, output_features=output_features, embed=embed) model = nn.Model(module, initial_params) return model
def test_conv2dlstm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 4, 4, 3)) c0, h0 = nn.ConvLSTM.initialize_carry(rng, (2, ), (4, 4, 6)) self.assertEqual(c0.shape, (2, 4, 4, 6)) self.assertEqual(h0.shape, (2, 4, 4, 6)) (carry, y), initial_params = nn.ConvLSTM.init(key2, (c0, h0), x, features=6, kernel_size=(3, 3)) lstm = nn.Model(nn.ConvLSTM, initial_params) self.assertEqual(carry[0].shape, (2, 4, 4, 6)) self.assertEqual(carry[1].shape, (2, 4, 4, 6)) onp.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_map(onp.shape, lstm.params) self.assertEqual( param_shapes, { 'hh': { 'bias': (6 * 4, ), 'kernel': (3, 3, 6, 6 * 4) }, 'ih': { 'bias': (6 * 4, ), 'kernel': (3, 3, 3, 6 * 4) }, })
def main(argv): key = random.PRNGKey(0) train_ds = tfds.load('mnist', split=tfds.Split.TRAIN) train_ds = train_ds.cache().shuffle(1000).batch(FLAGS.batch_size) test_ds = tfds.as_numpy( tfds.load('mnist', split=tfds.Split.TEST, batch_size=-1)) _, params = VAE.init_by_shape(key, [((1, 784), jnp.float32)]) vae = nn.Model(VAE, params) optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae) for epoch in range(FLAGS.num_epochs): for batch in tfds.as_numpy(train_ds): batch['image'] = batch['image'].reshape(-1, 784) / 255.0 optimizer = train_step(optimizer, batch) z = np.random.normal(size=(64, 20)) metrics, comparison, sample = eval(optimizer.target, test_ds, z) save_image(comparison, 'results/reconstruction_' + str(epoch) + '.png', nrow=8) save_image(sample, 'results/sample_' + str(epoch) + '.png', nrow=8) print("eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}".format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
def create_model(key, batch_size, image_size, model_dtype, space_to_depth): """Initialize a ResNet-50 model.""" if space_to_depth: input_shape = (batch_size, image_size // 2, image_size // 2, 3 * 2 * 2) else: input_shape = (batch_size, image_size, image_size, 3) model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet batchnorm_span = FLAGS.batchnorm_span if batchnorm_span is None: batchnorm_span = max(batch_size, 64) if FLAGS.distributed_batchnorm and (batch_size < batchnorm_span <= batch_size * jax.device_count()): mllogger.event('model_bn_span', batchnorm_span) model_def = model_type.partial(num_classes=1000, axis_name='batch', axis_index_groups=local_replica_groups( batchnorm_span // batch_size), dtype=model_dtype, conv0_space_to_depth=space_to_depth) else: mllogger.event('model_bn_span', batch_size) model_def = model_type.partial(num_classes=1000, dtype=model_dtype, conv0_space_to_depth=space_to_depth) with nn.stateful() as init_state: _, params = model_def.init_by_shape(key, [(input_shape, model_dtype)]) model = nn.Model(model_def, params) return model, init_state
def create_model(key, input_shape): def inducing_loc_init(key, shape): return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:, jnp.newaxis] kwargs = {} for i in range(1, FLAGS.num_layers + 1): kwargs['kernel_fn_{}_kwargs'.format(i)] = { 'amplitude_init': lambda key, shape: jnp.ones(shape), 'length_scale_init': lambda key, shape: jnp.ones(shape) } kwargs['inducing_var_{}_kwargs'.format(i)] = { 'fixed_locations': False, 'whiten': FLAGS.whiten, 'inducing_locations_init': inducing_loc_init } model_def = DeepGPModel.partial(**kwargs) with nn.stochastic(key): _, params = model_def.init_by_shape(key, [ (input_shape, jnp.float64), ], nn.make_rng(), **kwargs) return nn.Model(model_def, params)
def create_model(rng): """Creates a model.""" vocab_size = params['vocab_length'] _, initial_params = charRNN.init_by_shape( rng, [((1, params['seq_length'], vocab_size), jnp.float32)]) model = nn.Model(charRNN, initial_params) return model
def test_gradients_nonhermitian(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) net = nets.CpxRNN.partial(L=3) _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )]) model = nn.Model(net, params1) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psi = NQS(model) psi0 = psi(s) G = psi.gradients(s) delta = 1e-5 params = psi.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal), jax.ops.index[j], 1) psi.update_parameters(delta * u) psi1 = psi(s) psi.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_gradients_cpx(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) rbm = nets.CpxRBM.partial(numHidden=2, bias=True) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)]) rbmModel = nn.Model(rbm, params) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psiC = NQS(rbmModel) psi0 = psiC(s) G = psiC.gradients(s) delta = 1e-5 params = psiC.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=global_defs.tReal), jax.ops.index[j], 1) psiC.update_parameters(delta * u) psi1 = psiC(s) psiC.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_truncated_module(self): x = jnp.array([1.]) _, initial_params = NestedModule.init(random.PRNGKey(0), x) model = nn.Model(NestedModule, initial_params) model = model.truncate_at('/dummy_0') y = model(x) self.assertEqual(y, [x + 1])
def create_loss(rng, model, train_ds): loss_clz = losses.VariationalGaussianLikelihoodLoss dist = model(train_ds['index_points']) _, params = loss_clz.init(rng, train_ds['y'], dist) return nn.Model(loss_clz, params)
def create_model(key, input_shape, model_kwargs): model_def = models.TransformerLM.partial(**model_kwargs) with nn.attention.Cache().mutate() as cache_def: _, initial_params = model_def.init_by_shape( key, [(input_shape, jnp.float32)], cache=cache_def) model = nn.Model(model_def, initial_params) return model, cache_def
def train(): """Run main training loop.""" rng = random.PRNGKey(0) # Get Zachary's karate club graph dataset. node_feats, node_labels, sources, targets = get_karate_club_data() # Create model and optimizer. _, initial_params = GNN.init(rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets) model = nn.Model(GNN, initial_params) optimizer = optim.Adam(learning_rate=0.01).create(model) # Train for 20 iterations. for iteration in range(20): optimizer, loss = train_step(optimizer, node_feats, sources, targets) accuracy = eval_step( # Model is stored in `optimizer.target`. optimizer.target, node_feats, sources, targets, node_labels) print('iteration: %d, loss: %.4f, accuracy: %.2f' % (iteration + 1, loss, accuracy * 100))
def train( self, epochs=None, batch_size=None, model_save_path=None, display_every=1000, ): """ Trains the model for a fixed number of epochs""" dim_x = self.data.geom.dim train_data = self.data.train_data() train_points = device_put(train_data[:, dim_x]) train_tag = device_put(train_data[:, dim_x:]) print('+-+-+-+-+-+-+-') _, initial_params = FNN.init_by_shape(jax.random.PRNGKey(0), [((1, 1, 3), jnp.float32)]) model = nn.Model(FNN, initial_params) optimizer_def = flax.optim.Adam(learning_rate=self.learning_rate) optimizer = optimizer_def.create(model) print('+++++++++++++') first_grad = grad(optimizer.target)(train_points) second_grad = jax.hessian(optimizer.target)(train_points).diagonal() print('------------') print(first_grad, second_grad) return first_grad, second_grad
def create_model(config): """Create a model, starting with a pre-trained checkpoint.""" model_kwargs = dict(config=config.model, ) model_def = modeling.BertForPreTraining.partial(**model_kwargs) if config.init_checkpoint: initial_params = import_weights.load_params( init_checkpoint=config.init_checkpoint, hidden_size=config.model.hidden_size, num_attention_heads=config.model.num_attention_heads, keep_masked_lm_head=True) else: with nn.stochastic(jax.random.PRNGKey(0)): _, initial_params = model_def.init_by_shape( jax.random.PRNGKey(0), [((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_predictions_per_seq), jnp.int32)], deterministic=True) def fixup_for_tpu(x, i=[0]): """HACK to fix incorrect param initialization on TPU.""" if isinstance(x, jax.ShapeDtypeStruct): i[0] += 1 if len(x.shape) == 2: return jnp.zeros(x.shape, x.dtype) else: return nn.linear.default_kernel_init( jax.random.PRNGKey(i[0]), x.shape, x.dtype) else: return x initial_params = jax.tree_map(fixup_for_tpu, initial_params) model = nn.Model(model_def, initial_params) return model
def test_permutation_invariance(self): num_nodes = 4 num_features = 2 rng = random.PRNGKey(0) # Generate random graph. adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2) node_feats = random.normal(rng, (num_nodes, num_features)) sources, targets = jnp.where(adjacency) # Get permuted graph. perm = random.permutation(rng, jnp.arange(num_nodes)) node_feats_perm = node_feats[perm] adjacency_perm = adjacency[perm] for j in range(len(adjacency)): adjacency_perm = jax.ops.index_update( adjacency_perm, j, adjacency_perm[j][perm]) sources_perm, targets_perm = jnp.where(adjacency_perm) # Create GNN. _, initial_params = GNN.init( rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets) model = nn.Model(GNN, initial_params) # Feedforward both original and permuted graph. logits = model(node_feats, None, sources, targets) logits_perm = model(node_feats_perm, None, sources_perm, targets_perm) self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.SelfAttention.partial( num_heads=num_heads, qkv_features=num_heads * num_features, attention_axis=attn_dims, causal_mask=True, precision=lax.Precision.HIGHEST) with nn.attention.Cache().mutate() as cache_def: _, initial_params = module.init_by_shape( key2, [(inputs.shape, inputs.dtype)], cache=cache_def) model = nn.Model(module, initial_params) y_ref = jax.jit(lambda f, x: f(x))(model, inputs) # feed the inputs sequentially to simulate decoding cache0 = cache_def.initialize_cache((bs,) + spatial_shape) def body_fn(cache, x): with cache.mutate() as new_cache: y = model(x, cache=new_cache) return new_cache, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, axis=attn_dims, keepdims=True) onp.testing.assert_allclose(y_ref, y, atol=1e-5)
def initialize(flax_module_def, initializer, loss_fn, input_shape, output_shape, hps, rng, metrics_logger): """Run the given initializer. We initialize in 3 phases. First we run the default initializer that is specified by the model constructor. Next we apply any rescaling as specified by hps.layer_rescale_factors. Finally we run the black box initializer provided by the initializer arg (the default is noop). Args: flax_module_def: An uninitialized flax module definition. initializer: An initializer defined in init_lib. loss_fn: A loss function. input_shape: The input shape of a single data example. output_shape: The output shape of a single data example. hps: A dictionary specifying the model and initializer hparams. rng: An rng key to seed the initialization. metrics_logger: Used for black box initializers that have learning curves. Returns: A tuple (model, batch_stats), where model is the initialized flax.nn.Model and batch_stats is the collection used for batch norm. """ model_dtype = utils.dtype_from_str(hps.model_dtype) # init_by_shape should either pass in a tuple or a list of tuples. # For example, for vision tasks typically input_shape is (image_shape) # For seq2seq tasks, shape can be a list of two tuples corresponding to # input_sequence_shape for encoder and output_sequence_shape for decoder. # TODO(gilmer,ankugarg): Support initializers for list of tuples. if isinstance(input_shape, list): # Typical case for seq2seq models input_specs = [((hps.batch_size, *x), model_dtype) for x in input_shape] else: # Typical case for classification models input_specs = [((hps.batch_size, *input_shape), model_dtype)] params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3) with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): # Using flax_module_def.create can OOM for larger models, so we must use # create by shape here. # TODO(gilmer) Link to flax issue when bug reporting process finalizes. _, params = flax_module_def.init_by_shape(params_rng, input_specs, train=False) model = nn.Model(flax_module_def, params) if hps.get('layer_rescale_factors'): model = model_utils.rescale_layers(model, hps.layer_rescale_factors) # We don't pass batch_stats to the initializer, the initializer will just # run batch_norm in train mode and does not need to maintain the batch_stats. # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need # to change for other models. Maybe have meta_loss_inner as an initializer # hyper_param? # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get # the loss from that. new_model = initializer(loss_fn, model, hps, input_shape, output_shape, init_rng, metrics_logger) return new_model, batch_stats
def test_shared_module(self): rng = random.PRNGKey(0) x = jnp.array([1.]) _, initial_params = LoopModule.init(rng, x) model = nn.Model(LoopModule, initial_params) y = model(x) self.assertEqual(y, jnp.array([3.])) self.assertEqual(model.params, {'dummy': {'bias': jnp.array([1.])}})
def create_model(key, batch_size, image_size, model_dtype): input_shape = (batch_size, image_size, image_size, 3) module = models.ResNet.partial(num_classes=1000, dtype=model_dtype) with nn.stateful() as init_state: _, initial_params = module.init_by_shape(key, [(input_shape, model_dtype)]) model = nn.Model(module, initial_params) return model, init_state
def test_model_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) serialized_bytes = serialization.to_bytes(model) restored_model = serialization.from_bytes(model, serialized_bytes) self.assertEqual(restored_model.params, model.params)