def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones, (4, )) msg = 'Inconsistent shapes between value and initializer for parameter "test" in "/": (2,), (4,)' with self.assertRaisesWithLiteralMatch(ValueError, msg): apply(f)(freeze({'params': {'test': np.ones((2, ))}}))
def test_variable_is_mutable(self): def f(scope, should_be_mutable): test = scope.variable('state', 'test', lambda: 1) self.assertEqual(test.is_mutable(), should_be_mutable) _, variables = apply(f, mutable='state')({}, True) apply(f, mutable=False)(variables, False)
def test_undefined_param(self): def f(scope): nn.dense(scope.push('dense'), np.ones((1, 2)), 2) with self.assertRaisesWithLiteralMatch( ValueError, 'No paramater named "kernel" exists in "/dense".'): apply(f)({})
def test_jit_cache(self): compiles = 0 @lift.jit def f(scope, x): nonlocal compiles compiles += 1 if scope.is_mutable_collection( 'intermediates' ) and not scope.is_mutable_collection('params'): scope.put_variable('intermediates', 'x', x + 1) return nn.dense(scope, x, 1) x = np.ones((3, 2)) _, params = init(f)(random.PRNGKey(0), x) init(f)(random.PRNGKey(0), x) self.assertEqual(compiles, 1) apply(f)(params, x) self.assertEqual(compiles, 2) # apply should cause a compile apply(f)(params, x) self.assertEqual(compiles, 2) # applying again should not # edge case where only the implicit return of the jitted functions changes. # this should not use the previously cached apply. _, state = apply(f, mutable='intermediates')(params, x) self.assertEqual(compiles, 3) # applying again should not self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2)
def test_undefined_param(self): def f(scope): nn.dense(scope.push('dense'), np.ones((1, 2)), 2) msg = r'No parameter named "kernel" exists in "/dense".' with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg): apply(f)({})
def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones, (4, )) msg = r'Inconsistent shapes between value and initializer for parameter "test" in "/": \(2,\), \(4,\).' with self.assertRaisesRegex(errors.ScopeParamShapeError, msg): apply(f)(freeze({'params': {'test': np.ones((2, ))}}))
def test_undefined_param(self): def f(scope): dense = lift.vmap(nn.dense, in_axes=(0, None), out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}) dense(scope.push('dense'), np.ones((3, 2)), 2) with self.assertRaisesWithLiteralMatch(ValueError, 'No parameter named "kernel" exists in "/vmap(dense)".'): apply(f)({})
def test_undefined_param(self): def f(scope): dense = lift.vmap(nn.dense, in_axes=(0, None), out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}) dense(scope.push('dense'), np.ones((3, 2)), 2) msg = r'No parameter named "kernel" exists in "/vmap\(dense\)".' with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg): apply(f)({})
def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, method: Callable[..., Any] = None, mutable: Union[bool, str, Sequence[str]] = False, **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: """Applies a module method to variables and returns output and modified variables. Args: variables: A dictionary containing variables keyed by variable collections. See :mod:`flax.core.variables` for more details about variables. rngs: The rngs for the variable collections. method: The literal name of a method in this class. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict of the modified collections. """ if method is None: method = self.__class__.__call__ else: method = _get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) return apply(fn, mutable=mutable)(variables, rngs=rngs)
def apply( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = False, capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False ) -> Callable[..., Any]: """Creates an apply function to call ``fn`` with a bound module. Unlike ``Module.apply`` this function returns a new function with the signature ``(variables, *args, rngs=None, **kwargs) -> T`` where `T` is the return type of ``fn``. If ``mutable`` is not ``False`` the return type is a tuple where the second item is a ``FrozenDict`` with the mutated variables. The apply function that is returned can be directly composed with JAX transformations like ``jax.jit``:: def f(foo, x): z = foo.encode(x) y = foo.decode(z) # ... return y foo = Foo() f_jitted = jax.jit(nn.apply(f, foo)) f_jitted(variables, x) Args: fn: The function that should be applied. The first argument passed will be an module instance of the ``module`` with variables and RNGs bound to it. module: The ``Module`` that will be used to bind variables and RNGs to. The ``Module`` passed as the first argument to ``fn`` will be a clone of module. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. capture_intermediates: If `True`, captures intermediate return values of all Modules inside the "intermediates" collection. By default only the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. Returns: The apply function wrapping ``fn``. """ @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: return fn(module.clone(parent=scope), *args, **kwargs) finally: _context.capture_stack.pop() if capture_intermediates is True: capture_intermediates = capture_call_intermediates if capture_intermediates: mutable = union_filters(mutable, 'intermediates') return core.apply(scope_fn, mutable=mutable)
def apply(self, variables, *args, rngs=None, method=None, mutable=False, **kwargs): """Apply module to variables and return output and modified variables.""" if method is None: method = self.__class__.__call__ else: method = get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) return apply(fn, mutable=mutable)(variables, rngs=rngs)
def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, method: Callable[..., Any] = None, mutable: Union[bool, str, Sequence[str]] = False, capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, **kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]: """Applies a module method to variables and returns output and modified variables. Note that `method` should be set if one would like to call `apply` on a different class method than `_call__`. For instance, suppose a Transformer modules has a method called `encode`, then the following calls `apply` on that method:: model = models.Transformer(config) encoded = model.apply({'params': params}, inputs, method=model.encode) Args: variables: A dictionary containing variables keyed by variable collections. See :mod:`flax.core.variables` for more details about variables. rngs: The rngs for the variable collections. method: The literal name of a method in this class. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. capture_intermediates: If `True`, captures intermediate return values of all Modules inside the "intermediates" collection. By default only the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict of the modified collections. """ if method is None: method = self.__class__.__call__ else: method = _get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) if capture_intermediates is True: capture_intermediates = capture_call_intermediates if capture_intermediates: mutable = union_filters(mutable, 'intermediates') _context.capture_stack.append(capture_intermediates) try: return apply(fn, mutable=mutable)(variables, rngs=rngs) finally: _context.capture_stack.pop()
def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) y, variables = init(flow.forward)(random.PRNGKey(0), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(y.shape, (1, 3)) self.assertEqual(param_shapes, { '0': {'kernel': (3, 3), 'bias': (3,)}, '1': {'kernel': (3, 3), 'bias': (3,)}, '2': {'kernel': (3, 3), 'bias': (3,)}, }) x_restored = apply(flow.backward)(variables, y) self.assertTrue(jnp.allclose(x, x_restored))
def test_weight_std(self): x = random.normal(random.PRNGKey(0), (1, 4,)) y, variables = init(mlp)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) self.assertEqual(param_shapes, { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, }) self.assertEqual(y.shape, (1, 1)) self.assertTrue(y.ravel() < 1.) y2 = apply(mlp)(variables, x) self.assertTrue(jnp.allclose(y, y2))
def test_custom_vjp(self): x = random.normal(random.PRNGKey(0), (1, 4)) y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x) param_shapes = unfreeze( jax.tree_map(jnp.shape, variables['params'])) loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2) grad = jax.grad(loss_fn)(variables, x) grad_shapes = unfreeze( jax.tree_map(jnp.shape, grad['params'])) self.assertEqual(y.shape, (1, 1)) expected_param_shapes = { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, } self.assertEqual(param_shapes, expected_param_shapes) self.assertEqual(grad_shapes, expected_param_shapes) for g in jax.tree_leaves(grad): self.assertTrue(np.all(g == np.sign(g)))
def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, method: Callable[..., Any] = None, mutable: Union[bool, str, Sequence[str]] = False, **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: """Applies a module method to variables and returns output and modified variables. Note that `method` should be set if one would like to call `apply` on a different class method than `_call__`. For instance, suppose a Transformer modules has a method called `encode`, then the following calls `apply` on that method:: model = models.Transformer(config) encoded = model.apply({'params': params}, inputs, method=model.encode) Args: variables: A dictionary containing variables keyed by variable collections. See :mod:`flax.core.variables` for more details about variables. rngs: The rngs for the variable collections. method: The literal name of a method in this class. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict of the modified collections. """ if method is None: method = self.__class__.__call__ else: method = _get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) return apply(fn, mutable=mutable)(variables, rngs=rngs)
_, pullback = jax.vjp(fn, params, x) g_param, g_x = pullback(g) g_param = jax.tree_map(jnp.sign, g_param) return g_param, g_x dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2, )) # hidden layers for size in sizes[:-1]: x = scope.child(dense_custom_grad)(x, size) x = act_fn(x) # output layer return scope.child(dense_custom_grad)(x, sizes[-1]) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print('shared params: (same inputs, same outputs)') y, params = init(mlp_custom_grad)(random.PRNGKey(1), x) print(y) print(jax.tree_map(jnp.shape, unfreeze(params))) print( jax.grad( lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x)**2))( params, x))
return jnp.dot(x, expm(kernel)) + bias def backward(self, scope: Scope, y: Array): kernel, bias = self.params(scope, y.shape[-1]) return jnp.dot(y - bias, expm(-kernel)) @dataclass class StackFlow: flows: Sequence[Flow] def forward(self, scope: Scope, x: Array): for i, f in enumerate(self.flows): x = scope.child(f.forward, name=str(i))(x) return x def backward(self, scope: Scope, x: Array): for i, f in reversed(tuple(enumerate(self.flows))): x = scope.child(f.backward, name=str(i))(x) return x if __name__ == "__main__": flow = StackFlow((DenseFlow(), ) * 3) # forward and backward are interchangeable here # so shape inference and initialization can be done on the forward and backward pass of the flow y, params = init(flow.forward)(random.PRNGKey(0), jnp.ones((1, 3))) print(params) x_restore = apply(flow.backward)(params, y) print(x_restore)
def _tied(self, fn, transpose=False): if not transpose: return fn def trans(variables): if 'param' not in variables: return variables params = variables['param'] params['kernel'] = params['kernel'].T return variables return lift.transform_module( fn, trans_in_fn=trans, trans_out_fn=trans) if __name__ == "__main__": ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) x_r, params = init(ae)(random.PRNGKey(0), x) print(x, x_r) print(params) print('init from decoder:') z = jnp.ones((1, ae.latents)) x_r, params = init(ae.decode)(random.PRNGKey(0), z) print(apply(ae)(params, x)) print(params)
return y, x def bwd(features, scope_fn, params, res, g): x = res fn = lambda params, x: nn.dense(scope_fn(params), x, features) _, pullback = jax.vjp(fn, params, x) g_param, g_x = pullback(g) g_param = jax.tree_map(jnp.sign, g_param) return g_param, g_x dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,)) # hidden layers for size in sizes[:-1]: x = scope.child(dense_custom_grad)(x, size) x = act_fn(x) # output layer return scope.child(dense_custom_grad)(x, sizes[-1]) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 4)) x = jnp.concatenate([x, x], 0) print('shared params: (same inputs, same outputs)') y, params = init(mlp_custom_grad)(random.PRNGKey(1), x) print(y) print(jax.tree_map(jnp.shape, unfreeze(params))) print(jax.grad(lambda params, x: jnp.mean(apply(mlp_custom_grad)(params, x) ** 2))(params, x))
def apply_fun(params, *args, **kwargs): return apply(core_fun, mutable=mutable)(params, *args, **kwargs)