Ejemplo n.º 1
0
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(),
                                           stax.Dense(3))

        test_x1 = random.normal(input_key1, (12, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (9, 4, 4))

        kernel_fn = nt.empirical_ntk_fn(apply_fn,
                                        trace_axes=trace_axes,
                                        diagonal_axes=diagonal_axes,
                                        vmap_axes=0,
                                        implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batching.batch(kernel_fn,
                                    device_count=device_count,
                                    batch_size=3)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
Ejemplo n.º 2
0
  def test_monte_carlo_generator(self, batch_size, device_count,
                                 store_on_device, get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(8, 1)
    x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

    log_n_max = 4
    n_samples = [2**k for k in range(log_n_max)]
    sample_generator = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, batch_size, device_count,
        store_on_device, vmap_axes=0)

    if get is None:
      samples_12 = sample_generator(x1, x2)
      samples_34 = sample_generator(x3, x4)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key,
                                                      n, batch_size,
                                                      device_count,
                                                      store_on_device,
                                                      vmap_axes=0)
        sample_12 = sample_fn(x1, x2)
        sample_34 = sample_fn(x3, x4)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk'))
      ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk'))

    else:
      samples_12 = sample_generator(x1, x2, get)
      samples_34 = sample_generator(x3, x4, get)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n, batch_size,
            device_count, store_on_device, vmap_axes=0)
        sample_12 = sample_fn(x1, x2, get)
        sample_34 = sample_fn(x3, x4, get)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, get)
      ker_analytic_34 = stax_kernel_fn(x3, x4, get)

    self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
    self.assertAllClose(ker_analytic_12, ker_analytic_34)
Ejemplo n.º 3
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batching, 2)
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batching._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Ejemplo n.º 4
0
  def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                             get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
    kernel_fn = nt.empirical_kernel_fn(apply_fn)
    sample_once_fn = monte_carlo._sample_once_kernel_fn(
        kernel_fn, init_fn, device_count=0)
    batch_sample_once_fn = batching.batch(sample_once_fn, batch_size,
                                          device_count, store_on_device)
    one_sample = sample_once_fn(x1, x2, key, get)
    one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
    self.assertAllClose(one_sample, one_batch_sample)
Ejemplo n.º 5
0
  def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                   store_on_device):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
        WIDTH, 256, jax.default_backend() == 'tpu')

    sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                               batch_size, device_count,
                                               store_on_device)

    ker_empirical = sample(x1, x2, 'nngp')
    ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

    test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Ejemplo n.º 6
0
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2 = random.split(rng, 2)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1))

        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1))
            x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batching.batch(K_readin_fn, 2)
        batch_K_readout_fn = batching.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batching.batch(K_readin_fn, 2)
        batch_K_readout_fn = batching.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)
Ejemplo n.º 7
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batching, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batching.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batching.batch(kernel_fn,
                                        batch_size=batch_size,
                                        store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Ejemplo n.º 8
0
    def testComposition(self, train_shape, test_shape, network, name,
                        kernel_fn, batch_size):
        test_utils.stub_out_pmap(batching, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batching._parallel(
            batching._serial(kernel_fn, batch_size=batch_size))
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batching._serial(batching._parallel(kernel_fn),
                                          batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Ejemplo n.º 9
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1))
        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1))
            x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Ejemplo n.º 10
0
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batching, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batching, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])
Ejemplo n.º 11
0
 def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size):
     test_utils.stub_out_pmap(batching, 2)
     self._test_analytic_kernel_composition(
         partial(batching.batch,
                 batch_size=batch_size,
                 store_on_device=store_on_device))
Ejemplo n.º 12
0
 def testAnalyticKernelComposeParallel(self):
     test_utils.stub_out_pmap(batching, 2)
     self._test_analytic_kernel_composition(batching._parallel)