Ejemplo n.º 1
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Ejemplo n.º 2
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Ejemplo n.º 3
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
Ejemplo n.º 4
0
  def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
Ejemplo n.º 5
0
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(256), stax.Relu(),
                                           stax.Dense(10))

        test_x1 = random.normal(input_key1, (50, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (60, 4, 4))

        kernel_fn = empirical.empirical_ntk_fn(apply_fn,
                                               trace_axes=trace_axes,
                                               diagonal_axes=diagonal_axes,
                                               vmap_axes=0,
                                               implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batch.batch(kernel_fn,
                                 device_count=device_count,
                                 batch_size=5)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
Ejemplo n.º 6
0
  def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn):
    utils.stub_out_pmap(batch, 2)

    key = random.PRNGKey(0)
    key, self_split, other_split = random.split(key, 3)
    data_self = random.normal(self_split, train_shape)
    data_other = random.normal(other_split, test_shape)

    kernel_fn = kernel_fn(key, train_shape[1:], network)

    kernel_batched = batch.batch(kernel_fn, batch_size=2)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)

    kernel_batched = batch.batch(kernel_fn, batch_size=2, store_on_device=False)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)
Ejemplo n.º 7
0
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, mc_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        get_ntk = utils.nt_tree_fn()(lambda k: k.ntk)

        test_utils.assert_close_matrices(
            self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))),
            get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
Ejemplo n.º 8
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
Ejemplo n.º 9
0
    def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)
        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn,
                                                            init_fn,
                                                            device_count=0)
        batch_sample_once_fn = batch.batch(sample_once_fn, batch_size,
                                           device_count, store_on_device)
        one_sample = sample_once_fn(x1, x2, key, get)
        one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
        self.assertAllClose(one_sample, one_batch_sample, True)
Ejemplo n.º 10
0
    def test_batch_sample_once(self, batch_size, device_count,
                               store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, _, key = _get_inputs_and_model()
        ker_fun = empirical.get_ker_fun_empirical(apply_fun)

        sample_once_fun = monte_carlo._get_ker_fun_sample_once(
            ker_fun, init_fun)
        one_sample = sample_once_fun(x1, x2, key)

        batch_sample_once_fun = batch.batch(
            monte_carlo._get_ker_fun_sample_once(ker_fun, init_fun),
            batch_size, device_count, store_on_device)
        one_batch_sample = batch_sample_once_fun(x1, x2, key)
        self.assertAllClose(one_sample, one_batch_sample, True)
Ejemplo n.º 11
0
    def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)
        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn,
                                                            init_fn,
                                                            device_count=0)
        batch_sample_once_fn = batch.batch(sample_once_fn, batch_size,
                                           device_count, store_on_device)
        if get is None:
            raise jtu.SkipTest('No default `get` values for this method.')
        else:
            one_sample = sample_once_fn(x1, x2, key, get)
            one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
            self.assertAllClose(one_sample, one_batch_sample, True)
Ejemplo n.º 12
0
    def test_sample_many_batch(self, batch_size, device_count,
                               store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fun, apply_fun, _, key = _get_inputs_and_model()
        ker_fun = empirical.get_ker_fun_empirical(apply_fun)

        sample_once_fun = monte_carlo._get_ker_fun_sample_once(
            ker_fun, init_fun)
        sample_many_fun = monte_carlo._get_ker_fun_sample_many(sample_once_fun)
        sample_many_batch_fun = monte_carlo._get_ker_fun_sample_many(
            batch.batch(sample_once_fun, batch_size, device_count,
                        store_on_device))

        many_samples = sample_many_fun(x1, x2, key, N_SAMPLES)
        many_samples_batch = sample_many_batch_fun(x1, x2, key, N_SAMPLES)
        self.assertAllClose(many_samples, many_samples_batch, True)
Ejemplo n.º 13
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)