Exemplo n.º 1
0
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(256), stax.Relu(),
                                           stax.Dense(10))

        test_x1 = random.normal(input_key1, (50, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (60, 4, 4))

        kernel_fn = empirical.empirical_ntk_fn(apply_fn,
                                               trace_axes=trace_axes,
                                               diagonal_axes=diagonal_axes,
                                               vmap_axes=0,
                                               implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batch.batch(kernel_fn,
                                 device_count=device_count,
                                 batch_size=5)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
Exemplo n.º 2
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Exemplo n.º 3
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
    def test_monte_carlo_generator(self, batch_size, device_count,
                                   store_on_device, get):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            8, 1)
        x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

        log_n_max = 4
        n_samples = [2**k for k in range(log_n_max)]
        sample_generator = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n_samples, batch_size, device_count,
            store_on_device)

        if get is None:
            samples_12 = sample_generator(x1, x2)
            samples_34 = sample_generator(x3, x4)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2)
                sample_34 = sample_fn(x3, x4)
                self.assertAllClose(s_12, sample_12)
                self.assertAllClose(s_12, s_34)
                self.assertAllClose(s_12, sample_34)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk'))
            ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk'))

        else:
            samples_12 = sample_generator(x1, x2, get)
            samples_34 = sample_generator(x3, x4, get)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2, get)
                sample_34 = sample_fn(x3, x4, get)
                self.assertAllClose(s_12, sample_12)
                self.assertAllClose(s_12, s_34)
                self.assertAllClose(s_12, sample_34)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, get)
            ker_analytic_34 = stax_kernel_fn(x3, x4, get)

        self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
        self.assertAllClose(ker_analytic_12, ker_analytic_34)
Exemplo n.º 5
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Exemplo n.º 6
0
    def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                               get):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)

        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
        sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
            kernel_fn, init_fn, batch_size, device_count, store_on_device)

        one_sample = sample_once_fn(x1, x2, key, get)
        one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
        self.assertAllClose(one_sample, one_sample_batch)
Exemplo n.º 7
0
    def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                     store_on_device):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            WIDTH, 256,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'nngp')
        ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

        test_utils.assert_close_matrices(self, ker_analytic, ker_empirical,
                                         2e-2)
Exemplo n.º 8
0
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, mc_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        get_ntk = utils.nt_tree_fn()(lambda k: k.ntk)

        test_utils.assert_close_matrices(
            self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))),
            get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
Exemplo n.º 9
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Exemplo n.º 10
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Exemplo n.º 11
0
    def testComposition(self, train_shape, test_shape, network, name,
                        kernel_fn, batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch._parallel(
            batch._serial(kernel_fn, batch_size=batch_size))
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch._serial(batch._parallel(kernel_fn),
                                       batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Exemplo n.º 12
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            256, 2,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'ntk')
        ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) /
                         ker_empirical.shape[-1])

        ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

        test_utils.assert_close_matrices(self, ker_analytic, ker_empirical,
                                         2e-2)
Exemplo n.º 13
0
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])
Exemplo n.º 14
0
 def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size):
     test_utils.stub_out_pmap(batch, 2)
     self._test_analytic_kernel_composition(
         partial(batch.batch,
                 batch_size=batch_size,
                 store_on_device=store_on_device))
Exemplo n.º 15
0
 def testAnalyticKernelComposeParallel(self):
     test_utils.stub_out_pmap(batch, 2)
     self._test_analytic_kernel_composition(batch._parallel)