def test_remat(self, flavor="old"): def f(x1): x2 = jnp.sin(x1) x3 = jnp.sin(x2) x4 = jnp.sin(x3) return x4 remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f) # The computation of grad_f computes "sin" 5 times, 3 for the forward pass # and then to rematerialize "x2" and "x3" in the backward pass. arg = np.array(3.) # Check that we have a Sin under a conditional f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False) f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def() if flavor == "old": raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat") if jax.config.jax_remat_opt_barrier: self.assertRegex( str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier") elif config.jax_experimental_name_stack: self.assertRegex(str(f_tf_graph), r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin') else: self.assertRegex(str(f_tf_graph), r'remat_checkpoint_/switch_case/indexed_case/Sin')
def test_hk_remat( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x))
def testAxisIndexRemat(self): # https://github.com/google/jax/issues/2716 n = len(jax.devices()) def f(key): key = random.fold_in(key, jax.lax.axis_index('i')) return random.bernoulli(key, p=0.5) keys = random.split(random.PRNGKey(0), n) jax.pmap(jax.remat(f), axis_name='i')(keys)
def _jax_scan(f, xs, init_value, axis=0, remat=False): """Scans the f over the given axis of xs. In pseudo-python, the scan function would look as follows: def scan(f, xs, init_value, axis): xs = [xs[..., i, ...] for i in range(xs.shape[axis])] cur_value = init_value ys = [] for x in xs: y, cur_value = f(x, cur_value) ys.append(y) return np.stack(ys, axis), cur_value Args: f: function (x, carry) -> (y, new_carry) xs: tensor, x will be xs slices on axis init_value: tensor, initial value of the carry-over axis: int, the axis on which to slice xs remat: whether to re-materialize f Returns: A pair (ys, last_value) as described above. """ def swapaxes(x): transposed_axes = list(range(len(x.shape))) transposed_axes[axis] = 0 transposed_axes[0] = axis return jnp.transpose(x, axes=transposed_axes) if axis != 0: xs = nested_map(swapaxes, xs) def transposed_f(c, x): y, d = f(x, c) return d, y if remat: last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs) else: last_value, ys = lax.scan(transposed_f, init_value, xs) if axis != 0: ys = nested_map(swapaxes, ys) return ys, last_value
def forward(self, xs): rngs = _split_rngs(self.rng, len(self.sublayers)) accumulator, *context = xs stack = context = tuple(context) new_state = [] for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): inputs = cb.inputs_from_stack(stack, layer.n_in) if base.N_WEIGHTS_SHARDS > 1: # With sharded weights, make sure we don't keep them concatenated # in memory on each device by using remat. outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) else: outputs, s = layer.pure_fn(inputs, w, s, rng) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) new_state.append(s) residual = stack[0] if isinstance(stack, (tuple, list)) else stack output = accumulator + residual stack = (output,) + context self.state = tuple(new_state) return stack
def test_create_module_inside_remat(self, jax_remat, inline_hk_remat): log = [] def forward(x): def create_and_use_layer(x): m = SquareModule(name="layer") log.append(m.module_name) return m(x) if not inline_hk_remat: create_and_use_layer = stateful.remat(create_and_use_layer) for _ in range(2): if inline_hk_remat: x = stateful.remat(create_and_use_layer)(x) else: x = create_and_use_layer(x) return x def reset(): del log[:] self.assertEmpty(log) # Test forward. x = jnp.float32(3) forward = transform.transform_with_state(forward) params, state = forward.init(None, x) self.assertEqual(log, ["layer", "layer_1"]) reset() # Test backward. for _ in range(3): grad_fn = jax.grad( lambda x: forward.apply(params, state, None, x)[0]) if jax_remat: grad_fn = jax.remat(grad_fn) self.assertEqual(int(grad_fn(x)), int(4 * (x**3))) self.assertEqual(log, ["layer", "layer_1"]) reset()
from functools import partial import jax import jax.numpy as np from .transforms import transform_params def _squared_distance(x1, x2, scales=None): z1, z2 = (x1, x2) if scales is None else (x1 / scales, x2 / scales) return ( # clip_up( FIXME np.sum(z1 * z1, axis=1, keepdims=True) - 2.0 * z1 @ z2.T + np.sum(z2 * z2, axis=1, keepdims=True).T) _remat_squared_distance = jax.remat(_squared_distance) def vmap(k, diag=False): """ Vectorize a "single" kernel of the form k(params, x1, x2) diag: k(params, x): (N,DX) -> (N,) full: k(params, x1, x2): (N,DX), (M,DX) -> (N,M) """ if diag: # k(params, x) return jax.vmap(lambda params, x: k(params, x, x), (None, 0)) else: # k(params, x1, x2) inside = jax.vmap(lambda params, x1, x2: k(params, x1, x2),