Exemple #1
0
    def test_inconsistent_param_shapes(self):
        def f(scope):
            scope.param('test', nn.initializers.ones, (4, ))

        msg = 'Inconsistent shapes between value and initializer for parameter "test" in "/": (2,), (4,)'
        with self.assertRaisesWithLiteralMatch(ValueError, msg):
            apply(f)(freeze({'params': {'test': np.ones((2, ))}}))
Exemple #2
0
    def test_variable_is_mutable(self):
        def f(scope, should_be_mutable):
            test = scope.variable('state', 'test', lambda: 1)
            self.assertEqual(test.is_mutable(), should_be_mutable)

        _, variables = apply(f, mutable='state')({}, True)
        apply(f, mutable=False)(variables, False)
Exemple #3
0
    def test_undefined_param(self):
        def f(scope):
            nn.dense(scope.push('dense'), np.ones((1, 2)), 2)

        with self.assertRaisesWithLiteralMatch(
                ValueError, 'No paramater named "kernel" exists in "/dense".'):
            apply(f)({})
Exemple #4
0
    def test_jit_cache(self):
        compiles = 0

        @lift.jit
        def f(scope, x):
            nonlocal compiles
            compiles += 1
            if scope.is_mutable_collection(
                    'intermediates'
            ) and not scope.is_mutable_collection('params'):
                scope.put_variable('intermediates', 'x', x + 1)
            return nn.dense(scope, x, 1)

        x = np.ones((3, 2))
        _, params = init(f)(random.PRNGKey(0), x)
        init(f)(random.PRNGKey(0), x)
        self.assertEqual(compiles, 1)
        apply(f)(params, x)
        self.assertEqual(compiles, 2)  # apply should cause a compile
        apply(f)(params, x)
        self.assertEqual(compiles, 2)  # applying again should not
        # edge case where only the implicit return of the jitted functions changes.
        # this should not use the previously cached apply.
        _, state = apply(f, mutable='intermediates')(params, x)
        self.assertEqual(compiles, 3)  # applying again should not
        self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2)
Exemple #5
0
    def test_undefined_param(self):
        def f(scope):
            nn.dense(scope.push('dense'), np.ones((1, 2)), 2)

        msg = r'No parameter named "kernel" exists in "/dense".'
        with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg):
            apply(f)({})
Exemple #6
0
    def test_inconsistent_param_shapes(self):
        def f(scope):
            scope.param('test', nn.initializers.ones, (4, ))

        msg = r'Inconsistent shapes between value and initializer for parameter "test" in "/": \(2,\), \(4,\).'
        with self.assertRaisesRegex(errors.ScopeParamShapeError, msg):
            apply(f)(freeze({'params': {'test': np.ones((2, ))}}))
Exemple #7
0
  def test_undefined_param(self):
    def f(scope):
      dense = lift.vmap(nn.dense, 
                        in_axes=(0, None), out_axes=0,
                        variable_axes={'params': 0},
                        split_rngs={'params': True})
      dense(scope.push('dense'), np.ones((3, 2)), 2)

    with self.assertRaisesWithLiteralMatch(ValueError, 'No parameter named "kernel" exists in "/vmap(dense)".'):
      apply(f)({})
Exemple #8
0
    def test_undefined_param(self):
        def f(scope):
            dense = lift.vmap(nn.dense,
                              in_axes=(0, None),
                              out_axes=0,
                              variable_axes={'params': 0},
                              split_rngs={'params': True})
            dense(scope.push('dense'), np.ones((3, 2)), 2)

        msg = r'No parameter named "kernel" exists in "/vmap\(dense\)".'
        with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg):
            apply(f)({})
Exemple #9
0
    def apply(self,
              variables: VariableDict,
              *args,
              rngs: RNGSequences = None,
              method: Callable[..., Any] = None,
              mutable: Union[bool, str, Sequence[str]] = False,
              **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
        """Applies a module method to variables and returns output and modified variables.

    Args:
      variables: A dictionary containing variables keyed by variable
        collections. See :mod:`flax.core.variables` for more details
        about variables.
      rngs: The rngs for the variable collections.
      method: The literal name of a method in this class. If provided, applies
        this method. If not provided, applies the ``__call__`` method.
      mutable: Can be bool, str, or list. Specifies which collections should be
               treated as mutable: ``bool``: all/no collections are mutable.
               ``str``: The name of a single mutable collection. ``list``: A
               list of names of mutable collections.
    Returns:
      If ``mutable`` is False, returns output. If any collections are
      mutable, returns ``(output, vars)``, where ``vars`` are is a dict
      of the modified collections.
    """
        if method is None:
            method = self.__class__.__call__
        else:
            method = _get_unbound_fn(method)
        fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
        return apply(fn, mutable=mutable)(variables, rngs=rngs)
Exemple #10
0
def apply(
    fn: Callable[..., Any],
    module: Module,
    mutable: CollectionFilter = False,
    capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False
) -> Callable[..., Any]:
    """Creates an apply function to call ``fn`` with a bound module.

  Unlike ``Module.apply`` this function returns a new function with the signature
  ``(variables, *args, rngs=None, **kwargs) -> T`` where `T` is the return type
  of ``fn``. If ``mutable`` is not ``False`` the return type is a tuple where the
  second item is a ``FrozenDict`` with the mutated variables.

  The apply function that is returned can be directly composed with
  JAX transformations like ``jax.jit``::

    def f(foo, x):
      z = foo.encode(x)
      y = foo.decode(z)
      # ...
      return y
    
    foo = Foo()
    f_jitted = jax.jit(nn.apply(f, foo))
    f_jitted(variables, x)

  Args:
    fn: The function that should be applied. The first argument passed will
      be an module instance of the ``module`` with variables and RNGs bound
      to it.
    module: The ``Module`` that will be used to bind variables and RNGs to.
      The ``Module`` passed as the first argument to ``fn`` will be a clone
      of module.
    mutable: Can be bool, str, or list. Specifies which collections should be
      treated as mutable: ``bool``: all/no collections are mutable.
      ``str``: The name of a single mutable collection. ``list``: A
      list of names of mutable collections.
    capture_intermediates: If `True`, captures intermediate return values
      of all Modules inside the "intermediates" collection. By default only
      the return values of all `__call__` methods are stored. A function can
      be passed to change the filter behavior. The filter function takes
      the Module instance and method name and returns a bool indicating
      whether the output of that method invocation should be stored.
  Returns:
    The apply function wrapping ``fn``.
  """
    @functools.wraps(fn)
    def scope_fn(scope, *args, **kwargs):
        _context.capture_stack.append(capture_intermediates)
        try:
            return fn(module.clone(parent=scope), *args, **kwargs)
        finally:
            _context.capture_stack.pop()

    if capture_intermediates is True:
        capture_intermediates = capture_call_intermediates
    if capture_intermediates:
        mutable = union_filters(mutable, 'intermediates')
    return core.apply(scope_fn, mutable=mutable)
Exemple #11
0
 def apply(self, variables, *args, rngs=None,
           method=None, mutable=False, **kwargs):
   """Apply module to variables and return output and modified variables."""
   if method is None:
     method = self.__class__.__call__
   else:
     method = get_unbound_fn(method)
   fn = lambda scope: method(self.clone(parent=scope),
                             *args, **kwargs)
   return apply(fn, mutable=mutable)(variables, rngs=rngs)
Exemple #12
0
    def apply(self,
              variables: VariableDict,
              *args,
              rngs: RNGSequences = None,
              method: Callable[..., Any] = None,
              mutable: Union[bool, str, Sequence[str]] = False,
              capture_intermediates: Union[bool, Callable[['Module', str],
                                                          bool]] = False,
              **kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]:
        """Applies a module method to variables and returns output and modified variables.

    Note that `method` should be set if one would like to call `apply` on a
    different class method than `_call__`. For instance, suppose a Transformer
    modules has a method called `encode`, then the following calls `apply` on
    that method::

      model = models.Transformer(config)
      encoded = model.apply({'params': params}, inputs, method=model.encode)

    Args:
      variables: A dictionary containing variables keyed by variable
        collections. See :mod:`flax.core.variables` for more details
        about variables.
      rngs: The rngs for the variable collections.
      method: The literal name of a method in this class. If provided, applies
        this method. If not provided, applies the ``__call__`` method.
      mutable: Can be bool, str, or list. Specifies which collections should be
               treated as mutable: ``bool``: all/no collections are mutable.
               ``str``: The name of a single mutable collection. ``list``: A
               list of names of mutable collections.
      capture_intermediates: If `True`, captures intermediate return values
        of all Modules inside the "intermediates" collection. By default only
        the return values of all `__call__` methods are stored. A function can
        be passed to change the filter behavior. The filter function takes
        the Module instance and method name and returns a bool indicating
        whether the output of that method invocation should be stored.
    Returns:
      If ``mutable`` is False, returns output. If any collections are
      mutable, returns ``(output, vars)``, where ``vars`` are is a dict
      of the modified collections.
    """
        if method is None:
            method = self.__class__.__call__
        else:
            method = _get_unbound_fn(method)
        fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
        if capture_intermediates is True:
            capture_intermediates = capture_call_intermediates
        if capture_intermediates:
            mutable = union_filters(mutable, 'intermediates')
        _context.capture_stack.append(capture_intermediates)
        try:
            return apply(fn, mutable=mutable)(variables, rngs=rngs)
        finally:
            _context.capture_stack.pop()
Exemple #13
0
 def test_flow(self):
   x = jnp.ones((1, 3))
   flow = StackFlow((DenseFlow(),) * 3)
   y, variables = init(flow.forward)(random.PRNGKey(0), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   self.assertEqual(y.shape, (1, 3))
   self.assertEqual(param_shapes, {
       '0': {'kernel': (3, 3), 'bias': (3,)},
       '1': {'kernel': (3, 3), 'bias': (3,)},
       '2': {'kernel': (3, 3), 'bias': (3,)},
   })
   x_restored = apply(flow.backward)(variables, y)
   self.assertTrue(jnp.allclose(x, x_restored))
Exemple #14
0
  def test_weight_std(self):
    x = random.normal(random.PRNGKey(0), (1, 4,))
    y, variables = init(mlp)(random.PRNGKey(1), x)

    param_shapes = unfreeze(
        jax.tree_map(jnp.shape, variables['params']))
    self.assertEqual(param_shapes, {
        'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
        'out': {'kernel': (8, 1), 'bias': (1,)},
    })
    self.assertEqual(y.shape, (1, 1))
    self.assertTrue(y.ravel() < 1.)

    y2 = apply(mlp)(variables, x)
    self.assertTrue(jnp.allclose(y, y2))
Exemple #15
0
 def test_custom_vjp(self):
   x = random.normal(random.PRNGKey(0), (1, 4))
   y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x)
   param_shapes = unfreeze(
       jax.tree_map(jnp.shape, variables['params']))
   loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2)
   grad = jax.grad(loss_fn)(variables, x)
   grad_shapes = unfreeze(
       jax.tree_map(jnp.shape, grad['params']))
   self.assertEqual(y.shape, (1, 1))
   expected_param_shapes = {
       'hidden_0': {'kernel': (4, 8), 'bias': (8,)},
       'out': {'kernel': (8, 1), 'bias': (1,)},
   }
   self.assertEqual(param_shapes, expected_param_shapes)
   self.assertEqual(grad_shapes, expected_param_shapes)
   for g in jax.tree_leaves(grad):
     self.assertTrue(np.all(g == np.sign(g)))
Exemple #16
0
    def apply(self,
              variables: VariableDict,
              *args,
              rngs: RNGSequences = None,
              method: Callable[..., Any] = None,
              mutable: Union[bool, str, Sequence[str]] = False,
              **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
        """Applies a module method to variables and returns output and modified variables.

    Note that `method` should be set if one would like to call `apply` on a 
    different class method than `_call__`. For instance, suppose a Transformer
    modules has a method called `encode`, then the following calls `apply` on
    that method::

      model = models.Transformer(config)
      encoded = model.apply({'params': params}, inputs, method=model.encode) 

    Args:
      variables: A dictionary containing variables keyed by variable
        collections. See :mod:`flax.core.variables` for more details
        about variables.
      rngs: The rngs for the variable collections.
      method: The literal name of a method in this class. If provided, applies
        this method. If not provided, applies the ``__call__`` method.
      mutable: Can be bool, str, or list. Specifies which collections should be
               treated as mutable: ``bool``: all/no collections are mutable.
               ``str``: The name of a single mutable collection. ``list``: A
               list of names of mutable collections.
    Returns:
      If ``mutable`` is False, returns output. If any collections are
      mutable, returns ``(output, vars)``, where ``vars`` are is a dict
      of the modified collections.
    """
        if method is None:
            method = self.__class__.__call__
        else:
            method = _get_unbound_fn(method)
        fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
        return apply(fn, mutable=mutable)(variables, rngs=rngs)
Exemple #17
0
        _, pullback = jax.vjp(fn, params, x)
        g_param, g_x = pullback(g)
        g_param = jax.tree_map(jnp.sign, g_param)
        return g_param, g_x

    dense_custom_grad = lift.custom_vjp(fwd,
                                        backward_fn=bwd,
                                        nondiff_argnums=(2, ))

    # hidden layers
    for size in sizes[:-1]:
        x = scope.child(dense_custom_grad)(x, size)
        x = act_fn(x)

    # output layer
    return scope.child(dense_custom_grad)(x, sizes[-1])


if __name__ == "__main__":
    x = random.normal(random.PRNGKey(0), (1, 4))
    x = jnp.concatenate([x, x], 0)

    print('shared params: (same inputs, same outputs)')
    y, params = init(mlp_custom_grad)(random.PRNGKey(1), x)
    print(y)
    print(jax.tree_map(jnp.shape, unfreeze(params)))

    print(
        jax.grad(
            lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x)**2))(
                params, x))
Exemple #18
0
        return jnp.dot(x, expm(kernel)) + bias

    def backward(self, scope: Scope, y: Array):
        kernel, bias = self.params(scope, y.shape[-1])
        return jnp.dot(y - bias, expm(-kernel))


@dataclass
class StackFlow:
    flows: Sequence[Flow]

    def forward(self, scope: Scope, x: Array):
        for i, f in enumerate(self.flows):
            x = scope.child(f.forward, name=str(i))(x)
        return x

    def backward(self, scope: Scope, x: Array):
        for i, f in reversed(tuple(enumerate(self.flows))):
            x = scope.child(f.backward, name=str(i))(x)
        return x


if __name__ == "__main__":
    flow = StackFlow((DenseFlow(), ) * 3)
    # forward and backward are interchangeable here
    # so shape inference and initialization can be done on the forward and backward pass of the flow
    y, params = init(flow.forward)(random.PRNGKey(0), jnp.ones((1, 3)))
    print(params)
    x_restore = apply(flow.backward)(params, y)
    print(x_restore)
Exemple #19
0
  def _tied(self, fn, transpose=False):
    if not transpose:
      return fn

    def trans(variables):
      if 'param' not in variables:
        return variables
      params = variables['param']
      params['kernel'] = params['kernel'].T
      return variables

    return lift.transform_module(
        fn, trans_in_fn=trans, trans_out_fn=trans)

if __name__ == "__main__":
  ae = TiedAutoEncoder(latents=2, features=4)
  x = jnp.ones((1, ae.features))

  x_r, params = init(ae)(random.PRNGKey(0), x)

  print(x, x_r)
  print(params)


  print('init from decoder:')
  z = jnp.ones((1, ae.latents))
  x_r, params = init(ae.decode)(random.PRNGKey(0), z)

  print(apply(ae)(params, x))
  print(params)
Exemple #20
0
    return y, x

  def bwd(features, scope_fn, params, res, g):
    x = res
    fn = lambda params, x: nn.dense(scope_fn(params), x, features)
    _, pullback = jax.vjp(fn, params, x)
    g_param, g_x = pullback(g)
    g_param = jax.tree_map(jnp.sign, g_param)
    return g_param, g_x

  dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,))

  # hidden layers
  for size in sizes[:-1]:
    x = scope.child(dense_custom_grad)(x, size)
    x = act_fn(x)

  # output layer
  return scope.child(dense_custom_grad)(x, sizes[-1])

if __name__ == "__main__":
  x = random.normal(random.PRNGKey(0), (1, 4))
  x = jnp.concatenate([x, x], 0)

  print('shared params: (same inputs, same outputs)')
  y, params = init(mlp_custom_grad)(random.PRNGKey(1), x)
  print(y)
  print(jax.tree_map(jnp.shape, unfreeze(params)))

  print(jax.grad(lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x) ** 2))(params, x))
Exemple #21
0
 def apply_fun(params, *args, **kwargs):
     return apply(core_fun, mutable=mutable)(params, *args, **kwargs)