Пример #1
0
    def test_mutate_undefined_collection(self):
        def f(scope):
            scope.put_variable('state', 'test', 123)

        msg = r'Cannot update variable "test" in "/" because collection "state" is immutable.'
        with self.assertRaisesRegex(errors.ModifyScopeVariableError, msg):
            init(f, mutable='params')(random.PRNGKey(0))
Пример #2
0
    def test_jit_cache(self):
        compiles = 0

        @lift.jit
        def f(scope, x):
            nonlocal compiles
            compiles += 1
            if scope.is_mutable_collection(
                    'intermediates'
            ) and not scope.is_mutable_collection('params'):
                scope.put_variable('intermediates', 'x', x + 1)
            return nn.dense(scope, x, 1)

        x = np.ones((3, 2))
        _, params = init(f)(random.PRNGKey(0), x)
        init(f)(random.PRNGKey(0), x)
        self.assertEqual(compiles, 1)
        apply(f)(params, x)
        self.assertEqual(compiles, 2)  # apply should cause a compile
        apply(f)(params, x)
        self.assertEqual(compiles, 2)  # applying again should not
        # edge case where only the implicit return of the jitted functions changes.
        # this should not use the previously cached apply.
        _, state = apply(f, mutable='intermediates')(params, x)
        self.assertEqual(compiles, 3)  # applying again should not
        self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2)
Пример #3
0
 def test_rng(self):
   def f(scope):
     self.assertTrue(scope.has_rng('params'))
     self.assertFalse(scope.has_rng('dropout'))
     rng = scope.make_rng('params')
     self.assertTrue(np.all(rng == random.fold_in(random.PRNGKey(0), 1)))
   init(f)(random.PRNGKey(0))
Пример #4
0
    def test_mutate_undefined_collection(self):
        def f(scope):
            scope.put_variable('state', 'test', 123)

        msg = 'Trying to update variable "test" in "/" but collection "state" is immutable.'
        with self.assertRaisesWithLiteralMatch(ValueError, msg):
            init(f, mutable='params')(random.PRNGKey(0))
Пример #5
0
  def test_aliasing(self):
    def f(scope):
      a = scope.push('a')

      def g(scopes, _):
        scope, a = scopes
        self.assertEqual(a.parent, scope)
      
      lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,)))

    init(f)(random.PRNGKey(0))
Пример #6
0
    def test_auto_encoder_bind_method(self):
        ae = lambda scope, x: AutoEncoder3.create(
            scope, latents=2, features=4, hidden=3)(x)
        x = jnp.ones((1, 4))

        x_r, variables = init(ae)(random.PRNGKey(0), x)
        self.assertEqual(x.shape, x_r.shape)
        variable_shapes = unfreeze(jax.tree_map(jnp.shape,
                                                variables['params']))
        self.assertEqual(
            variable_shapes, {
                'encode': {
                    'hidden': {
                        'kernel': (4, 3),
                        'bias': (3, )
                    },
                    'out': {
                        'kernel': (3, 2),
                        'bias': (2, )
                    },
                },
                'decode': {
                    'hidden': {
                        'kernel': (2, 3),
                        'bias': (3, )
                    },
                    'out': {
                        'kernel': (3, 4),
                        'bias': (4, )
                    },
                },
            })
Пример #7
0
 def test_auto_encoder_hp_struct(self):
     ae = AutoEncoder(latents=2, features=4, hidden=3)
     x = jnp.ones((1, 4))
     x_r, variables = init(ae)(random.PRNGKey(0), x)
     self.assertEqual(x.shape, x_r.shape)
     variable_shapes = unfreeze(jax.tree_map(jnp.shape,
                                             variables['params']))
     self.assertEqual(
         variable_shapes, {
             'encoder': {
                 'hidden': {
                     'kernel': (4, 3),
                     'bias': (3, )
                 },
                 'out': {
                     'kernel': (3, 2),
                     'bias': (2, )
                 },
             },
             'decoder': {
                 'hidden': {
                     'kernel': (2, 3),
                     'bias': (3, )
                 },
                 'out': {
                     'kernel': (3, 4),
                     'bias': (4, )
                 },
             },
         })
Пример #8
0
    def test_attention(self):
        inputs = jnp.ones((2, 7, 16))
        model = partial(multi_head_dot_product_attention,
                        num_heads=2,
                        batch_axes=(0, ),
                        attn_fn=with_dropout(softmax_attn,
                                             0.1,
                                             deterministic=False))

        rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
        y, variables = jax.jit(init(model))(rngs, inputs, inputs)
        variable_shapes = jax.tree_map(jnp.shape, variables['params'])
        self.assertEqual(y.shape, (2, 7, 16))
        self.assertEqual(
            unfreeze(variable_shapes), {
                'key': {
                    'kernel': (2, 16, 8)
                },
                'value': {
                    'kernel': (2, 16, 8)
                },
                'query': {
                    'kernel': (2, 16, 8)
                },
                'out': {
                    'bias': (2, 16),
                    'kernel': (2, 8, 16)
                },
            })
Пример #9
0
 def test_explicit_dense(self):
     x = jnp.ones((1, 3))
     y, variables = init(explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 4))
     self.assertEqual(param_shapes, {
         'kernel': (3, 4),
         'bias': (4, ),
     })
Пример #10
0
    def test_tied_auto_encoder(self):
        ae = TiedAutoEncoder(latents=2, features=4)
        x = jnp.ones((1, ae.features))
        x_r, variables = init(ae)(random.PRNGKey(0), x)

        param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
        self.assertEqual(param_shapes, {
            'kernel': (4, 2),
        })
        self.assertEqual(x.shape, x_r.shape)
Пример #11
0
    def test_init_from_decoder(self):
        ae = TiedAutoEncoder(latents=2, features=4)
        z = jnp.ones((1, ae.latents))
        x_r, variables = init(ae.decode)(random.PRNGKey(0), z)

        param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
        self.assertEqual(param_shapes, {
            'kernel': (4, 2),
        })
        self.assertEqual(x_r.shape, (1, 4))
Пример #12
0
 def test_explicit_dense(self):
     x = jnp.ones((1, 4))
     y, variables = init(explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 1))
     self.assertEqual(
         param_shapes, {
             'dense_0': ExplicitDense((4, 3), (3, )),
             'dense_1': ExplicitDense((3, 1), (1, ))
         })
Пример #13
0
 def test_flow(self):
   x = jnp.ones((1, 3))
   flow = StackFlow((DenseFlow(),) * 3)
   y, variables = init(flow.forward)(random.PRNGKey(0), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   self.assertEqual(y.shape, (1, 3))
   self.assertEqual(param_shapes, {
       '0': {'kernel': (3, 3), 'bias': (3,)},
       '1': {'kernel': (3, 3), 'bias': (3,)},
       '2': {'kernel': (3, 3), 'bias': (3,)},
   })
   x_restored = apply(flow.backward)(variables, y)
   self.assertTrue(jnp.allclose(x, x_restored))
Пример #14
0
  def test_vmap_unshared(self):
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False)

    param_shapes = unfreeze(
        jax.tree_map(jnp.shape, variables['params']))
    self.assertEqual(param_shapes, {
        'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)},
        'out': {'kernel': (2, 8, 1), 'bias': (2, 1)},
    })
    self.assertEqual(y.shape, (2, 1))
    self.assertFalse(jnp.allclose(y[0], y[1]))
Пример #15
0
  def test_weight_std(self):
    x = random.normal(random.PRNGKey(0), (1, 4,))
    y, variables = init(mlp)(random.PRNGKey(1), x)

    param_shapes = unfreeze(
        jax.tree_map(jnp.shape, variables['params']))
    self.assertEqual(param_shapes, {
        'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
        'out': {'kernel': (8, 1), 'bias': (1,)},
    })
    self.assertEqual(y.shape, (1, 1))
    self.assertTrue(y.ravel() < 1.)

    y2 = apply(mlp)(variables, x)
    self.assertTrue(jnp.allclose(y, y2))
Пример #16
0
 def test_semi_explicit_dense(self):
     x = jnp.ones((1, 4))
     y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x)
     param_shapes = unfreeze(jax.tree_map(jnp.shape, variables['params']))
     self.assertEqual(y.shape, (1, 1))
     self.assertEqual(
         param_shapes, {
             'dense_0': {
                 'kernel': (4, 3),
                 'bias': (3, )
             },
             'dense_1': {
                 'kernel': (3, 1),
                 'bias': (1, )
             }
         })
Пример #17
0
 def test_custom_vjp(self):
   x = random.normal(random.PRNGKey(0), (1, 4))
   y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2)
   grad = jax.grad(loss_fn)(variables, x)
   grad_shapes = unfreeze(
       jax.tree_map(jnp.shape, grad['params']))
   self.assertEqual(y.shape, (1, 1))
   expected_param_shapes = {
       'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
       'out': {'kernel': (8, 1), 'bias': (1,)},
   }
   self.assertEqual(param_shapes, expected_param_shapes)
   self.assertEqual(grad_shapes, expected_param_shapes)
   for g in jax.tree_leaves(grad):
     self.assertTrue(np.all(g == np.sign(g)))
Пример #18
0
def init_with_output(
    fn: Callable[..., Any],
    module: Module,
    mutable: CollectionFilter = True
) -> Callable[..., Tuple[Any, FrozenVariableDict]]:
    """Creates an init function to call ``fn`` with a bound module that also returns the function outputs.

  Unlike ``Module.init_with_output`` this function returns a new function with the signature
  ``(rngs, *args, **kwargs) -> (T, variables)`` where `T` is the return type of ``fn``.
  The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is
  equivalant to passing a dict with one PRNGKey with the name "params".

  The init function that is returned can be directly composed with
  JAX transformations like ``jax.jit``::

    def f(foo, x):
      z = foo.encode(x)
      y = foo.decode(z)
      # ...
      return y
    
    foo = Foo()
    f_jitted = jax.jit(nn.init_with_output(f, foo))
    y, variables = f_jitted(rng, x)

  Args:
    fn: The function that should be applied. The first argument passed will
      be an module instance of the ``module`` with variables and RNGs bound
      to it.
    module: The ``Module`` that will be used to bind variables and RNGs to.
      The ``Module`` passed as the first argument to ``fn`` will be a clone
      of module.
    mutable: Can be bool, str, or list. Specifies which collections should be
      treated as mutable: ``bool``: all/no collections are mutable.
      ``str``: The name of a single mutable collection. ``list``: A
      list of names of mutable collections.
  Returns:
    The init function wrapping ``fn``.
  """
    @functools.wraps(fn)
    def scope_fn(scope, *args, **kwargs):
        return fn(module.clone(parent=scope), *args, **kwargs)

    return core.init(scope_fn, mutable=mutable)
Пример #19
0
 def test_big_resnet(self):
   x = random.normal(random.PRNGKey(0), (1, 8, 8, 8))
   y, variables = init(big_resnet)(random.PRNGKey(1), x)
   self.assertEqual(y.shape, (1, 8, 8, 8))
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   batch_stats_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['batch_stats']))
   print(param_shapes)
   self.assertEqual(param_shapes, {
       'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)},
       'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)},
       'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)},
       'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}
   })
   self.assertEqual(batch_stats_shapes, {
       'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)},
       'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)}
   })
Пример #20
0
      x = nn.relu(x)
  return x

def semi_explicit_mlp(scope, x, sizes=(3, 1)):
  for i, size in enumerate(sizes):
    dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')(x.shape[-1], size)
    x = dense(x)
    if i + 1 < len(sizes):
      x = nn.relu(x)
  return x

if __name__ == "__main__":
  model = Dense(features=4)
  x = jnp.ones((1, 3))

  y, params = init(model)(random.PRNGKey(0), x)

  print(y)
  print(params)


  print('explicit dense:')
  y, params = init(explicit_mlp)(random.PRNGKey(0), x)

  print(y)
  print(params)

  print('semi-explicit dense:')
  y, params = init(semi_explicit_mlp)(random.PRNGKey(0), x)

  print(y)
Пример #21
0
                              variable_carry='counter',
                              variable_in_axes={'param': lift.broadcast},
                              variable_out_axes={'param': lift.broadcast},
                              split_rngs={'param': False})(scope, (), xs)
    else:
        carry, ys = lift.scan(body_fn,
                              variable_carry='counter',
                              variable_in_axes={'param': 0},
                              variable_out_axes={'param': 0},
                              split_rngs={'param': True})(scope, (), xs)

    # output layer
    return carry, ys


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    print(
        'unshared params: (outputs should be different, parameters has extra dim)'
    )
    y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False)
    print(y)
    print(unfreeze(variables))

    print('shared params: (outputs should be the same)')
    y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True)
    print(y)
    print(unfreeze(variables))
Пример #22
0
    else:
        dense_vmap = lift.vmap(nn.dense,
                               in_axes=(0, None),
                               variable_axes={'params': 0},
                               split_rngs={'params': True})

    # hidden layers
    for size in sizes[:-1]:
        x = scope.child(dense_vmap)(x, size)
        x = act_fn(x)

    # output layer
    return scope.child(dense_vmap)(x, sizes[-1])


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    print('shared params: (same inputs, same outputs)')
    y, params = init(mlp_vmap)(random.PRNGKey(1), x, share_params=True)
    print(y)
    print(jax.tree_map(jnp.shape, unfreeze(params)))

    print(
        'unshared params: (sampe inputs, different outputs, extra dim in params)'
    )
    y, params = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False)
    print(y)
    print(jax.tree_map(jnp.shape, unfreeze(params)))
Пример #23
0
    return y, x

  def bwd(features, scope_fn, params, res, g):
    x = res
    fn = lambda params, x: nn.dense(scope_fn(params), x, features)
    _, pullback = jax.vjp(fn, params, x)
    g_param, g_x = pullback(g)
    g_param = jax.tree_map(jnp.sign, g_param)
    return g_param, g_x

  dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,))

  # hidden layers
  for size in sizes[:-1]:
    x = scope.child(dense_custom_grad)(x, size)
    x = act_fn(x)

  # output layer
  return scope.child(dense_custom_grad)(x, sizes[-1])

if __name__ == "__main__":
  x = random.normal(random.PRNGKey(0), (1, 4))
  x = jnp.concatenate([x, x], 0)

  print('shared params: (same inputs, same outputs)')
  y, params = init(mlp_custom_grad)(random.PRNGKey(1), x)
  print(y)
  print(jax.tree_map(jnp.shape, unfreeze(params)))

  print(jax.grad(lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x) ** 2))(params, x))
Пример #24
0
    conv = partial(nn.conv, bias=False, dtype=dtype)
    norm = partial(norm, dtype=dtype)

    x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3)))
    x = scope.child(norm, 'init_bn')(x)
    x = act(x)
    x = nn.max_pool(x, (2, 2), (2, 2), 'SAME')

    for i, size in enumerate(block_sizes):
        for j in range(size):
            strides = (1, 1)
            if i > 0 and j == 0:
                strides = (2, 2)
            block_features = features * 2**i
            block_scope = scope.push(f'block_{i}_{j}')
            x = residual_block(block_scope, x, conv, norm, act, block_features,
                               strides)
            # we can access parameters of the sub module by operating on the scope
            # Example:
            # block_scope.get_kind('param')['conv_1']['kernel']
    x = jnp.mean(x, (1, 2))
    x = scope.child(nn.dense, 'out')(x, num_classes)
    return x


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 224, 224, 3))
    y, params = init(resnet)(random.PRNGKey(1), x)
    print(y.shape)
    print(jax.tree_map(jnp.shape, unfreeze(params)))
Пример #25
0
  def test_mutate_undefined_collection(self):
    def f(scope):
      scope.put_variable('test', 'test', 123)

    with self.assertRaisesWithLiteralMatch(ValueError, 'Collection is not mutable: "test"'):
      init(f, mutable='params')(random.PRNGKey(0))
Пример #26
0
        return jnp.dot(x, expm(kernel)) + bias

    def backward(self, scope: Scope, y: Array):
        kernel, bias = self.params(scope, y.shape[-1])
        return jnp.dot(y - bias, expm(-kernel))


@dataclass
class StackFlow:
    flows: Sequence[Flow]

    def forward(self, scope: Scope, x: Array):
        for i, f in enumerate(self.flows):
            x = scope.child(f.forward, name=str(i))(x)
        return x

    def backward(self, scope: Scope, x: Array):
        for i, f in reversed(tuple(enumerate(self.flows))):
            x = scope.child(f.backward, name=str(i))(x)
        return x


if __name__ == "__main__":
    flow = StackFlow((DenseFlow(), ) * 3)
    # forward and backward are interchangeable here
    # so shape inference and initialization can be done on the forward and backward pass of the flow
    y, params = init(flow.forward)(random.PRNGKey(0), jnp.ones((1, 3)))
    print(params)
    x_restore = apply(flow.backward)(params, y)
    print(x_restore)
Пример #27
0
  def _tied(self, fn, transpose=False):
    if not transpose:
      return fn

    def trans(variables):
      if 'param' not in variables:
        return variables
      params = variables['param']
      params['kernel'] = params['kernel'].T
      return variables

    return lift.transform_module(
        fn, trans_in_fn=trans, trans_out_fn=trans)

if __name__ == "__main__":
  ae = TiedAutoEncoder(latents=2, features=4)
  x = jnp.ones((1, ae.features))

  x_r, params = init(ae)(random.PRNGKey(0), x)

  print(x, x_r)
  print(params)


  print('init from decoder:')
  z = jnp.ones((1, ae.latents))
  x_r, params = init(ae.decode)(random.PRNGKey(0), z)

  print(apply(ae)(params, x))
  print(params)
Пример #28
0
    # the transformed kind will be immutable inside fn
    # this way we avoid lost mutations to param
    # transform also avoids accidental reuse of rngs
    # and it makes sure that other state is updated correctly (not twice during init!)
    return lift.transform_module(fn, trans_in_fn=std)


def mlp(scope: Scope,
        x: Array,
        sizes: Sequence[int] = (2, 4, 1),
        act_fn: Callable[[Array], Array] = nn.relu):
    std_dense = weight_std(
        partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5)))
    # hidden layers
    for size in sizes[:-1]:
        x = scope.child(std_dense, prefix='hidden_')(x, size)
        # x = act_fn(x)

    # output layer
    return scope.child(nn.dense, 'out')(x, sizes[-1])


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (
        1,
        4,
    ))
    y, params = init(mlp)(random.PRNGKey(1), x)
    print(y)
    print(jax.tree_map(jnp.shape, unfreeze(params)))
Пример #29
0
    for axis in reversed(sorted(batch_axes)):
        attn_fn = lift.vmap(attn_fn,
                            in_axes=(axis, axis, axis),
                            out_axes=axis,
                            variable_axes={'params': None},
                            split_rngs={
                                'params': False,
                                'dropout': not broadcast_dropout
                            })

    y = attn_fn(scope, inputs_q, inputs_kv, bias)
    return y.mean(axis=-2)


if __name__ == "__main__":
    inputs = jnp.ones((2, 7, 16))

    y, variables = init(multi_head_dot_product_attention)(
        {
            'params': random.PRNGKey(0),
            'dropout': random.PRNGKey(1)
        },
        inputs,
        inputs,
        num_heads=2,
        batch_axes=(0, ),
        attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False))

    print(y.shape)
    print(jax.tree_map(jnp.shape, variables))
Пример #30
0
 def init_fun(rng, *args, **kwargs):
     return init(core_fun)(rng, *args, **kwargs)