Пример #1
0
  def test_remat(self, flavor="old"):
    def f(x1):
      x2 = jnp.sin(x1)
      x3 = jnp.sin(x2)
      x4 = jnp.sin(x3)
      return x4
    remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f)

    # The computation of grad_f computes "sin" 5 times, 3 for the forward pass
    # and then to rematerialize "x2" and "x3" in the backward pass.
    arg = np.array(3.)
    # Check that we have a Sin under a conditional
    f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False)
    f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def()
    if flavor == "old":
      raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat")
    if jax.config.jax_remat_opt_barrier:
      self.assertRegex(
          str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier")
    elif config.jax_experimental_name_stack:
      self.assertRegex(str(f_tf_graph),
                       r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin')
    else:
      self.assertRegex(str(f_tf_graph),
                       r'remat_checkpoint_/switch_case/indexed_case/Sin')
Пример #2
0
  def test_hk_remat(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, remat=False):
      mod = module_fn()
      if remat:
        mod = hk.remat(mod)
      out = mod(x)
      if isinstance(out, dict):
        out = out['loss']
      return jnp.mean(out)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
    grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                             has_aux=True)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      grad_jax_remat(params, state, rng, x),
                      grad_hk_remat(params, state, rng, x))
Пример #3
0
  def testAxisIndexRemat(self):
    # https://github.com/google/jax/issues/2716
    n = len(jax.devices())

    def f(key):
      key = random.fold_in(key, jax.lax.axis_index('i'))
      return random.bernoulli(key, p=0.5)

    keys = random.split(random.PRNGKey(0), n)
    jax.pmap(jax.remat(f), axis_name='i')(keys)
Пример #4
0
def _jax_scan(f, xs, init_value, axis=0, remat=False):
    """Scans the f over the given axis of xs.

  In pseudo-python, the scan function would look as follows:

  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value

  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f

  Returns:
    A pair (ys, last_value) as described above.
  """
    def swapaxes(x):
        transposed_axes = list(range(len(x.shape)))
        transposed_axes[axis] = 0
        transposed_axes[0] = axis
        return jnp.transpose(x, axes=transposed_axes)

    if axis != 0:
        xs = nested_map(swapaxes, xs)

    def transposed_f(c, x):
        y, d = f(x, c)
        return d, y

    if remat:
        last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
    else:
        last_value, ys = lax.scan(transposed_f, init_value, xs)
    if axis != 0:
        ys = nested_map(swapaxes, ys)
    return ys, last_value
Пример #5
0
  def forward(self, xs):
    rngs = _split_rngs(self.rng, len(self.sublayers))
    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs):
      inputs = cb.inputs_from_stack(stack, layer.n_in)
      if base.N_WEIGHTS_SHARDS > 1:
        # With sharded weights, make sure we don't keep them concatenated
        # in memory on each device by using remat.
        outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng)
      else:
        outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    self.state = tuple(new_state)
    return stack
Пример #6
0
    def test_create_module_inside_remat(self, jax_remat, inline_hk_remat):
        log = []

        def forward(x):
            def create_and_use_layer(x):
                m = SquareModule(name="layer")
                log.append(m.module_name)
                return m(x)

            if not inline_hk_remat:
                create_and_use_layer = stateful.remat(create_and_use_layer)

            for _ in range(2):
                if inline_hk_remat:
                    x = stateful.remat(create_and_use_layer)(x)
                else:
                    x = create_and_use_layer(x)
            return x

        def reset():
            del log[:]
            self.assertEmpty(log)

        # Test forward.
        x = jnp.float32(3)
        forward = transform.transform_with_state(forward)
        params, state = forward.init(None, x)
        self.assertEqual(log, ["layer", "layer_1"])
        reset()

        # Test backward.
        for _ in range(3):
            grad_fn = jax.grad(
                lambda x: forward.apply(params, state, None, x)[0])
            if jax_remat:
                grad_fn = jax.remat(grad_fn)
            self.assertEqual(int(grad_fn(x)), int(4 * (x**3)))
            self.assertEqual(log, ["layer", "layer_1"])
            reset()
Пример #7
0
from functools import partial

import jax
import jax.numpy as np

from .transforms import transform_params


def _squared_distance(x1, x2, scales=None):
    z1, z2 = (x1, x2) if scales is None else (x1 / scales, x2 / scales)
    return (  # clip_up(   FIXME
        np.sum(z1 * z1, axis=1, keepdims=True) - 2.0 * z1 @ z2.T +
        np.sum(z2 * z2, axis=1, keepdims=True).T)


_remat_squared_distance = jax.remat(_squared_distance)


def vmap(k, diag=False):
    """
    Vectorize a "single" kernel of the form k(params, x1, x2)

    diag: k(params, x): (N,DX) -> (N,)
    full: k(params, x1, x2): (N,DX), (M,DX) -> (N,M)
    """
    if diag:
        # k(params, x)
        return jax.vmap(lambda params, x: k(params, x, x), (None, 0))
    else:
        # k(params, x1, x2)
        inside = jax.vmap(lambda params, x1, x2: k(params, x1, x2),