def test_binary(self, primitive: Optional[Primitive], shape1, shape2,
                    dtype, params):
        # TODO(romann): revisit when bugs below are fixed.
        if primitive == lax.conv_general_dilated_p:
            if jax.default_backend() == 'tpu':
                raise absltest.SkipTest('http://b/235167364')

            elif jax.default_backend(
            ) == 'gpu' and params['batch_group_count'] != 1:
                raise absltest.SkipTest('http://b/235485533')

        if len(shape1) > 3 or len(shape2) > 3:
            test_utils.skip_test(self)

        self._test_primitive(primitive, [shape1, shape2], dtype, params)
  def test_parallel_in(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, input_key2, mc_key = random.split(rng, 3)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 2))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3))

    x1 = (x1_1, x1_2)
    x2 = (x2_1, x2_2)

    N = 2 ** 7

    def net(logits):
      return stax.serial(
          stax.parallel(stax.Dense(N), stax.Dense(N)),
          stax.serial(stax.FanInSum(), stax.Dense(logits)))

    init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1)

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=((0, 0), 0, {})
    )
    test_utils.assert_close_matrices(self,
                                     kernel_fn(x1, x2, kernel_type),
                                     kernel_fn_empirical(x1, x2, kernel_type),
                                     rtol)
def main(unused_argv):
  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  x1 = random.normal(key1, (2, 8, 8, 3))
  x2 = random.normal(key2, (3, 8, 8, 3))

  # A vanilla CNN.
  init_fn, f, _ = stax.serial(
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Flatten(),
      stax.Dense(10)
  )

  _, params = init_fn(key3, x1.shape)
  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0,
  )

  # Default, baseline Jacobian contraction.
  jacobian_contraction = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

  # (6, 3, 10, 10) full `np.ndarray` test-train NTK
  ntk_jc = jacobian_contraction(x2, x1, params)

  # NTK-vector products-based implementation.
  ntk_vector_products = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

  ntk_vp = ntk_vector_products(x2, x1, params)

  # Structured derivatives-based implementation.
  structured_derivatives = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

  ntk_sd = structured_derivatives(x2, x1, params)

  # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
  auto = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.AUTO)

  ntk_auto = auto(x2, x1, params)

  # Check that implementations match
  for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
    for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
      diff = np.max(np.abs(ntk1 - ntk2))
      print(f'NTK implementation diff {diff}.')
      assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

  print('All NTK implementations match.')
Exemple #4
0
 def assert_close(x, y, tol=3e-5):
   if default_backend() == 'tpu':
     # TODO(romann): understand why TPUs have high errors.
     tol = 0.21
   self.assertLess(
       np.max(np.abs(x - y)) / (np.mean(np.abs(x)) + np.mean(np.abs(y))),
       tol)
Exemple #5
0
    def test_is_on_cpu(self):
        dtypes = [np.float16, np.float32]
        float64 = jax.dtypes.canonicalize_dtype(np.float64)
        if float64 != np.float32:
            dtypes += [float64]

        for dtype in dtypes:
            with self.subTest(dtype=dtype):

                def x():
                    return random.normal(random.PRNGKey(1), (2, 3), dtype)

                def x_cpu():
                    return device_get(
                        random.normal(random.PRNGKey(1), (2, 3), dtype))

                x_jit = jit(x)
                # x_cpu_jit = jit(x_cpu)
                x_cpu_jit_cpu = jit(x_cpu, backend='cpu')

                self.assertTrue(utils.is_on_cpu(x_cpu()))
                # TODO(mattjj): re-enable this when device_put under jit works
                # self.assertTrue(utils.is_on_cpu(x_cpu_jit()))
                self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu()))

                if jax.default_backend() == 'cpu':
                    self.assertTrue(utils.is_on_cpu(x()))
                    self.assertTrue(utils.is_on_cpu(x_jit()))
                else:
                    self.assertFalse(utils.is_on_cpu(x()))
                    self.assertFalse(utils.is_on_cpu(x_jit()))
Exemple #6
0
def _optimize() -> str:
  """Return contraction order for `np.einsum` based on platform.

  Introduced after https://github.com/google/jax/pull/7512 since TPU seems to
  be more precise in `greeedy` mode.
  """
  return 'greedy' if jax.default_backend() == 'tpu' else 'optimal'
Exemple #7
0
def skip_test(
    self,
    msg: str = 'Skipping large tests for speed.',
    platforms: Tuple[str, ...] = ('cpu',)
):
  if jax.default_backend() in platforms:
    raise parameterized.TestCase.skipTest(self, msg)
Exemple #8
0
def double_buffer_on_gpu(ds):
    if jax.default_backend() == "gpu":
        # This keeps two batches per-device in memory at all times, allowing
        # h2d transfers to overlap with execution (see b/173483287 for details).
        return double_buffer(ds)
    else:
        return ds
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
Exemple #10
0
  def test_parallel_out(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, mc_key = random.split(rng, 2)

    x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1))

    N = 2 ** 10

    def net(logits):
      return stax.serial(
          stax.Dense(N),
          stax.FanOut(2),
          stax.parallel(stax.Dense(logits), stax.Dense(logits)))

    init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1)

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=2,
        vmap_axes=(0, [0, 0], {}))

    test_utils.assert_close_matrices(self,
                                     kernel_fn(x1, x2, kernel_type),
                                     kernel_fn_empirical(x1, x2, kernel_type),
                                     rtol)
Exemple #11
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
Exemple #12
0
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if jax.default_backend() == "cpu":
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape)
 def test_pjit_inherits_effects(self):
   if jax.default_backend() not in {'gpu', 'tpu'}:
     raise unittest.SkipTest("pjit only supports GPU and TPU backends")
   def f(x):
     effect_p.bind(effect='foo')
     effect_p.bind(effect='bar')
     return x
   f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'),
       out_axis_resources=pjit.PartitionSpec('x'))
   with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
     with maps.Mesh(np.array(jax.devices()), ['x']):
       jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Exemple #14
0
def stub_out_pmap(batch: ModuleType, count: int):
  # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core.
  if count > 0:
    class xla_bridge_stub:

      def device_count(self) -> int:
        return count

    platform = jax.default_backend()
    if platform == 'gpu' or platform == 'cpu':
      batch.pmap = _jit_vmap
      batch.xla_bridge = xla_bridge_stub()
Exemple #15
0
    def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        if gpu and jax.default_backend() == "cpu":
            raise unittest.SkipTest("Skipping GPU test case on CPU")
        if (not gpu and jax.default_backend() == "gpu"
                and jax.lib._xla_extension_version < 25):
            raise unittest.SkipTest(
                "Mixed CPU/GPU dlpack support requires jaxlib "
                "0.1.68 or newer")
        device = jax.devices("gpu" if gpu else "cpu")[0]
        x = jax.device_put(np, device)
        dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
        self.assertEqual(take_ownership, x.device_buffer.is_deleted())
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertEqual(y.device(), device)
        self.assertAllClose(np.astype(x.dtype), y)

        self.assertRaisesRegex(RuntimeError,
                               "DLPack tensor may be consumed at most once",
                               lambda: jax.dlpack.from_dlpack(dlpack))
  def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                   store_on_device):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
        WIDTH, 256, jax.default_backend() == 'tpu')

    sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                               batch_size, device_count,
                                               store_on_device)

    ker_empirical = sample(x1, x2, 'nngp')
    ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

    test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Exemple #17
0
  def test_sparse_inputs(self, act, kernel, do_stabilize):
    if do_stabilize and act != 'relu':
      raise absltest.SkipTest('Stabilization possible only in Relu.')

    key = random.PRNGKey(1)

    input_count = 4
    sparse_count = 2
    input_size = 3
    width = 1024

    # NOTE(schsam): It seems that convergence is slower when inputs are sparse.
    samples = N_SAMPLES

    if default_backend() == 'gpu':
      tol = 5e-4
      samples = 100 * N_SAMPLES
    else:
      tol = {onp.dtype(onp.float32): 5e-2, onp.dtype(onp.float64): 5e-3}

    # a batch of dense inputs
    x_dense = random.normal(key, (input_count, input_size))
    x_sparse = x_dense.at[:sparse_count, :].set(0.)

    activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu'
                  else stax.Erf())

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        activation,
        stax.Dense(1 if kernel == 'ntk' else width))
    exact = kernel_fn(x_sparse, None, kernel)

    mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        random.split(key, 2)[0],
        samples,
        vmap_axes=0,
        device_count=-1,
        implementation=2
    )(x_sparse, None, kernel)
    mc = np.reshape(mc, exact.shape)

    assert not np.any(np.isnan(exact))
    self.assertAllClose(exact[sparse_count:, sparse_count:],
                        mc[sparse_count:, sparse_count:],
                        rtol=tol, atol=tol)
Exemple #18
0
def _check_agreement_with_empirical(
    self,
    net,
    same_inputs,
    use_dropout,
    is_ntk,
    rtol=RTOL,
    atol=ATOL
):
  ((init_fn, apply_fn, kernel_fn),
   input_shape, device_count, channel_axis) = net

  num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES
  key = random.PRNGKey(1)
  x1, x2 = _get_inputs(key, same_inputs, input_shape)
  if default_backend() == 'tpu' and use_dropout:
    # including a test case for tpu + dropout with (parallel + batching)
    batch_size = 2
  else:
    batch_size = 0
  x1_out_shape, params = init_fn(key, x1.shape)
  if same_inputs:
    assert x2 is None
  if x2 is None:
    x2_out_shape = x1_out_shape
  else:
    x2_out_shape, params = init_fn(key, x2.shape)
  del params

  def _get_empirical(n_samples, get):
    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, device_count=device_count,
        trace_axes=(channel_axis,), batch_size=batch_size,
        implementation=2
    )
    if same_inputs:
      assert x2 is None
    return kernel_fn_empirical(x1, x2, get)

  if is_ntk:
    exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2'))
    empirical = _get_empirical(num_samples, 'ntk')
  else:
    exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2'))
    empirical = _get_empirical(num_samples, 'nngp')
  test_utils.assert_close_matrices(self, exact, empirical, rtol, atol)
  self.assertEqual(shape1, x1_out_shape)
  self.assertEqual(shape2, x2_out_shape)
Exemple #19
0
def _subsample_fn(size, subsample_size, rng_key=None):
    assert rng_key is not None, "Missing random key to generate subsample indices."
    if jax.default_backend() == "cpu":
        # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
        rng_keys = random.split(rng_key, subsample_size)

        def body_fn(val, idx):
            i_p1 = size - idx
            i = i_p1 - 1
            j = random.randint(rng_keys[idx], (), 0, i_p1)
            val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])])
            return val, None

        val, _ = lax.scan(body_fn, jnp.arange(size),
                          jnp.arange(subsample_size))
        return val[-subsample_size:]
    else:
        return random.choice(rng_key, size, (subsample_size, ), replace=False)
Exemple #20
0
  def test_parallel_in_out(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, input_key2, mc_key = random.split(rng, 3)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2))

    x1 = (x1_1, x1_2)
    x2 = (x2_1, x2_2)

    N_in = 2 ** 10
    N_out = N_in if kernel_type == 'nngp' else 1

    readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)),
                         stax.FanInSum())
    readout = stax.serial(stax.FanOut(3),
                          stax.parallel(stax.Dense(N_out),
                                        stax.Dense(N_out + 1),
                                        stax.Dense(N_out + 2)))
    init_fn, apply_fn, _ = stax.serial(readin, readout)

    K_readin_fn = jit(readin[2])
    K_readout_fn = jit(functools.partial(readout[2], get=kernel_type))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=2,
        vmap_axes=((0, 0), [0, 0, 0], {})
    )

    test_utils.assert_close_matrices(
        self,
        K_readout_fn(K_readin_fn(x1, x2)),
        kernel_fn_empirical(x1, x2, get=kernel_type),
        rtol)

    # Check Both (here we just want to make sure we _can_ compute the output).
    K_readin_fn = jit(readin[2])
    K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk')))

    K_readout_fn(K_readin_fn(x1, x2))
Exemple #21
0
  def test_double_buffer(self):
    if jax.default_backend() != "gpu":
      self.skipTest("Only necessary on GPU.")

    n = jax.local_device_count()
    dataset = it.repeat(np.ones([n]))
    iterator = iter(utils.double_buffer(dataset))

    batch_ptrs = []
    while len(batch_ptrs) < 4:
      batch = next(iterator)
      ptrs = [b.unsafe_buffer_pointer() for b in batch.device_buffers]
      batch_ptrs.append(ptrs)
      del batch

    self.assertEqual(batch_ptrs[0], batch_ptrs[2])
    self.assertEqual(batch_ptrs[1], batch_ptrs[3])
    self.assertNotEqual(batch_ptrs[0], batch_ptrs[1])
    self.assertNotEqual(batch_ptrs[2], batch_ptrs[3])
Exemple #22
0
def x1_is_x2(x1: np.ndarray,
             x2: Optional[np.ndarray] = None,
             eps: float = 1e-12) -> Union[bool, np.ndarray]:
    if not isinstance(x1, (onp.ndarray, np.ndarray)):
        raise TypeError('`x1` must be an ndarray. A {} is found.'.format(
            type(x1)))

    if x2 is None:
        return True

    if x1 is x2:
        return True

    if x1.shape != x2.shape:
        return False

    if jax.default_backend() == 'tpu':
        eps = 1e-4

    return np.all(np.abs(x1 - x2) < eps)
Exemple #23
0
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Exemple #24
0
def _device_to_host_funcs():
  """Generates device-to-host transfer functions."""
  if jax.default_backend() == "cpu":
    # device-to-host does not incur transfer on the CPU backend.
    return []

  with jax.transfer_guard_host_to_device("allow"):
    device_arrays = [jnp.ones(1) for _ in range(6)]
  return [
      # (function name, is an explicit transfer?, function)
      ("device_to_host_jax_device_get", True,
       lambda: jax.device_get(device_arrays[0])),
      ("device_to_host_np_asarray", False,
       lambda: np.asarray(device_arrays[1])),
      ("device_to_host_copy_to_host_async", False,
       lambda: device_arrays[2].copy_to_host_async()),
      ("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)),
      ("device_to_host_str", False, lambda: str(device_arrays[4])),
      ("device_to_host_pickle_dumps", False,
       lambda: pickle.dumps(device_arrays[5])),
  ]
Exemple #25
0
def _subsample_fn(size, subsample_size, rng_key=None):
    if rng_key is None:
        raise ValueError(
            "Missing random key to generate subsample indices."
            " Algorithms like HMC/NUTS do not support subsampling."
            " You might want to use SVI or HMCECS instead."
        )
    if jax.default_backend() == "cpu":
        # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
        rng_keys = random.split(rng_key, subsample_size)

        def body_fn(val, idx):
            i_p1 = size - idx
            i = i_p1 - 1
            j = random.randint(rng_keys[idx], (), 0, i_p1)
            val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])])
            return val, None

        val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size))
        return val[-subsample_size:]
    else:
        return random.choice(rng_key, size, (subsample_size,), replace=False)
Exemple #26
0
def setUpModule():
  if jax.default_backend() not in {'gpu', 'tpu'}:
    raise unittest.SkipTest("pjit only supports GPU and TPU backends")
  jtu.set_spmd_lowering_flag(True)
Exemple #27
0
def fori_collect(
    lower,
    upper,
    body_fun,
    init_val,
    transform=identity,
    progbar=True,
    return_last_val=False,
    collection_size=None,
    thinning=1,
    **progbar_opts,
):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param bool return_last_val: If `True`, the last value is also returned.
        This has the same type as `init_val`.
    :param thinning: Positive integer that controls the thinning ratio for retained
        values. Defaults to 1, i.e. no thinning.
    :param int collection_size: Size of the returned collection. If not
        specified, the size will be ``(upper - lower) // thinning``. If the
        size is larger than ``(upper - lower) // thinning``, only the top
        ``(upper - lower) // thinning`` entries will be non-zero.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower <= upper
    assert thinning >= 1
    collection_size = ((upper - lower) // thinning
                       if collection_size is None else collection_size)
    assert collection_size >= (upper - lower) // thinning
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
    start_idx = lower + (upper - lower) % thinning
    num_chains = progbar_opts.pop("num_chains", 1)
    # host_callback does not work yet with multi-GPU platforms
    # See: https://github.com/google/jax/issues/6447
    if num_chains > 1 and jax.default_backend() == "gpu":
        warnings.warn(
            "We will disable progress bar because it does not work yet on multi-GPUs platforms.",
            stacklevel=find_stack_level(),
        )
        progbar = False

    @cached_by(fori_collect, body_fun, transform)
    def _body_fn(i, vals):
        val, collection, start_idx, thinning = vals
        val = body_fun(val)
        idx = (i - start_idx) // thinning
        collection = cond(
            idx >= 0,
            collection,
            lambda x: x.at[idx].set(ravel_pytree(transform(val))[0]),
            collection,
            identity,
        )
        return val, collection, start_idx, thinning

    collection = jnp.zeros((collection_size, ) + init_val_flat.shape,
                           dtype=init_val_flat.dtype)
    if not progbar:
        last_val, collection, _, _ = fori_loop(
            0, upper, _body_fn, (init_val, collection, start_idx, thinning))
    elif num_chains > 1:
        progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
        _body_fn_pbar = progress_bar_fori_loop(_body_fn)
        last_val, collection, _, _ = fori_loop(
            0, upper, _body_fn_pbar,
            (init_val, collection, start_idx, thinning))
    else:
        diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
        progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "")

        vals = (init_val, collection, device_put(start_idx),
                device_put(thinning))
        if upper == 0:
            # special case, only compiling
            jit(_body_fn)(0, vals)
        else:
            with tqdm.trange(upper) as t:
                for i in t:
                    vals = jit(_body_fn)(i, vals)
                    t.set_description(progbar_desc(i), refresh=False)
                    if diagnostics_fn:
                        t.set_postfix_str(diagnostics_fn(vals[0]),
                                          refresh=False)

        last_val, collection, _, _ = vals

    unravel_collection = vmap(unravel_fn)(collection)
    return (unravel_collection,
            last_val) if return_last_val else unravel_collection
    def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in,
                       fan_in_mode):
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis == 0) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` '
                                    'require `is_gaussian`.')

        if ((axis == 1 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis requires a dense layer '
                'after concatenation or Hadamard product.')
        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        if n_branches != 2:
            test_utils.skip_test(self)

        key = random.PRNGKey(1)
        X0_1 = np.cos(random.normal(key, (4, 3)))
        X0_2 = None if same_inputs else random.normal(key, (8, 3))

        width = 1024
        n_samples = 256 * 2

        if default_backend() == 'tpu':
            tol = 0.07
        else:
            tol = 0.02

        dense = stax.Dense(width, 1.25, 0.1)
        input_layers = [dense, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [dense]
            branches += [stax.serial(*branch_layers)]

        output_layers = [fan_in_layer, stax.Relu()]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, dense)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = nn
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(1, 1.25, 0.5))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -2) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -2) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout, fan_in_mode):
        test_utils.skip_test(self)
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd()` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                                    'require `is_gaussian`.')

        if ((axis == 3 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis '
                'requires a dense layer after concatenation '
                'or Hadamard product.')

        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if default_backend() == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Conv(out_chan=int(width * multiplier),
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            fan_in_layer,
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -4) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)