def test_parallel_in_out_mc(self, same_inputs, batch_size):
    rng = random.PRNGKey(0)
    input_key1, input_key2, net_key = random.split(rng, 3)

    x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5))
    x1 = (x1_1, (x1_2, x1_3))

    if same_inputs:
      x2 = None
    else:
      x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 4, 5))
      x2 = (x2_1, (x2_2, x2_3))

    def net(N_out):
      return stax.parallel(stax.Dense(N_out),
                           stax.parallel(stax.Dense(N_out + 1),
                                         stax.Dense(N_out + 2)))

    # Check NNGP.
    init_fn, apply_fn, _ = net(WIDTH)

    nb_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                     apply_fn,
                                                     net_key,
                                                     n_samples=4,
                                                     trace_axes=(-1,))

    kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                  apply_fn,
                                                  net_key,
                                                  n_samples=4,
                                                  batch_size=batch_size,
                                                  trace_axes=(-1,))

    self.assertAllClose(kernel_fn(x1, x2, 'nngp'), nb_kernel_fn(x1, x2, 'nngp'))
  def test_monte_carlo_generator(self, batch_size, device_count,
                                 store_on_device, get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(8, 1)
    x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

    log_n_max = 4
    n_samples = [2**k for k in range(log_n_max)]
    sample_generator = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, batch_size, device_count,
        store_on_device, vmap_axes=0)

    if get is None:
      samples_12 = sample_generator(x1, x2)
      samples_34 = sample_generator(x3, x4)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key,
                                                      n, batch_size,
                                                      device_count,
                                                      store_on_device,
                                                      vmap_axes=0)
        sample_12 = sample_fn(x1, x2)
        sample_34 = sample_fn(x3, x4)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk'))
      ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk'))

    else:
      samples_12 = sample_generator(x1, x2, get)
      samples_34 = sample_generator(x3, x4, get)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n, batch_size,
            device_count, store_on_device, vmap_axes=0)
        sample_12 = sample_fn(x1, x2, get)
        sample_34 = sample_fn(x3, x4, get)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, get)
      ker_analytic_34 = stax_kernel_fn(x3, x4, get)

    self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
    self.assertAllClose(ker_analytic_12, ker_analytic_34)
  def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                   store_on_device):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
        WIDTH, 256, jax.default_backend() == 'tpu')

    sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                               batch_size, device_count,
                                               store_on_device)

    ker_empirical = sample(x1, x2, 'nngp')
    ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

    test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)