def test_freeze(self): d = {'a': 1, 'b': {'c': 2, 'd': 3}} dg = DotGetter(d) self.assertEqual(freeze(dg), freeze(d)) fd = freeze({'a': 1, 'b': {'c': 2, 'd': 3}}) fdg = DotGetter(d) self.assertEqual(unfreeze(fdg), unfreeze(fd))
def test_convert_pre_linen(self): params = checkpoints.convert_pre_linen({ 'mod_0': { 'submod1_0': {}, 'submod2_1': {}, 'submod1_2': {}, }, 'mod2_2': { 'submod2_2_0': {} }, 'mod2_11': { 'submod2_11_0': {} }, 'mod2_1': { 'submod2_1_0': {} }, }) self.assertDictEqual( core.unfreeze(params), { 'mod_0': { 'submod1_0': {}, 'submod1_1': {}, 'submod2_0': {}, }, 'mod2_0': { 'submod2_1_0': {} }, 'mod2_1': { 'submod2_2_0': {} }, 'mod2_2': { 'submod2_11_0': {} }, })
def _log_shape_and_norms(pytree, metrics_logger, key): shape_and_norms = jax.tree_map( lambda x: (str(x.shape), str(np.linalg.norm(x.reshape(-1)))), unfreeze(pytree)) logging.info(json.dumps(shape_and_norms, sort_keys=True, indent=4)) if metrics_logger is not None: metrics_logger.append_json_object({'key': key, 'value': shape_and_norms})
def test_attention(self): inputs = jnp.ones((2, 7, 16)) model = partial(multi_head_dot_product_attention, num_heads=2, batch_axes=(0, ), attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False)) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} y, variables = jax.jit(init(model))(rngs, inputs, inputs) variable_shapes = jax.tree_map(jnp.shape, variables['params']) self.assertEqual(y.shape, (2, 7, 16)) self.assertEqual( unfreeze(variable_shapes), { 'key': { 'kernel': (2, 16, 8) }, 'value': { 'kernel': (2, 16, 8) }, 'query': { 'kernel': (2, 16, 8) }, 'out': { 'bias': (2, 16), 'kernel': (2, 8, 16) }, })
def test_auto_encoder_bind_method(self): ae = lambda scope, x: AutoEncoder3.create( scope, latents=2, features=4, hidden=3)(x) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual( variable_shapes, { 'encode': { 'hidden': { 'kernel': (4, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 2), 'bias': (2, ) }, }, 'decode': { 'hidden': { 'kernel': (2, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 4), 'bias': (4, ) }, }, })
def test_auto_encoder_hp_struct(self): ae = AutoEncoder(latents=2, features=4, hidden=3) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.PRNGKey(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual( variable_shapes, { 'encoder': { 'hidden': { 'kernel': (4, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 2), 'bias': (2, ) }, }, 'decoder': { 'hidden': { 'kernel': (2, 3), 'bias': (3, ) }, 'out': { 'kernel': (3, 4), 'bias': (4, ) }, }, })
def update_GCNN_parity(params): """Adds biases of parity-flip layers to the corresponding no-flip layers. Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3. Args: params: a parameter pytree """ # unfreeze just in case, doesn't break with a plain dict params = flatten_dict(unfreeze(params)) to_remove = [] for path in params: if ( len(path) > 1 and path[-2].startswith("equivariant_layers_flip") and path[-1] == "bias" ): alt_path = ( *path[:-2], path[-2].replace("equivariant_layers_flip", "equivariant_layers"), path[-1], ) params[alt_path] = params[alt_path] + params[path] to_remove.append(path) for path in to_remove: del params[path] return unflatten_dict(params)
def test_custom_vjp(self): x = random.normal(random.PRNGKey(0), (1, 4)) y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2) grad = jax.grad(loss_fn)(variables, x) grad_shapes = unfreeze( jax.tree_map(jnp.shape, grad['params'])) self.assertEqual(y.shape, (1, 1)) expected_param_shapes = { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, } self.assertEqual(param_shapes, expected_param_shapes) self.assertEqual(grad_shapes, expected_param_shapes) for g in jax.tree_leaves(grad): self.assertTrue(np.all(g == np.sign(g)))
def test_explicit_dense(self): x = jnp.ones((1, 3)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 4)) self.assertEqual(param_shapes, { 'kernel': (3, 4), 'bias': (4, ), })
def test_init_from_decoder(self): ae = TiedAutoEncoder(latents=2, features=4) z = jnp.ones((1, ae.latents)) x_r, variables = init(ae.decode)(random.PRNGKey(0), z) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'kernel': (4, 2), }) self.assertEqual(x_r.shape, (1, 4))
def test_big_resnet(self): x = random.normal(random.PRNGKey(0), (1, 8, 8, 8)) y, variables = init(big_resnet)(random.PRNGKey(1), x) self.assertEqual(y.shape, (1, 8, 8, 8)) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) batch_stats_shapes = unfreeze( jax.tree_map(jnp.shape, variables['batch_stats'])) print(param_shapes) self.assertEqual(param_shapes, { 'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)}, 'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)}, 'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, 'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)} }) self.assertEqual(batch_stats_shapes, { 'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, 'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)} })
def test_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': ExplicitDense((4, 3), (3, )), 'dense_1': ExplicitDense((3, 1), (1, )) })
def test_tied_auto_encoder(self): ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) x_r, variables = init(ae)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'kernel': (4, 2), }) self.assertEqual(x.shape, x_r.shape)
def load_keras_model(self, checkpoint, prng_key=None): # Create the Keras beta-VAE keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq') models, model_loss_function, reconstruction_loss_function = keras_nn # Load weights into keras model from a given checkpoint # base_dir = os.path.dirname(os.getcwd()) # checkpoint = os.path.join(base_dir, 'data', f'epoch_{epoch}', 'Model') # assert os.path.exists(checkpoint), "Path does not exist : " + checkpoint models['vae'].load_weights(checkpoint) decoder_weights = models['decoder'].get_weights() decoder_weights = [jnp.array(w) for w in decoder_weights] # Initialise if prng_key is None: prng_key = random.PRNGKey(42) init_data = jnp.ones((1, 1, self.zdim)) params = self.init(prng_key, init_data) # Replace weights by the ones from keras unfrozen_params = unfreeze(params) unfrozen_params['params']['ConvTranspose_0']['kernel'] = np.swapaxes( decoder_weights[0], 2, 3) unfrozen_params['params']['BatchNorm_0']['scale'] = decoder_weights[1] unfrozen_params['params']['BatchNorm_0']['bias'] = decoder_weights[2] unfrozen_params['batch_stats']['BatchNorm_0'][ 'mean'] = decoder_weights[3] unfrozen_params['batch_stats']['BatchNorm_0']['var'] = decoder_weights[ 4] unfrozen_params['params']['ConvTranspose_1']['kernel'] = np.swapaxes( decoder_weights[5], 2, 3) unfrozen_params['params']['BatchNorm_1']['scale'] = decoder_weights[6] unfrozen_params['params']['BatchNorm_1']['bias'] = decoder_weights[7] unfrozen_params['batch_stats']['BatchNorm_1'][ 'mean'] = decoder_weights[8] unfrozen_params['batch_stats']['BatchNorm_1']['var'] = decoder_weights[ 9] unfrozen_params['params']['ConvTranspose_2']['kernel'] = np.swapaxes( decoder_weights[10], 2, 3) unfrozen_params['params']['BatchNorm_2']['scale'] = decoder_weights[11] unfrozen_params['params']['BatchNorm_2']['bias'] = decoder_weights[12] unfrozen_params['batch_stats']['BatchNorm_2'][ 'mean'] = decoder_weights[13] unfrozen_params['batch_stats']['BatchNorm_2']['var'] = decoder_weights[ 14] unfrozen_params['params']['ConvTranspose_3']['kernel'] = np.swapaxes( decoder_weights[15], 2, 3) unfrozen_params['params']['BatchNorm_3']['scale'] = decoder_weights[16] unfrozen_params['params']['BatchNorm_3']['bias'] = decoder_weights[17] unfrozen_params['batch_stats']['BatchNorm_3'][ 'mean'] = decoder_weights[18] unfrozen_params['batch_stats']['BatchNorm_3']['var'] = decoder_weights[ 19] unfrozen_params['params']['ConvTranspose_4']['kernel'] = np.swapaxes( decoder_weights[20], 2, 3) self.params = freeze(unfrozen_params)
def dict_replace(col, target, leaf_only=True): col_flat = flatten_dict(unfreeze(col)) diff = {} for keys_flat in col_flat.keys(): for tar_key, tar_val in target.items(): if (keys_flat[-1] == tar_key if leaf_only else (tar_key in keys_flat)): diff[keys_flat] = tar_val col_flat.update(diff) col = unflatten_dict(col_flat) return col
def load_pretrained(variables, url='', default_cfg=None, filter_fn=None): if not url: assert default_cfg is not None and default_cfg['url'] url = default_cfg['url'] state_dict = load_state_dict_from_url(url, transpose=True) source_params, source_state = split_state_dict(state_dict) if filter_fn is not None: # filter after split as we may have modified the split criteria (ie bn running vars) source_params = filter_fn(source_params) source_state = filter_fn(source_state) # FIXME better way to do this? var_unfrozen = unfreeze(variables) missing_keys = [] flat_params = flatten_dict(var_unfrozen['params']) flat_param_keys = set() for k, v in flat_params.items(): flat_k = '.'.join(k) if flat_k in source_params: assert flat_params[k].shape == v.shape flat_params[k] = source_params[flat_k] else: missing_keys.append(flat_k) flat_param_keys.add(flat_k) unexpected_keys = list( set(source_params.keys()).difference(flat_param_keys)) params = freeze(unflatten_dict(flat_params)) flat_state = flatten_dict(var_unfrozen['batch_stats']) flat_state_keys = set() for k, v in flat_state.items(): flat_k = '.'.join(k) if flat_k in source_state: assert flat_state[k].shape == v.shape flat_state[k] = source_state[flat_k] else: missing_keys.append(flat_k) flat_state_keys.add(flat_k) unexpected_keys.extend( list(set(source_state.keys()).difference(flat_state_keys))) batch_stats = freeze(unflatten_dict(flat_state)) if missing_keys: print( f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}' ) if unexpected_keys: print( f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}' ) return dict(params=params, batch_stats=batch_stats)
def sparse_init(loss_fn, flax_module, params, hps, input_shape, output_shape, rng_key, metrics_logger=None, log_every=10): """Implements SparseInit initializer. Args: loss_fn: Loss function. flax_module: Flax nn.Module class. params: The dict of model parameters. hps: HParam object. Required hparams are meta_learning_rate, meta_batch_size, meta_steps, and epsilon. input_shape: Must agree with batch[0].shape[1:]. output_shape: Must agree with batch[1].shape[1:]. rng_key: jax.PRNGKey, used to seed all randomness. metrics_logger: Instance of utils.MetricsLogger log_every: Print meta loss every k steps. Returns: A Flax model with sparse initialization. """ del flax_module, loss_fn, input_shape, output_shape, rng_key, metrics_logger, log_every params = unfreeze(params) activation_functions = hps.activation_function num_hidden_layers = len(hps.hid_sizes) if isinstance(hps.activation_function, str): activation_functions = [hps.activation_function] * num_hidden_layers for i, key in enumerate(params): num_units, num_weights = params[key]['kernel'].shape mask = np.zeros((num_units, num_weights), dtype=bool) for k in range(num_units): if num_weights >= hps.non_zero_connection_weights: sample = np.random.choice(num_weights, hps.non_zero_connection_weights, replace=False) else: sample = np.random.choice(num_weights, hps.non_zero_connection_weights) mask[k, sample] = True params[key]['kernel'] = params[key]['kernel'].at[~mask].set(0.0) if i < num_hidden_layers and activation_functions[i] == 'tanh': params[key]['bias'] = params[key]['bias'].at[:].set(0.5) else: params[key]['bias'] = params[key]['bias'].at[:].set(0.0) return frozen_dict.freeze(params)
def test_vmap_unshared(self): x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)}, 'out': {'kernel': (2, 8, 1), 'bias': (2, 1)}, }) self.assertEqual(y.shape, (2, 1)) self.assertFalse(jnp.allclose(y[0], y[1]))
def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) y, variables = init(flow.forward)(random.PRNGKey(0), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 3)) self.assertEqual(param_shapes, { '0': {'kernel': (3, 3), 'bias': (3,)}, '1': {'kernel': (3, 3), 'bias': (3,)}, '2': {'kernel': (3, 3), 'bias': (3,)}, }) x_restored = apply(flow.backward)(variables, y) self.assertTrue(jnp.allclose(x, x_restored))
def test_weight_std(self): x = random.normal(random.PRNGKey(0), (1, 4,)) y, variables = init(mlp)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, }) self.assertEqual(y.shape, (1, 1)) self.assertTrue(y.ravel() < 1.) y2 = apply(mlp)(variables, x) self.assertTrue(jnp.allclose(y, y2))
def test_semi_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x) param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': { 'kernel': (4, 3), 'bias': (3, ) }, 'dense_1': { 'kernel': (3, 1), 'bias': (1, ) } })
def derive_logical_axes(self, optimizer, param_logical_axes): """Returns PartitionSpec associated with optimizer states. Args: optimizer: A flax.Optim optimizer. param_logical_axes: Pytree of pjit.PartitionSpec associated with params. """ assert self._hps.shard_optimizer_states optimizer_dict = optimizer.state_dict() optimizer_logical_axes = jax.tree_map(lambda x: None, optimizer.state_dict()) optimizer_logical_axes['target'] = param_logical_axes init_state = self.distributed_shampoo.init(None) pspec_fn = init_state.pspec_fn optimizer_logical_axes['state']['param_states'] = pspec_fn( optimizer_dict['target'], param_logical_axes, self._hps.statistics_partition_spec) return optimizer.restore_state(unfreeze(optimizer_logical_axes))
def update_dense_symm(params, names=["dense_symm", "Dense"]): """Updates DenseSymm kernels in pre-PR#1030 parameter pytrees to the new 3D convention. Args: params: a parameter pytree names: layer names search for, default: those used in RBMSymm and GCNN* """ params = unfreeze(params) # just in case, doesn't break with a plain dict def fix_one_kernel(args): path, array = args if (len(path) > 1 and path[-2] in names and path[-1] == "kernel" and array.ndim == 2): array = jnp.expand_dims(array, 1) return (path, array) return unflatten_dict( dict(map(fix_one_kernel, flatten_dict(params).items())))
def load_keras_model(self, checkpoint, prng_key=None): # Create the Keras beta-VAE keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq') models, model_loss_function, reconstruction_loss_function = keras_nn # Load weights into keras model from the given checkpoint models['vae'].load_weights(checkpoint).expect_partial() encoder_weights = models['encoder'].get_weights() decoder_weights = models['decoder'].get_weights() # Recast as JAX device arrays to enable autodiff through the model encoder_weights = [jnp.array(w) for w in encoder_weights] decoder_weights = [jnp.array(w) for w in decoder_weights] # Initialise if prng_key is None: prng_key = random.PRNGKey(42) init_data = jnp.ones((self.xdim, self.zdim)) key, subkey1, subkey2 = random.split(prng_key, 3) params = self.init(subkey1, init_data, z_rng=subkey2) # Replace encoder weights unfrozen_params = unfreeze(params) unfrozen_params['params']['encoder']['Conv_0'][ 'kernel'] = encoder_weights[0] unfrozen_params['params']['encoder']['Conv_1'][ 'kernel'] = encoder_weights[1] unfrozen_params['params']['encoder']['BatchNorm_0'][ 'scale'] = encoder_weights[2] unfrozen_params['params']['encoder']['BatchNorm_0'][ 'bias'] = encoder_weights[3] unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][ 'mean'] = encoder_weights[4] unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][ 'var'] = encoder_weights[5] unfrozen_params['params']['encoder']['Conv_2'][ 'kernel'] = encoder_weights[6] unfrozen_params['params']['encoder']['BatchNorm_1'][ 'scale'] = encoder_weights[7] unfrozen_params['params']['encoder']['BatchNorm_1'][ 'bias'] = encoder_weights[8] unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][ 'mean'] = encoder_weights[9] unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][ 'var'] = encoder_weights[10] unfrozen_params['params']['encoder']['Conv_3'][ 'kernel'] = encoder_weights[11] unfrozen_params['params']['encoder']['BatchNorm_2'][ 'scale'] = encoder_weights[12] unfrozen_params['params']['encoder']['BatchNorm_2'][ 'bias'] = encoder_weights[13] unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][ 'mean'] = encoder_weights[14] unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][ 'var'] = encoder_weights[15] unfrozen_params['params']['encoder']['Conv_4'][ 'kernel'] = encoder_weights[16] unfrozen_params['params']['encoder']['Dense_0'][ 'kernel'] = encoder_weights[17] unfrozen_params['params']['encoder']['Dense_0'][ 'bias'] = encoder_weights[18] unfrozen_params['params']['encoder']['Dense_1'][ 'kernel'] = encoder_weights[19] unfrozen_params['params']['encoder']['Dense_1'][ 'bias'] = encoder_weights[20] # Replace decoder weights unfrozen_params['params']['decoder']['ConvTranspose_0'][ 'kernel'] = np.swapaxes(decoder_weights[0], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_0'][ 'scale'] = decoder_weights[1] unfrozen_params['params']['decoder']['BatchNorm_0'][ 'bias'] = decoder_weights[2] unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][ 'mean'] = decoder_weights[3] unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][ 'var'] = decoder_weights[4] unfrozen_params['params']['decoder']['ConvTranspose_1'][ 'kernel'] = np.swapaxes(decoder_weights[5], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_1'][ 'scale'] = decoder_weights[6] unfrozen_params['params']['decoder']['BatchNorm_1'][ 'bias'] = decoder_weights[7] unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][ 'mean'] = decoder_weights[8] unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][ 'var'] = decoder_weights[9] unfrozen_params['params']['decoder']['ConvTranspose_2'][ 'kernel'] = np.swapaxes(decoder_weights[10], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_2'][ 'scale'] = decoder_weights[11] unfrozen_params['params']['decoder']['BatchNorm_2'][ 'bias'] = decoder_weights[12] unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][ 'mean'] = decoder_weights[13] unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][ 'var'] = decoder_weights[14] unfrozen_params['params']['decoder']['ConvTranspose_3'][ 'kernel'] = np.swapaxes(decoder_weights[15], 2, 3) unfrozen_params['params']['decoder']['BatchNorm_3'][ 'scale'] = decoder_weights[16] unfrozen_params['params']['decoder']['BatchNorm_3'][ 'bias'] = decoder_weights[17] unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][ 'mean'] = decoder_weights[18] unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][ 'var'] = decoder_weights[19] unfrozen_params['params']['decoder']['ConvTranspose_4'][ 'kernel'] = np.swapaxes(decoder_weights[20], 2, 3) self.params = freeze(unfrozen_params)
def test_frozen_dict_pop(self): xs = {'a': 1, 'b': {'c': 2}} b, a = FrozenDict(xs).pop('a') self.assertEqual(a, 1) self.assertEqual(unfreeze(b), {'b': {'c': 2}})
def test_frozen_dict_copies(self): xs = {'a': 1, 'b': {'c': 2}} frozen = freeze(xs) xs['a'] += 1 xs['b']['c'] += 1 self.assertEqual(unfreeze(frozen), {'a': 1, 'b': {'c': 2}})
conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3))) x = scope.child(norm, 'init_bn')(x) x = act(x) x = nn.max_pool(x, (2, 2), (2, 2), 'SAME') for i, size in enumerate(block_sizes): for j in range(size): strides = (1, 1) if i > 0 and j == 0: strides = (2, 2) block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') x = residual_block(block_scope, x, conv, norm, act, block_features, strides) # we can access parameters of the sub module by operating on the scope # Example: # block_scope.get_kind('param')['conv_1']['kernel'] x = jnp.mean(x, (1, 2)) x = scope.child(nn.dense, 'out')(x, num_classes) return x if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 224, 224, 3)) y, params = init(resnet)(random.PRNGKey(1), x) print(y.shape) print(jax.tree_map(jnp.shape, unfreeze(params)))
def test_frozen_dict_maps(self): xs = {'a': 1, 'b': {'c': 2}} frozen = FrozenDict(xs) frozen2 = jax.tree_map(lambda x: x + x, frozen) self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
def test_frozen_dict_partially_maps(self): x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}), freeze({'a': { 'b': 1 }})) self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
variable_carry='counter', variable_in_axes={'param': lift.broadcast}, variable_out_axes={'param': lift.broadcast}, split_rngs={'param': False})(scope, (), xs) else: carry, ys = lift.scan(body_fn, variable_carry='counter', variable_in_axes={'param': 0}, variable_out_axes={'param': 0}, split_rngs={'param': True})(scope, (), xs) # output layer return carry, ys if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print( 'unshared params: (outputs should be different, parameters has extra dim)' ) y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False) print(y) print(unfreeze(variables)) print('shared params: (outputs should be the same)') y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True) print(y) print(unfreeze(variables))