Exemple #1
0
  def test_jit_or_pmap_broadcast(self):
    def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None,
                  p=0.65):
      res = np.abs(np.matmul(x1, x2))
      if do_square:
        res *= res
      if do_flip:
        res = -res

      res *= random.uniform(keys) * p
      return [res, params]

    params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
    x2 = np.arange(0, 10).reshape((10,))
    keys = random.PRNGKey(1)

    kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0)
    x1 = np.arange(0, 10).reshape((1, 10))
    for do_flip in [True, False]:
      for do_square in [True, False]:
        with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0):
          res_1 = kernel_fn(
              x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65)
          res_2 = kernel_fn_pmapped(
              x1, x2, do_flip, keys, do_square, params, _unused=True)
          self.assertAllClose(res_1, res_2, True)

    utils.stub_out_pmap(batch, 1)
    x1 = np.arange(0, 10).reshape((1, 10))
    kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1)
    for do_flip in [True, False]:
      for do_square in [True, False]:
        with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1):
          res_1 = kernel_fn(
              x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65)
          res_2 = kernel_fn_pmapped(
              x1, x2, do_flip, keys, do_square, params, _unused=None)
          self.assertAllClose(res_1[0], res_2[0], True)
          self.assertAllClose(
              tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1],
              True)

    kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2)
    x1 = np.arange(0, 20).reshape((2, 10))
    utils.stub_out_pmap(batch, 2)

    def broadcast(arg):
      return np.broadcast_to(arg, (2,) + arg.shape)

    for do_flip in [True, False]:
      for do_square in [True, False]:
        with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2):
          res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2)
          res_2 = kernel_fn_pmapped(
              x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2)
          self.assertAllClose(res_1[0][0], res_2[0][0], True)
          self.assertAllClose(res_1[0][1], res_2[0][1], True)
          self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1], True)
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        utils.stub_out_pmap(batch, 2)
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Exemple #3
0
    def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)

        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
        sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
            kernel_fn, init_fn, batch_size, device_count, store_on_device)

        one_sample = sample_once_fn(x1, x2, key, get)
        one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
        self.assertAllClose(one_sample, one_sample_batch, True)
Exemple #4
0
    def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                     store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            1024, 256,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'nngp')
        ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Exemple #5
0
    def test_batch_sample_once(self, batch_size, device_count,
                               store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, _, key = _get_inputs_and_model()
        ker_fun = empirical.get_ker_fun_empirical(apply_fun)

        sample_once_fun = monte_carlo._get_ker_fun_sample_once(
            ker_fun, init_fun)
        one_sample = sample_once_fun(x1, x2, key)

        batch_sample_once_fun = batch.batch(
            monte_carlo._get_ker_fun_sample_once(ker_fun, init_fun),
            batch_size, device_count, store_on_device)
        one_batch_sample = batch_sample_once_fun(x1, x2, key)
        self.assertAllClose(one_sample, one_batch_sample, True)
Exemple #6
0
    def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)
        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn,
                                                            init_fn,
                                                            device_count=0)
        batch_sample_once_fn = batch.batch(sample_once_fn, batch_size,
                                           device_count, store_on_device)
        if get is None:
            raise jtu.SkipTest('No default `get` values for this method.')
        else:
            one_sample = sample_once_fn(x1, x2, key, get)
            one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
            self.assertAllClose(one_sample, one_batch_sample, True)
Exemple #7
0
  def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn):
    utils.stub_out_pmap(batch, 2)

    key = random.PRNGKey(0)
    key, self_split, other_split = random.split(key, 3)
    data_self = random.normal(self_split, train_shape)
    data_other = random.normal(other_split, test_shape)

    kernel_fn = kernel_fn(key, train_shape[1:], network)

    kernel_batched = batch.batch(kernel_fn, batch_size=2)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)

    kernel_batched = batch.batch(kernel_fn, batch_size=2, store_on_device=False)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)
Exemple #8
0
  def testComposition(self, train_shape, test_shape, network, name, kernel_fn):
    utils.stub_out_pmap(batch, 2)

    key = random.PRNGKey(0)
    key, self_split, other_split = random.split(key, 3)
    data_self = random.normal(self_split, train_shape)
    data_other = random.normal(other_split, test_shape)

    kernel_fn = kernel_fn(key, train_shape[1:], network)

    kernel_batched = batch._parallel(batch._serial(kernel_fn, batch_size=2))
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)

    kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=2)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)
Exemple #9
0
    def test_sample_many_batch(self, batch_size, device_count,
                               store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, _, key = _get_inputs_and_model()
        ker_fun = empirical.get_ker_fun_empirical(apply_fun)

        sample_once_fun = monte_carlo._get_ker_fun_sample_once(
            ker_fun, init_fun)
        sample_many_fun = monte_carlo._get_ker_fun_sample_many(sample_once_fun)
        sample_many_batch_fun = monte_carlo._get_ker_fun_sample_many(
            batch.batch(sample_once_fun, batch_size, device_count,
                        store_on_device))

        many_samples = sample_many_fun(x1, x2, key, N_SAMPLES)
        many_samples_batch = sample_many_batch_fun(x1, x2, key, N_SAMPLES)
        self.assertAllClose(many_samples, many_samples_batch, True)
Exemple #10
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            256, 2,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'ntk')
        ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) /
                         ker_empirical.shape[-1])

        ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Exemple #11
0
    def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                     store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, stax_ker_fun, key = _get_inputs_and_model(
            512, 512)

        sample = monte_carlo.get_ker_fun_monte_carlo(init_fun, apply_fun, True,
                                                     False, batch_size,
                                                     device_count,
                                                     store_on_device)

        ker_empirical = sample(x1, x2, key, 200).nngp
        ker_analytic = stax_ker_fun(x1,
                                    x2,
                                    compute_ntk=False,
                                    compute_nngp=True)
        ker_analytic = ker_analytic.nngp

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 1e-2)
Exemple #12
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, stax_ker_fun, key = _get_inputs_and_model(
            512, 2)

        sample = monte_carlo.get_ker_fun_monte_carlo(init_fun, apply_fun,
                                                     False, True, batch_size,
                                                     device_count,
                                                     store_on_device)

        ker_empirical = sample(x1, x2, key, 100).ntk
        ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) /
                         ker_empirical.shape[-1])

        ker_analytic = stax_ker_fun(x1,
                                    x2,
                                    compute_ntk=True,
                                    compute_nngp=True)
        ker_analytic = ker_analytic.ntk

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 1e-2)
Exemple #13
0
    def test_monte_carlo_generator(self, batch_size, device_count,
                                   store_on_device, get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            8, 1)
        x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

        log_n_max = 4
        n_samples = [2**k for k in range(log_n_max)]
        sample_generator = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n_samples, batch_size, device_count,
            store_on_device)

        if get is None:
            samples_12 = sample_generator(x1, x2)
            samples_34 = sample_generator(x3, x4)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2)
                sample_34 = sample_fn(x3, x4)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2)
            ker_analytic_34 = stax_kernel_fn(x3, x4)

        else:
            samples_12 = sample_generator(x1, x2, get)
            samples_34 = sample_generator(x3, x4, get)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2, get)
                sample_34 = sample_fn(x3, x4, get)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, get)
            ker_analytic_34 = stax_kernel_fn(x3, x4, get)

        if get == 'ntk':
            s_12 = np.squeeze(s_12, (-1, -2))
        elif get is None or 'ntk' in get:
            s_12 = s_12._replace(ntk=np.squeeze(s_12.ntk, (-1, -2)))

        self.assertAllClose(ker_analytic_12, s_12, True, 2., 2.)
        self.assertAllClose(ker_analytic_12, ker_analytic_34, True)
Exemple #14
0
 def testAnalyticKernelComposeAutomatic(self, store_on_device):
   utils.stub_out_pmap(batch, 2)
   self._test_analytic_kernel_composition(
       partial(batch.batch, batch_size=2, store_on_device=store_on_device))
Exemple #15
0
 def testAnalyticKernelComposeParallel(self):
   utils.stub_out_pmap(batch, 2)
   self._test_analytic_kernel_composition(batch._parallel)