Exemple #1
0
 def test_freeze(self):
     d = {'a': 1, 'b': {'c': 2, 'd': 3}}
     dg = DotGetter(d)
     self.assertEqual(freeze(dg), freeze(d))
     fd = freeze({'a': 1, 'b': {'c': 2, 'd': 3}})
     fdg = DotGetter(d)
     self.assertEqual(unfreeze(fdg), unfreeze(fd))
Exemple #2
0
 def test_convert_pre_linen(self):
     params = checkpoints.convert_pre_linen({
         'mod_0': {
             'submod1_0': {},
             'submod2_1': {},
             'submod1_2': {},
         },
         'mod2_2': {
             'submod2_2_0': {}
         },
         'mod2_11': {
             'submod2_11_0': {}
         },
         'mod2_1': {
             'submod2_1_0': {}
         },
     })
     self.assertDictEqual(
         core.unfreeze(params), {
             'mod_0': {
                 'submod1_0': {},
                 'submod1_1': {},
                 'submod2_0': {},
             },
             'mod2_0': {
                 'submod2_1_0': {}
             },
             'mod2_1': {
                 'submod2_2_0': {}
             },
             'mod2_2': {
                 'submod2_11_0': {}
             },
         })
Exemple #3
0
def _log_shape_and_norms(pytree, metrics_logger, key):
  shape_and_norms = jax.tree_map(
      lambda x: (str(x.shape), str(np.linalg.norm(x.reshape(-1)))),
      unfreeze(pytree))
  logging.info(json.dumps(shape_and_norms, sort_keys=True, indent=4))
  if metrics_logger is not None:
    metrics_logger.append_json_object({'key': key, 'value': shape_and_norms})
Exemple #4
0
    def test_attention(self):
        inputs = jnp.ones((2, 7, 16))
        model = partial(multi_head_dot_product_attention,
                        num_heads=2,
                        batch_axes=(0, ),
                        attn_fn=with_dropout(softmax_attn,
                                             0.1,
                                             deterministic=False))

        rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
        y, variables = jax.jit(init(model))(rngs, inputs, inputs)
        variable_shapes = jax.tree_map(jnp.shape, variables['params'])
        self.assertEqual(y.shape, (2, 7, 16))
        self.assertEqual(
            unfreeze(variable_shapes), {
                'key': {
                    'kernel': (2, 16, 8)
                },
                'value': {
                    'kernel': (2, 16, 8)
                },
                'query': {
                    'kernel': (2, 16, 8)
                },
                'out': {
                    'bias': (2, 16),
                    'kernel': (2, 8, 16)
                },
            })
    def test_auto_encoder_bind_method(self):
        ae = lambda scope, x: AutoEncoder3.create(
            scope, latents=2, features=4, hidden=3)(x)
        x = jnp.ones((1, 4))

        x_r, variables = init(ae)(random.PRNGKey(0), x)
        self.assertEqual(x.shape, x_r.shape)
        variable_shapes = unfreeze(jax.tree_map(jnp.shape,
                                                variables['params']))
        self.assertEqual(
            variable_shapes, {
                'encode': {
                    'hidden': {
                        'kernel': (4, 3),
                        'bias': (3, )
                    },
                    'out': {
                        'kernel': (3, 2),
                        'bias': (2, )
                    },
                },
                'decode': {
                    'hidden': {
                        'kernel': (2, 3),
                        'bias': (3, )
                    },
                    'out': {
                        'kernel': (3, 4),
                        'bias': (4, )
                    },
                },
            })
 def test_auto_encoder_hp_struct(self):
     ae = AutoEncoder(latents=2, features=4, hidden=3)
     x = jnp.ones((1, 4))
     x_r, variables = init(ae)(random.PRNGKey(0), x)
     self.assertEqual(x.shape, x_r.shape)
     variable_shapes = unfreeze(jax.tree_map(jnp.shape,
                                             variables['params']))
     self.assertEqual(
         variable_shapes, {
             'encoder': {
                 'hidden': {
                     'kernel': (4, 3),
                     'bias': (3, )
                 },
                 'out': {
                     'kernel': (3, 2),
                     'bias': (2, )
                 },
             },
             'decoder': {
                 'hidden': {
                     'kernel': (2, 3),
                     'bias': (3, )
                 },
                 'out': {
                     'kernel': (3, 4),
                     'bias': (4, )
                 },
             },
         })
Exemple #7
0
def update_GCNN_parity(params):
    """Adds biases of parity-flip layers to the corresponding no-flip layers.
    Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3.

    Args:
        params: a parameter pytree
    """
    # unfreeze just in case, doesn't break with a plain dict
    params = flatten_dict(unfreeze(params))
    to_remove = []
    for path in params:
        if (
            len(path) > 1
            and path[-2].startswith("equivariant_layers_flip")
            and path[-1] == "bias"
        ):
            alt_path = (
                *path[:-2],
                path[-2].replace("equivariant_layers_flip", "equivariant_layers"),
                path[-1],
            )
            params[alt_path] = params[alt_path] + params[path]
            to_remove.append(path)
    for path in to_remove:
        del params[path]
    return unflatten_dict(params)
 def test_custom_vjp(self):
   x = random.normal(random.PRNGKey(0), (1, 4))
   y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2)
   grad = jax.grad(loss_fn)(variables, x)
   grad_shapes = unfreeze(
       jax.tree_map(jnp.shape, grad['params']))
   self.assertEqual(y.shape, (1, 1))
   expected_param_shapes = {
       'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
       'out': {'kernel': (8, 1), 'bias': (1,)},
   }
   self.assertEqual(param_shapes, expected_param_shapes)
   self.assertEqual(grad_shapes, expected_param_shapes)
   for g in jax.tree_leaves(grad):
     self.assertTrue(np.all(g == np.sign(g)))
Exemple #9
0
 def test_explicit_dense(self):
     x = jnp.ones((1, 3))
     y, variables = init(explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 4))
     self.assertEqual(param_shapes, {
         'kernel': (3, 4),
         'bias': (4, ),
     })
    def test_init_from_decoder(self):
        ae = TiedAutoEncoder(latents=2, features=4)
        z = jnp.ones((1, ae.latents))
        x_r, variables = init(ae.decode)(random.PRNGKey(0), z)

        param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
        self.assertEqual(param_shapes, {
            'kernel': (4, 2),
        })
        self.assertEqual(x_r.shape, (1, 4))
Exemple #11
0
 def test_big_resnet(self):
   x = random.normal(random.PRNGKey(0), (1, 8, 8, 8))
   y, variables = init(big_resnet)(random.PRNGKey(1), x)
   self.assertEqual(y.shape, (1, 8, 8, 8))
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   batch_stats_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['batch_stats']))
   print(param_shapes)
   self.assertEqual(param_shapes, {
       'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)},
       'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)},
       'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)},
       'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}
   })
   self.assertEqual(batch_stats_shapes, {
       'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)},
       'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)}
   })
Exemple #12
0
 def test_explicit_dense(self):
     x = jnp.ones((1, 4))
     y, variables = init(explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 1))
     self.assertEqual(
         param_shapes, {
             'dense_0': ExplicitDense((4, 3), (3, )),
             'dense_1': ExplicitDense((3, 1), (1, ))
         })
    def test_tied_auto_encoder(self):
        ae = TiedAutoEncoder(latents=2, features=4)
        x = jnp.ones((1, ae.features))
        x_r, variables = init(ae)(random.PRNGKey(0), x)

        param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
        self.assertEqual(param_shapes, {
            'kernel': (4, 2),
        })
        self.assertEqual(x.shape, x_r.shape)
Exemple #14
0
    def load_keras_model(self, checkpoint, prng_key=None):
        # Create the Keras beta-VAE
        keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq')
        models, model_loss_function, reconstruction_loss_function = keras_nn
        # Load weights into keras model from a given checkpoint
        # base_dir = os.path.dirname(os.getcwd())
        # checkpoint = os.path.join(base_dir, 'data', f'epoch_{epoch}', 'Model')
        # assert os.path.exists(checkpoint), "Path does not exist : " + checkpoint
        models['vae'].load_weights(checkpoint)
        decoder_weights = models['decoder'].get_weights()
        decoder_weights = [jnp.array(w) for w in decoder_weights]

        # Initialise
        if prng_key is None:
            prng_key = random.PRNGKey(42)
        init_data = jnp.ones((1, 1, self.zdim))
        params = self.init(prng_key, init_data)

        # Replace weights by the ones from keras
        unfrozen_params = unfreeze(params)
        unfrozen_params['params']['ConvTranspose_0']['kernel'] = np.swapaxes(
            decoder_weights[0], 2, 3)
        unfrozen_params['params']['BatchNorm_0']['scale'] = decoder_weights[1]
        unfrozen_params['params']['BatchNorm_0']['bias'] = decoder_weights[2]
        unfrozen_params['batch_stats']['BatchNorm_0'][
            'mean'] = decoder_weights[3]
        unfrozen_params['batch_stats']['BatchNorm_0']['var'] = decoder_weights[
            4]
        unfrozen_params['params']['ConvTranspose_1']['kernel'] = np.swapaxes(
            decoder_weights[5], 2, 3)
        unfrozen_params['params']['BatchNorm_1']['scale'] = decoder_weights[6]
        unfrozen_params['params']['BatchNorm_1']['bias'] = decoder_weights[7]
        unfrozen_params['batch_stats']['BatchNorm_1'][
            'mean'] = decoder_weights[8]
        unfrozen_params['batch_stats']['BatchNorm_1']['var'] = decoder_weights[
            9]
        unfrozen_params['params']['ConvTranspose_2']['kernel'] = np.swapaxes(
            decoder_weights[10], 2, 3)
        unfrozen_params['params']['BatchNorm_2']['scale'] = decoder_weights[11]
        unfrozen_params['params']['BatchNorm_2']['bias'] = decoder_weights[12]
        unfrozen_params['batch_stats']['BatchNorm_2'][
            'mean'] = decoder_weights[13]
        unfrozen_params['batch_stats']['BatchNorm_2']['var'] = decoder_weights[
            14]
        unfrozen_params['params']['ConvTranspose_3']['kernel'] = np.swapaxes(
            decoder_weights[15], 2, 3)
        unfrozen_params['params']['BatchNorm_3']['scale'] = decoder_weights[16]
        unfrozen_params['params']['BatchNorm_3']['bias'] = decoder_weights[17]
        unfrozen_params['batch_stats']['BatchNorm_3'][
            'mean'] = decoder_weights[18]
        unfrozen_params['batch_stats']['BatchNorm_3']['var'] = decoder_weights[
            19]
        unfrozen_params['params']['ConvTranspose_4']['kernel'] = np.swapaxes(
            decoder_weights[20], 2, 3)
        self.params = freeze(unfrozen_params)
Exemple #15
0
def dict_replace(col, target, leaf_only=True):
    col_flat = flatten_dict(unfreeze(col))
    diff = {}
    for keys_flat in col_flat.keys():
        for tar_key, tar_val in target.items():
            if (keys_flat[-1] == tar_key if leaf_only else
                (tar_key in keys_flat)):
                diff[keys_flat] = tar_val
    col_flat.update(diff)
    col = unflatten_dict(col_flat)
    return col
Exemple #16
0
def load_pretrained(variables, url='', default_cfg=None, filter_fn=None):
    if not url:
        assert default_cfg is not None and default_cfg['url']
        url = default_cfg['url']
    state_dict = load_state_dict_from_url(url, transpose=True)

    source_params, source_state = split_state_dict(state_dict)
    if filter_fn is not None:
        # filter after split as we may have modified the split criteria (ie bn running vars)
        source_params = filter_fn(source_params)
        source_state = filter_fn(source_state)

    # FIXME better way to do this?
    var_unfrozen = unfreeze(variables)
    missing_keys = []
    flat_params = flatten_dict(var_unfrozen['params'])
    flat_param_keys = set()
    for k, v in flat_params.items():
        flat_k = '.'.join(k)
        if flat_k in source_params:
            assert flat_params[k].shape == v.shape
            flat_params[k] = source_params[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_param_keys.add(flat_k)
    unexpected_keys = list(
        set(source_params.keys()).difference(flat_param_keys))
    params = freeze(unflatten_dict(flat_params))

    flat_state = flatten_dict(var_unfrozen['batch_stats'])
    flat_state_keys = set()
    for k, v in flat_state.items():
        flat_k = '.'.join(k)
        if flat_k in source_state:
            assert flat_state[k].shape == v.shape
            flat_state[k] = source_state[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_state_keys.add(flat_k)
    unexpected_keys.extend(
        list(set(source_state.keys()).difference(flat_state_keys)))
    batch_stats = freeze(unflatten_dict(flat_state))

    if missing_keys:
        print(
            f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}'
        )
    if unexpected_keys:
        print(
            f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}'
        )

    return dict(params=params, batch_stats=batch_stats)
Exemple #17
0
def sparse_init(loss_fn,
                flax_module,
                params,
                hps,
                input_shape,
                output_shape,
                rng_key,
                metrics_logger=None,
                log_every=10):
    """Implements SparseInit initializer.

  Args:
    loss_fn: Loss function.
    flax_module: Flax nn.Module class.
    params: The dict of model parameters.
    hps: HParam object. Required hparams are meta_learning_rate,
      meta_batch_size, meta_steps, and epsilon.
    input_shape: Must agree with batch[0].shape[1:].
    output_shape: Must agree with batch[1].shape[1:].
    rng_key: jax.PRNGKey, used to seed all randomness.
    metrics_logger: Instance of utils.MetricsLogger
    log_every: Print meta loss every k steps.

  Returns:
    A Flax model with sparse initialization.
  """

    del flax_module, loss_fn, input_shape, output_shape, rng_key, metrics_logger, log_every

    params = unfreeze(params)
    activation_functions = hps.activation_function
    num_hidden_layers = len(hps.hid_sizes)
    if isinstance(hps.activation_function, str):
        activation_functions = [hps.activation_function] * num_hidden_layers
    for i, key in enumerate(params):
        num_units, num_weights = params[key]['kernel'].shape
        mask = np.zeros((num_units, num_weights), dtype=bool)
        for k in range(num_units):
            if num_weights >= hps.non_zero_connection_weights:
                sample = np.random.choice(num_weights,
                                          hps.non_zero_connection_weights,
                                          replace=False)
            else:
                sample = np.random.choice(num_weights,
                                          hps.non_zero_connection_weights)
            mask[k, sample] = True
        params[key]['kernel'] = params[key]['kernel'].at[~mask].set(0.0)
        if i < num_hidden_layers and activation_functions[i] == 'tanh':
            params[key]['bias'] = params[key]['bias'].at[:].set(0.5)
        else:
            params[key]['bias'] = params[key]['bias'].at[:].set(0.0)
    return frozen_dict.freeze(params)
Exemple #18
0
  def test_vmap_unshared(self):
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False)

    param_shapes = unfreeze(
        jax.tree_map(jnp.shape, variables['params']))
    self.assertEqual(param_shapes, {
        'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)},
        'out': {'kernel': (2, 8, 1), 'bias': (2, 1)},
    })
    self.assertEqual(y.shape, (2, 1))
    self.assertFalse(jnp.allclose(y[0], y[1]))
Exemple #19
0
 def test_flow(self):
   x = jnp.ones((1, 3))
   flow = StackFlow((DenseFlow(),) * 3)
   y, variables = init(flow.forward)(random.PRNGKey(0), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   self.assertEqual(y.shape, (1, 3))
   self.assertEqual(param_shapes, {
       '0': {'kernel': (3, 3), 'bias': (3,)},
       '1': {'kernel': (3, 3), 'bias': (3,)},
       '2': {'kernel': (3, 3), 'bias': (3,)},
   })
   x_restored = apply(flow.backward)(variables, y)
   self.assertTrue(jnp.allclose(x, x_restored))
Exemple #20
0
  def test_weight_std(self):
    x = random.normal(random.PRNGKey(0), (1, 4,))
    y, variables = init(mlp)(random.PRNGKey(1), x)

    param_shapes = unfreeze(
        jax.tree_map(jnp.shape, variables['params']))
    self.assertEqual(param_shapes, {
        'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
        'out': {'kernel': (8, 1), 'bias': (1,)},
    })
    self.assertEqual(y.shape, (1, 1))
    self.assertTrue(y.ravel() < 1.)

    y2 = apply(mlp)(variables, x)
    self.assertTrue(jnp.allclose(y, y2))
Exemple #21
0
 def test_semi_explicit_dense(self):
     x = jnp.ones((1, 4))
     y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 1))
     self.assertEqual(
         param_shapes, {
             'dense_0': {
                 'kernel': (4, 3),
                 'bias': (3, )
             },
             'dense_1': {
                 'kernel': (3, 1),
                 'bias': (1, )
             }
         })
    def derive_logical_axes(self, optimizer, param_logical_axes):
        """Returns PartitionSpec associated with optimizer states.

    Args:
      optimizer: A flax.Optim optimizer.
      param_logical_axes: Pytree of pjit.PartitionSpec associated with params.
    """
        assert self._hps.shard_optimizer_states
        optimizer_dict = optimizer.state_dict()
        optimizer_logical_axes = jax.tree_map(lambda x: None,
                                              optimizer.state_dict())
        optimizer_logical_axes['target'] = param_logical_axes
        init_state = self.distributed_shampoo.init(None)
        pspec_fn = init_state.pspec_fn
        optimizer_logical_axes['state']['param_states'] = pspec_fn(
            optimizer_dict['target'], param_logical_axes,
            self._hps.statistics_partition_spec)
        return optimizer.restore_state(unfreeze(optimizer_logical_axes))
Exemple #23
0
def update_dense_symm(params, names=["dense_symm", "Dense"]):
    """Updates DenseSymm kernels in pre-PR#1030 parameter pytrees to the new
    3D convention.

    Args:
        params: a parameter pytree
        names: layer names search for, default: those used in RBMSymm and GCNN*
    """
    params = unfreeze(params)  # just in case, doesn't break with a plain dict

    def fix_one_kernel(args):
        path, array = args
        if (len(path) > 1 and path[-2] in names and path[-1] == "kernel"
                and array.ndim == 2):
            array = jnp.expand_dims(array, 1)
        return (path, array)

    return unflatten_dict(
        dict(map(fix_one_kernel,
                 flatten_dict(params).items())))
Exemple #24
0
    def load_keras_model(self, checkpoint, prng_key=None):
        # Create the Keras beta-VAE
        keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq')
        models, model_loss_function, reconstruction_loss_function = keras_nn
        # Load weights into keras model from the given checkpoint
        models['vae'].load_weights(checkpoint).expect_partial()
        encoder_weights = models['encoder'].get_weights()
        decoder_weights = models['decoder'].get_weights()
        # Recast as JAX device arrays to enable autodiff through the model
        encoder_weights = [jnp.array(w) for w in encoder_weights]
        decoder_weights = [jnp.array(w) for w in decoder_weights]

        # Initialise
        if prng_key is None:
            prng_key = random.PRNGKey(42)
        init_data = jnp.ones((self.xdim, self.zdim))
        key, subkey1, subkey2 = random.split(prng_key, 3)
        params = self.init(subkey1, init_data, z_rng=subkey2)

        # Replace encoder weights
        unfrozen_params = unfreeze(params)
        unfrozen_params['params']['encoder']['Conv_0'][
            'kernel'] = encoder_weights[0]
        unfrozen_params['params']['encoder']['Conv_1'][
            'kernel'] = encoder_weights[1]
        unfrozen_params['params']['encoder']['BatchNorm_0'][
            'scale'] = encoder_weights[2]
        unfrozen_params['params']['encoder']['BatchNorm_0'][
            'bias'] = encoder_weights[3]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][
            'mean'] = encoder_weights[4]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][
            'var'] = encoder_weights[5]
        unfrozen_params['params']['encoder']['Conv_2'][
            'kernel'] = encoder_weights[6]
        unfrozen_params['params']['encoder']['BatchNorm_1'][
            'scale'] = encoder_weights[7]
        unfrozen_params['params']['encoder']['BatchNorm_1'][
            'bias'] = encoder_weights[8]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][
            'mean'] = encoder_weights[9]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][
            'var'] = encoder_weights[10]
        unfrozen_params['params']['encoder']['Conv_3'][
            'kernel'] = encoder_weights[11]
        unfrozen_params['params']['encoder']['BatchNorm_2'][
            'scale'] = encoder_weights[12]
        unfrozen_params['params']['encoder']['BatchNorm_2'][
            'bias'] = encoder_weights[13]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][
            'mean'] = encoder_weights[14]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][
            'var'] = encoder_weights[15]
        unfrozen_params['params']['encoder']['Conv_4'][
            'kernel'] = encoder_weights[16]
        unfrozen_params['params']['encoder']['Dense_0'][
            'kernel'] = encoder_weights[17]
        unfrozen_params['params']['encoder']['Dense_0'][
            'bias'] = encoder_weights[18]
        unfrozen_params['params']['encoder']['Dense_1'][
            'kernel'] = encoder_weights[19]
        unfrozen_params['params']['encoder']['Dense_1'][
            'bias'] = encoder_weights[20]

        # Replace decoder weights
        unfrozen_params['params']['decoder']['ConvTranspose_0'][
            'kernel'] = np.swapaxes(decoder_weights[0], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_0'][
            'scale'] = decoder_weights[1]
        unfrozen_params['params']['decoder']['BatchNorm_0'][
            'bias'] = decoder_weights[2]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][
            'mean'] = decoder_weights[3]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][
            'var'] = decoder_weights[4]
        unfrozen_params['params']['decoder']['ConvTranspose_1'][
            'kernel'] = np.swapaxes(decoder_weights[5], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_1'][
            'scale'] = decoder_weights[6]
        unfrozen_params['params']['decoder']['BatchNorm_1'][
            'bias'] = decoder_weights[7]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][
            'mean'] = decoder_weights[8]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][
            'var'] = decoder_weights[9]
        unfrozen_params['params']['decoder']['ConvTranspose_2'][
            'kernel'] = np.swapaxes(decoder_weights[10], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_2'][
            'scale'] = decoder_weights[11]
        unfrozen_params['params']['decoder']['BatchNorm_2'][
            'bias'] = decoder_weights[12]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][
            'mean'] = decoder_weights[13]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][
            'var'] = decoder_weights[14]
        unfrozen_params['params']['decoder']['ConvTranspose_3'][
            'kernel'] = np.swapaxes(decoder_weights[15], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_3'][
            'scale'] = decoder_weights[16]
        unfrozen_params['params']['decoder']['BatchNorm_3'][
            'bias'] = decoder_weights[17]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][
            'mean'] = decoder_weights[18]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][
            'var'] = decoder_weights[19]
        unfrozen_params['params']['decoder']['ConvTranspose_4'][
            'kernel'] = np.swapaxes(decoder_weights[20], 2, 3)
        self.params = freeze(unfrozen_params)
Exemple #25
0
 def test_frozen_dict_pop(self):
     xs = {'a': 1, 'b': {'c': 2}}
     b, a = FrozenDict(xs).pop('a')
     self.assertEqual(a, 1)
     self.assertEqual(unfreeze(b), {'b': {'c': 2}})
Exemple #26
0
 def test_frozen_dict_copies(self):
     xs = {'a': 1, 'b': {'c': 2}}
     frozen = freeze(xs)
     xs['a'] += 1
     xs['b']['c'] += 1
     self.assertEqual(unfreeze(frozen), {'a': 1, 'b': {'c': 2}})
Exemple #27
0
    conv = partial(nn.conv, bias=False, dtype=dtype)
    norm = partial(norm, dtype=dtype)

    x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3)))
    x = scope.child(norm, 'init_bn')(x)
    x = act(x)
    x = nn.max_pool(x, (2, 2), (2, 2), 'SAME')

    for i, size in enumerate(block_sizes):
        for j in range(size):
            strides = (1, 1)
            if i > 0 and j == 0:
                strides = (2, 2)
            block_features = features * 2**i
            block_scope = scope.push(f'block_{i}_{j}')
            x = residual_block(block_scope, x, conv, norm, act, block_features,
                               strides)
            # we can access parameters of the sub module by operating on the scope
            # Example:
            # block_scope.get_kind('param')['conv_1']['kernel']
    x = jnp.mean(x, (1, 2))
    x = scope.child(nn.dense, 'out')(x, num_classes)
    return x


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 224, 224, 3))
    y, params = init(resnet)(random.PRNGKey(1), x)
    print(y.shape)
    print(jax.tree_map(jnp.shape, unfreeze(params)))
Exemple #28
0
 def test_frozen_dict_maps(self):
     xs = {'a': 1, 'b': {'c': 2}}
     frozen = FrozenDict(xs)
     frozen2 = jax.tree_map(lambda x: x + x, frozen)
     self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
Exemple #29
0
 def test_frozen_dict_partially_maps(self):
     x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}),
                           freeze({'a': {
                               'b': 1
                           }}))
     self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
Exemple #30
0
                              variable_carry='counter',
                              variable_in_axes={'param': lift.broadcast},
                              variable_out_axes={'param': lift.broadcast},
                              split_rngs={'param': False})(scope, (), xs)
    else:
        carry, ys = lift.scan(body_fn,
                              variable_carry='counter',
                              variable_in_axes={'param': 0},
                              variable_out_axes={'param': 0},
                              split_rngs={'param': True})(scope, (), xs)

    # output layer
    return carry, ys


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    print(
        'unshared params: (outputs should be different, parameters has extra dim)'
    )
    y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False)
    print(y)
    print(unfreeze(variables))

    print('shared params: (outputs should be the same)')
    y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True)
    print(y)
    print(unfreeze(variables))