Beispiel #1
0
    def test_toplevel_submodule_adoption_sharing(self):
        dense = functools.partial(nn.Dense, use_bias=False)

        class A(nn.Module):
            @nn.compact
            def __call__(self, x):
                return dense(2)(x)

        class B(nn.Module):
            a: nn.Module

            @nn.compact
            def __call__(self, x):
                return dense(2)(x) + self.a(x)

        class C(nn.Module):
            a: nn.Module
            b: nn.Module

            @nn.compact
            def __call__(self, x):
                return dense(2)(x) + self.b(x) + self.a(x)

        key = random.PRNGKey(0)
        x = jnp.ones((2, 2))
        a = A()
        b = B(a)
        c = C(a, b)
        p = c.init(key, x)
        var_shapes = jax.tree_map(jnp.shape, p)
        ref_var_shapes = freeze({
            'params': {
                'Dense_0': {
                    'kernel': (2, 2),
                },
                'a': {
                    'Dense_0': {
                        'kernel': (2, 2),
                    },
                },
                'b': {
                    'Dense_0': {
                        'kernel': (2, 2),
                    },
                },
            },
        })
        self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
Beispiel #2
0
 def test_partially_applied_module_constructor_transform(self):
   k = random.PRNGKey(0)
   x = jnp.ones((3,4,4))
   dense = partial(nn.Dense, use_bias=False)
   vmap_dense = nn.vmap(
       dense,
       variable_axes={'params':0},
       split_rngs={'params':True})(4)
   init_vars = vmap_dense.init(k, x)
   init_vars_shapes = jax.tree_map(jnp.shape, init_vars)
   ref_var_shapes = freeze({
       'params': {
           'kernel': (3, 4, 4),
       },
   })
   self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))
Beispiel #3
0
def train_full(
    update_fn,
    validation_fn,
    mask_update_fn,
    optimizer,
    state,
    max_epochs=1e4,
    convergence_args={
        "patience": 200,
        "delta": 1e-3
    },
    mask_update_args={
        "patience": 500,
        "delta": 1e-5,
        "periodicity": 200
    },
):

    logger = Logger()
    converged = Convergence(**convergence_args)
    update_mask = mask_scheduler(**mask_update_args)

    for epoch in jnp.arange(max_epochs):
        (optimizer, state), metrics, output = update_fn(optimizer, state)
        prediction, dt, theta, coeffs = output

        if epoch % 1000 == 0:
            print(f"Loss step {epoch}: {metrics['loss']}")

        if epoch % 25 == 0:
            val_metric = validation_fn(optimizer, state)
            metrics = {**metrics, "validation_metric": val_metric}
            logger.write(metrics, epoch)
            apply_sparsity, optimizer = update_mask(val_metric, epoch,
                                                    optimizer)

            if apply_sparsity:
                mask = mask_update_fn(theta, dt)
                state = freeze({"vars": {"LeastSquares_0": {"mask": mask}}})

        if converged(epoch, coeffs):
            mask = mask_update_fn(theta, dt)
            print(f"Converged at epoch {epoch} with mask {mask[:, None]}.")
            break

    logger.close()
    return optimizer, state
Beispiel #4
0
    def test_toplevel_submodule_adoption_pytree_transform(self):
        class A(nn.Module):
            @nn.compact
            def __call__(self, c, x):
                counter = self.variable('counter', 'i', jnp.zeros, ())
                counter.value += 1
                x = nn.Dense(1)(x)
                return c, x

        class B(nn.Module):
            A: Any

            @nn.compact
            def __call__(self, c, x):
                return self.A['foo'](*self.A['bar'](c, x))

        a = A()
        As = {'foo': A(), 'bar': A()}
        b = nn.scan(B,
                    in_axes=0,
                    variable_carry='counter',
                    variable_broadcast='params',
                    split_rngs={'params': False})(As)

        key = random.PRNGKey(0)
        x = jnp.ones((10, 2))

        p = B(As).init(key, x, x)
        y, cntrs = b.apply(p, x, x, mutable='counter')
        ref_cntrs = freeze({
            'counter': {
                'A_bar': {
                    'i': jnp.array(11.0),
                },
                'A_foo': {
                    'i': jnp.array(11.0),
                },
            },
        })
        self.assertTrue(
            jax.tree_util.tree_all(
                jax.tree_multimap(
                    lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
                    cntrs, ref_cntrs)))
Beispiel #5
0
    def test_toplevel_submodule_adoption_pytree(self):
        class A(nn.Module):
            @nn.compact
            def __call__(self, c, x):
                counter = self.variable('counter', 'i', jnp.zeros, ())
                counter.value += 1
                x = nn.Dense(1)(x)
                return c, x

        class B(nn.Module):
            A: Any

            @nn.compact
            def __call__(self, c, x):
                return self.A['foo'](*self.A['bar'](c, x))

        a = A()
        As = {'foo': A(), 'bar': A()}
        b = B(As)

        key = random.PRNGKey(0)
        x = jnp.ones((2, 2))

        p = B(As).init(key, x, x)
        print('apply', x.shape)
        y, cntrs = b.apply(p, x, x, mutable='counter')
        ref_cntrs = freeze({
            'counter': {
                'A_bar': {
                    'i': jnp.array(2.0),
                },
                'A_foo': {
                    'i': jnp.array(2.0),
                },
            },
        })
        self.assertTrue(
            jax.tree_util.tree_all(
                jax.tree_multimap(
                    lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
                    cntrs, ref_cntrs)))
Beispiel #6
0
  def test_module_pass_in_closure(self):
    a = nn.Dense(2)
        
    class B(nn.Module):
      def setup(self):
        self.foo = a

      def __call__(self, x):
        return self.foo(x)

    variables = B().init(random.PRNGKey(0), jnp.ones((1,)))
    var_shapes = jax.tree_map(jnp.shape, variables)
    ref_var_shapes = freeze({
      'params': {
          'foo': {
              'bias': (2,),
              'kernel': (1, 2),
          }
      },
    })
    self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
    self.assertEqual(a.name, None)
Beispiel #7
0
    def test_toplevel_named_submodule_adoption(self):
        dense = functools.partial(nn.Dense, use_bias=False)

        class A(nn.Module):
            def setup(self):
                self.dense = dense(4)

            def __call__(self, x):
                return self.dense(x)

        class B(nn.Module):
            a: A

            def setup(self):
                self.proj = dense(6)

            def __call__(self, x):
                return self.proj(self.a(x))

        a = A(name='foo')
        b = B(a=a)
        k = jax.random.PRNGKey(0)
        x = jnp.zeros((5, 5))
        init_vars = b.init(k, x)
        var_shapes = jax.tree_map(jnp.shape, init_vars)
        ref_var_shapes = freeze({
            'params': {
                'a': {
                    'dense': {
                        'kernel': (5, 4),
                    },
                },
                'proj': {
                    'kernel': (4, 6),
                },
            },
        })
        self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
Beispiel #8
0
    def test_same(self, partial_model_pair, hilbert, dtype, machine_pow, skip):
        batch_size = 3

        model1 = partial_model_pair[0](hilbert, dtype, machine_pow)
        model2 = partial_model_pair[1](hilbert, dtype, machine_pow)

        key_spins, key_model = jax.random.split(jax.random.PRNGKey(0))
        spins = hilbert.random_state(key_spins, size=batch_size)
        variables = model2.init(key_model,
                                spins,
                                0,
                                method=model2._conditional)

        p1 = model1.apply(variables, spins, method=model1.conditionals)
        p2 = model2.apply(variables, spins, method=model2.conditionals)

        # Results from `FastARNN*.conditionals` should be the same as those from `ARNN*.conditionals`
        np.testing.assert_allclose(p2, p1)

        p3 = jnp.zeros_like(p1)
        params = variables["params"]
        cache = variables["cache"]
        for i in range(hilbert.size):
            variables = freeze({"params": params, "cache": cache})
            p_i, mutables = model2.apply(
                variables,
                spins,
                i,
                method=model2._conditional,
                mutable=["cache"],
            )
            cache = mutables["cache"]
            p3 = p3.at[:, i, :].set(p_i)

        # Results from `FastARNN*.conditional` should be the same as those from `ARNN*.conditionals`
        np.testing.assert_allclose(p3, p1)
Beispiel #9
0
def convert_pre_linen(params):
    """Converts a pre-Linen parameter pytree.

  In pre-Linen API submodules were numbered incrementally, independent of the
  submodule class. With Linen this behavior has changed to keep separate
  submodule counts per module class.

  Consider the following module:

    class Model(nn.Module):
      @nn.compact
      def __call__(self, x):
        x = nn.Conv(1, 1)(x)
        x = nn.Dense(1)(x)
        return x

  In pre-Linen the resulting params would have had the structure:
    {'Conv_0': { ... }, 'Dense_1': { ... } }

  With Linen the resulting params would instead have had the structure:
    {'Conv_0': { ... }, 'Dense_0': { ... } }

  To convert from pre-Linen format to Linen simply call:

    params = convert_pre_linen(pre_linen_params)

  Note that you can also use this utility to convert pre-Linen collections
  because they're following the same module naming. Note though that collections
  were "flat" in pre-Linen and first need to be unflattened before they can be
  used with this function:

    batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({
        tuple(k.split('/')[1:]): v
        for k, v in pre_linen_model_state.as_dict().items()
    }))

  Then Linen variables can be defined from these converted collections:

    variables = {'params': params, 'batch_stats': batch_stats}

  Args:
    params: Parameter pytree in pre-Linen format. If the pytree is already in
      Linen format, then the returned pytree is unchanged (i.e. this function
      can safely be called on any loaded checkpoint for use with Linen).

  Returns:
    Parameter pytree with Linen submodule naming.
  """
    if not isinstance(params, (dict, core.FrozenDict)):
        return params
    params_renamed = {}
    counts = {}
    names = natural_sort(params.keys())
    for name in names:
        value = params[name]
        match = MODULE_NUM_RE.match(name)
        if match:
            module = match.group(1)
            num = counts.get(module, 0)
            name = f'{module}_{num}'
            counts[module] = num + 1
        params_renamed[name] = convert_pre_linen(value)

    return core.freeze(params_renamed)
Beispiel #10
0
 def wrap_params(variables):
     return freeze({"params": variables})
Beispiel #11
0
 def init(self, keys, inpt):
     _, variables = self.ifun(keys["params"], inpt.shape)
     return freeze({"params": variables})
Beispiel #12
0
 def test_eq(self):
     dg1 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}})
     dg2 = DotGetter({'a': 1, 'b': {'c': 2, 'd': 3}})
     self.assertEqual(dg1, dg2)
     self.assertEqual(freeze(dg1), dg2)
     self.assertEqual(freeze(dg1), freeze(dg2))
Beispiel #13
0
def update_varcolbykey(var, col_name, target, leaf_only=True):
    wocol, col = var.pop(col_name)
    col = dict_replace(col, target, leaf_only=leaf_only)
    del var
    return freeze({**wocol, col_name: col})
Beispiel #14
0
    def load_keras_model(self, checkpoint, prng_key=None):
        # Create the Keras beta-VAE
        keras_nn = get_Neural_Network(1e-3, 'softplus', 'chi_sq')
        models, model_loss_function, reconstruction_loss_function = keras_nn
        # Load weights into keras model from the given checkpoint
        models['vae'].load_weights(checkpoint).expect_partial()
        encoder_weights = models['encoder'].get_weights()
        decoder_weights = models['decoder'].get_weights()
        # Recast as JAX device arrays to enable autodiff through the model
        encoder_weights = [jnp.array(w) for w in encoder_weights]
        decoder_weights = [jnp.array(w) for w in decoder_weights]

        # Initialise
        if prng_key is None:
            prng_key = random.PRNGKey(42)
        init_data = jnp.ones((self.xdim, self.zdim))
        key, subkey1, subkey2 = random.split(prng_key, 3)
        params = self.init(subkey1, init_data, z_rng=subkey2)

        # Replace encoder weights
        unfrozen_params = unfreeze(params)
        unfrozen_params['params']['encoder']['Conv_0'][
            'kernel'] = encoder_weights[0]
        unfrozen_params['params']['encoder']['Conv_1'][
            'kernel'] = encoder_weights[1]
        unfrozen_params['params']['encoder']['BatchNorm_0'][
            'scale'] = encoder_weights[2]
        unfrozen_params['params']['encoder']['BatchNorm_0'][
            'bias'] = encoder_weights[3]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][
            'mean'] = encoder_weights[4]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_0'][
            'var'] = encoder_weights[5]
        unfrozen_params['params']['encoder']['Conv_2'][
            'kernel'] = encoder_weights[6]
        unfrozen_params['params']['encoder']['BatchNorm_1'][
            'scale'] = encoder_weights[7]
        unfrozen_params['params']['encoder']['BatchNorm_1'][
            'bias'] = encoder_weights[8]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][
            'mean'] = encoder_weights[9]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_1'][
            'var'] = encoder_weights[10]
        unfrozen_params['params']['encoder']['Conv_3'][
            'kernel'] = encoder_weights[11]
        unfrozen_params['params']['encoder']['BatchNorm_2'][
            'scale'] = encoder_weights[12]
        unfrozen_params['params']['encoder']['BatchNorm_2'][
            'bias'] = encoder_weights[13]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][
            'mean'] = encoder_weights[14]
        unfrozen_params['batch_stats']['encoder']['BatchNorm_2'][
            'var'] = encoder_weights[15]
        unfrozen_params['params']['encoder']['Conv_4'][
            'kernel'] = encoder_weights[16]
        unfrozen_params['params']['encoder']['Dense_0'][
            'kernel'] = encoder_weights[17]
        unfrozen_params['params']['encoder']['Dense_0'][
            'bias'] = encoder_weights[18]
        unfrozen_params['params']['encoder']['Dense_1'][
            'kernel'] = encoder_weights[19]
        unfrozen_params['params']['encoder']['Dense_1'][
            'bias'] = encoder_weights[20]

        # Replace decoder weights
        unfrozen_params['params']['decoder']['ConvTranspose_0'][
            'kernel'] = np.swapaxes(decoder_weights[0], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_0'][
            'scale'] = decoder_weights[1]
        unfrozen_params['params']['decoder']['BatchNorm_0'][
            'bias'] = decoder_weights[2]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][
            'mean'] = decoder_weights[3]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_0'][
            'var'] = decoder_weights[4]
        unfrozen_params['params']['decoder']['ConvTranspose_1'][
            'kernel'] = np.swapaxes(decoder_weights[5], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_1'][
            'scale'] = decoder_weights[6]
        unfrozen_params['params']['decoder']['BatchNorm_1'][
            'bias'] = decoder_weights[7]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][
            'mean'] = decoder_weights[8]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_1'][
            'var'] = decoder_weights[9]
        unfrozen_params['params']['decoder']['ConvTranspose_2'][
            'kernel'] = np.swapaxes(decoder_weights[10], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_2'][
            'scale'] = decoder_weights[11]
        unfrozen_params['params']['decoder']['BatchNorm_2'][
            'bias'] = decoder_weights[12]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][
            'mean'] = decoder_weights[13]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_2'][
            'var'] = decoder_weights[14]
        unfrozen_params['params']['decoder']['ConvTranspose_3'][
            'kernel'] = np.swapaxes(decoder_weights[15], 2, 3)
        unfrozen_params['params']['decoder']['BatchNorm_3'][
            'scale'] = decoder_weights[16]
        unfrozen_params['params']['decoder']['BatchNorm_3'][
            'bias'] = decoder_weights[17]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][
            'mean'] = decoder_weights[18]
        unfrozen_params['batch_stats']['decoder']['BatchNorm_3'][
            'var'] = decoder_weights[19]
        unfrozen_params['params']['decoder']['ConvTranspose_4'][
            'kernel'] = np.swapaxes(decoder_weights[20], 2, 3)
        self.params = freeze(unfrozen_params)
Beispiel #15
0
  def test_toplevel_submodule_adoption_transform(self):
    class A(nn.Module):
      @nn.compact
      def __call__(self, x):
        return nn.Dense(3)(x)
    class B(nn.Module):
      A: nn.Module
      @nn.compact
      def __call__(self, x):
        return self.A(x)
    class C(nn.Module):
      A: nn.Module
      B: nn.Module
      @partial(
          nn.vmap,
          variable_axes={'params': 0},
          split_rngs={'params': True})
      @nn.compact
      def __call__(self, x):
        return self.B(x) + self.A(x)
    class Csimple(nn.Module):
      A: nn.Module
      B: nn.Module
      @nn.compact
      def __call__(self, x):
        return self.B(x) + self.A(x)
    class D(nn.Module):
      @nn.compact
      def __call__(self, x):
        a1 = A()
        a2 = A()
        b = B(a1)
        c = C(a2, b)
        return c(x)

    key = random.PRNGKey(0)
    x = jnp.ones((10, 10))
    p1 = D().init(key, x)
    y1 = D().apply(p1, x)

    a1 = A()
    a2 = A()
    b = B(a1)
    p2 = freeze({'params': {
        'A': p1['params']['A_0'],
        'B': {
            'A': p1['params']['A_1'],
        }
    }})

    print(jax.tree_map(jnp.shape, p1))
    # Test method wrapper transform.
    y2 = C(a2, b).apply(p2, x)
    np.testing.assert_allclose(y1, y2, atol=1e-7)
    # Test class transform.
    Ctrafo = nn.vmap(Csimple,
                     variable_axes={'params': 0},
                     split_rngs={'params': True})

    y3 = Ctrafo(a2, b).apply(p2, x)
    np.testing.assert_allclose(y1, y3, atol=1e-7)
Beispiel #16
0
 def init(self, rng, *args, **kwargs):
     variables = self.transformed.init(rng["params"], *args, **kwargs)
     return freeze({"params": variables})
Beispiel #17
0
def unflatten_params(flat_params):
    return freeze(
        traverse_util.unflatten_dict(
            {tuple(k.split("/")): v
             for k, v in flat_params.items()}))
Beispiel #18
0
    def test_frozen_items(self):
        xs = {'a': 1, 'b': {'c': 2}}
        items = list(freeze(xs).items())

        self.assertEqual(items, [('a', 1), ('b', freeze(xs['b']))])
Beispiel #19
0
 def test_frozen_dict_hash(self):
     xs = {'a': 1, 'b': {'c': 2}}
     ys = {'a': 1, 'b': {'c': 3}}
     self.assertNotEqual(hash(freeze(xs)), hash(freeze(ys)))
Beispiel #20
0
 def test_frozen_dict_partially_maps(self):
     x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}),
                           freeze({'a': {
                               'b': 1
                           }}))
     self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
Beispiel #21
0
 def test_frozen_dict_copies(self):
     xs = {'a': 1, 'b': {'c': 2}}
     frozen = freeze(xs)
     xs['a'] += 1
     xs['b']['c'] += 1
     self.assertEqual(unfreeze(frozen), {'a': 1, 'b': {'c': 2}})