예제 #1
0
        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))

            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(data_train, None, params)
                g_td = kernel_fn(data_test, data_train, params)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count
예제 #2
0
        def mc_sampling(count=10):
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))
            collect_test_predict = []
            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(x_train, None, params)
                g_td = kernel_fn(x_test, x_train, params)
                predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

                fx_initial_train = f(params, x_train)
                fx_initial_test = f(params, x_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                collect_test_predict.append(fx_pred_test)
            collect_test_predict = np.array(collect_test_predict)
            mean_emp = np.mean(collect_test_predict, axis=0)
            mean_subtracted = collect_test_predict - mean_emp
            cov_emp = np.einsum(
                'ijk,ilk->jl', mean_subtracted, mean_subtracted,
                optimize=True) / (mean_subtracted.shape[0] *
                                  mean_subtracted.shape[-1])
            return mean_emp, cov_emp
예제 #3
0
    def testNTKMeanCovPrediction(self, train_shape, test_shape, network,
                                 out_logits):
        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        init_fn, f, kernel_fn = stax.serial(
            stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        reg = 1e-6
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)
        ts = np.array([1., 5., 10.])

        fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True)
        self.assertEqual(cov_test_inf.shape[1], x_test.shape[0])
        self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8)

        fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True)
        self.assertEqual(cov_train_inf.shape[1], x_train.shape[0])
        self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8)

        _kernel_fn = empirical.empirical_kernel_fn(f)
        kernel_fn = jit(
            lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params))

        def predict_empirical(key):
            _, params = init_fn(key, train_shape)
            g_dd = kernel_fn(x_train, None, params)
            g_td = kernel_fn(x_test, x_train, params)
            predict_fn = predict.gradient_descent_mse(g_dd,
                                                      y_train,
                                                      diag_reg=reg)
            fx_train_0 = f(params, x_train)
            fx_test_0 = f(params, x_test)
            return predict_fn(ts, fx_train_0, fx_test_0, g_td)

        def predict_mc(count, key):
            key = random.split(key, count)
            fx_train, fx_test = vmap(predict_empirical)(key)
            fx_train_mean = np.mean(fx_train, axis=0, keepdims=True)
            fx_test_mean = np.mean(fx_test, axis=0, keepdims=True)

            fx_train_centered = fx_train - fx_train_mean
            fx_test_centered = fx_test - fx_test_mean

            cov_train = PredictTest._cov_empirical(fx_train_centered)
            cov_test = PredictTest._cov_empirical(fx_test_centered)

            return fx_train_mean, fx_test_mean, cov_train, cov_test

        fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc(
            4096, key)
        rtol = 0.05
        self._assertAllClose(fx_train_mc, fx_train_inf, rtol)
        self._assertAllClose(cov_train_mc, cov_train_inf, rtol)
        self._assertAllClose(cov_test_mc, cov_test_inf, rtol)
        self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
  def testGpInference(self):
    reg = 1e-5
    key = random.PRNGKey(1)
    x_train = random.normal(key, (4, 2))
    init_fn, apply_fn, kernel_fn_analytic = stax.serial(
        stax.Dense(32, 2., 0.5),
        stax.Relu(),
        stax.Dense(10, 2., 0.5))
    y_train = random.normal(key, (4, 10))
    for kernel_fn_is_analytic in [True, False]:
      if kernel_fn_is_analytic:
        kernel_fn = kernel_fn_analytic
      else:
        _, params = init_fn(key, x_train.shape)
        kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)
        def kernel_fn(x1, x2, get):
          return kernel_fn_empirical(x1, x2, get, params)

      for get in [None,
                  'nngp', 'ntk',
                  ('nngp',), ('ntk',),
                  ('nngp', 'ntk'), ('ntk', 'nngp')]:
        k_dd = kernel_fn(x_train, None, get)

        gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
        gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                            x_train,
                                                            y_train,
                                                            diag_reg=reg)
        for x_test in [None, 'x_test']:
          x_test = None if x_test is None else random.normal(key, (8, 2))
          k_td = None if x_test is None else kernel_fn(x_test, x_train, get)

          for compute_cov in [True, False]:
            with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
                              get=get,
                              x_test=x_test if x_test is None else 'x_test',
                              compute_cov=compute_cov):
              if compute_cov:
                nngp_tt = (True if x_test is None else
                           kernel_fn(x_test, None, 'nngp'))
              else:
                nngp_tt = None

              out_ens = gd_ensemble(None, x_test, get, compute_cov)
              out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
              self._assertAllClose(out_ens_inf, out_ens, 0.08)

              if (get is not None and
                  'nngp' not in get and
                  compute_cov and
                  k_td is not None):
                with self.assertRaises(ValueError):
                  out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                            nngp_test_test=nngp_tt)
              else:
                out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                          nngp_test_test=nngp_tt)
                self.assertAllClose(out_ens, out_gp_inf)
예제 #5
0
    def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)

        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
        sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
            kernel_fn, init_fn, batch_size, device_count, store_on_device)

        one_sample = sample_once_fn(x1, x2, key, get)
        one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
        self.assertAllClose(one_sample, one_sample_batch, True)
예제 #6
0
    def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)
        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn,
                                                            init_fn,
                                                            device_count=0)
        batch_sample_once_fn = batch.batch(sample_once_fn, batch_size,
                                           device_count, store_on_device)
        if get is None:
            raise jtu.SkipTest('No default `get` values for this method.')
        else:
            one_sample = sample_once_fn(x1, x2, key, get)
            one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
            self.assertAllClose(one_sample, one_batch_sample, True)
예제 #7
0
def monte_carlo_kernel_fn(
    init_fn: InitFn,
    apply_fn: ApplyFn,
    key: PRNGKey,
    n_samples: Union[int, Iterable[int]],
    batch_size: int = 0,
    device_count: int = -1,
    store_on_device: bool = True,
    trace_axes: Axes = (-1, ),
    diagonal_axes: Axes = ()
) -> MonteCarloKernelFn:
    """Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

  Note that the returned function is appropriately batched / parallelized. You
  don't need to apply the `nt.batch` or `jax.jit` decorators  to it. Further,
  you do not need to apply `jax.jit` to the input `apply_fn` function, as the
  resulting empirical kernel function is JITted internally.

  Args:
    init_fn:
      a function initializing parameters of the neural network. From
      `jax.experimental.stax`: "takes an rng key and an input shape and returns
      an `(output_shape, params)` pair".
    apply_fn:
      a function computing the output of the neural network.
      From `jax.experimental.stax`: "takes params, inputs, and an rng key and
      applies the layer".
    key:
      RNG (`jax.random.PRNGKey`) for sampling random networks. Must have
      shape `(2,)`.
    n_samples:
      number of Monte Carlo samples. Can be either an integer or an
      iterable of integers at which the resulting generator will yield
      estimates. Example: use `n_samples=[2**k for k in range(10)]` for the
      generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples.
    batch_size: an integer making the kernel computed in batches of `x1` and
      `x2` of this size. `0` means computing the whole kernel. Must divide
      `x1.shape[0]` and `x2.shape[0]`.
    device_count:
      an integer making the kernel be computed in parallel across
      this number of devices (e.g. GPUs or TPU cores). `-1` means use all
      available devices. `0` means compute on a single device sequentially. If
      not `0`, must divide `x1.shape[0]`.
    store_on_device:
      a boolean, indicating whether to store the resulting
      kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger
      kernels may fit.
    trace_axes:
      output axes to trace the output kernel over, i.e. compute only the trace
      of the covariance along the respective pair of axes (one pair for each
      axis in `trace_axes`). This allows to save space and compute if you are
      only interested in the respective trace, but also improve approximation
      accuracy if you know that covariance along these pairs of axes converges
      to a `constant * identity matrix` in the limit of interest (e.g.
      infinite width or infinite `n_samples`). A common use case is the channel
      / feature / logit axis, since activation slices along such axis are i.i.d.
      and the respective covariance along the respective pair of axes indeed
      converges to a constant-diagonal matrix in the infinite width or infinite
      `n_samples` limit.
      Also related to "contracting dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)
    diagonal_axes:
      output axes to diagonalize the output kernel over, i.e. compute only the
      diagonal of the covariance along the respective pair of axes (one pair for
      each axis in `diagonal_axes`). This allows to save space and compute, if
      off-diagonal values along these axes are not needed, but also improve
      approximation accuracy if their limiting value is known theoretically,
      e.g. if they vanish in the limit of interest (e.g. infinite
      width or infinite `n_samples`). If you further know that on-diagonal
      values converge to the same constant in your limit of interest, you should
      specify these axes in `trace_axes` instead, to save even more compute and
      gain even more accuracy. A common use case is computing the variance
      (instead of covariance) along certain axes.
      Also related to "batch dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  Returns:
    If `n_samples` is an integer, returns a function of signature
    `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using
    `n_samples`. If `n_samples` is a collection of integers,
    `kernel_fn(x1, x2, get)` returns a generator that yields estimates using
    `n` samples for `n in n_samples`.

  Example:
    >>> from jax import random
    >>> import neural_tangents as nt
    >>> from neural_tangents import stax
    >>>
    >>> key1, key2 = random.split(random.PRNGKey(1), 2)
    >>> x_train = random.normal(key1, (20, 32, 32, 3))
    >>> y_train = random.uniform(key1, (20, 10))
    >>> x_test = random.normal(key2, (5, 32, 32, 3))
    >>>
    >>> init_fn, apply_fn, _ = stax.serial(
    >>>     stax.Conv(128, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(256, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(512, (3, 3)),
    >>>     stax.Flatten(),
    >>>     stax.Dense(10)
    >>> )
    >>>
    >>> n_samples = 200
    >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
    >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
    >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
    >>>
    >>> n_samples = [1, 10, 100, 1000]
    >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
    >>>                                                n_samples)
    >>> kernel_samples = kernel_fn_generator(x_train, x_test,
    >>>                                      get=('nngp', 'ntk'))
    >>> for n, kernel in zip(n_samples, kernel_samples):
    >>>   print(n, kernel)
    >>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
  """
    kernel_fn = empirical.empirical_kernel_fn(apply_fn,
                                              trace_axes=trace_axes,
                                              diagonal_axes=diagonal_axes)

    kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn,
                                                   batch_size, device_count,
                                                   store_on_device)

    n_samples, get_generator = _canonicalize_n_samples(n_samples)
    kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples,
                                       get_generator)
    return kernel_fn
예제 #8
0
def _empirical_kernel(key, input_shape, network, out_logits):
    init_fn, f, _ = _build_network(input_shape, network, out_logits)
    _, params = init_fn(key, (-1, ) + input_shape)
    _kernel_fn = empirical.empirical_kernel_fn(f)
    kernel_fn = lambda x1, x2, get: _kernel_fn(x1, x2, params, get)
    return params, f, jit(kernel_fn, static_argnums=(2, ))
예제 #9
0
def monte_carlo_kernel_fn(init_fn,
                          apply_fn,
                          key,
                          n_samples,
                          batch_size=0,
                          device_count=-1,
                          store_on_device=True):
    """Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

  Args:
    init_fn: a function initializing parameters of the neural network. From
      `jax.experimental.stax`: "takes an rng key and an input shape and returns
      an `(output_shape, params)` pair".
    apply_fn: a function computing the output of the neural network.
      From `jax.experimental.stax`: "takes params, inputs, and an rng key and
      applies the layer".
    key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have
      shape `(2,)`.
    n_samples: number of Monte Carlo samples. Can be either an integer or an
      iterable of integers at which the resulting generator will yield
      estimates. Example: use `n_samples=[2**k for k in range(10)]` for the
      generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples.
    batch_size: an integer making the kernel computed in batches of `x1` and
      `x2` of this size. `0` means computing the whole kernel. Must divide
      `x1.shape[0]` and `x2.shape[0]`.
    device_count: an integer making the kernel be computed in parallel across
      this number of devices (e.g. GPUs or TPU cores). `-1` means use all
      available devices. `0` means compute on a single device sequentially. If
      not `0`, must divide `x1.shape[0]`.
    store_on_device: a boolean, indicating whether to store the resulting
      kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger
      kernels may fit.

  Returns:
    If `n_samples` is an integer, returns a function of signature
    `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using
    `n_samples`. If `n_samples` is a collection of integers,
    `kernel_fn(x1, x2, get)` returns a generator that yields estimates using
    `n` samples for `n in n_samples`.

  Example:
  ```python
  >>> from jax import random
  >>> import neural_tangents as nt
  >>> from neural_tangents import stax
  >>>
  >>> key1, key2 = random.split(random.PRNGKey(1), 2)
  >>> x_train = random.normal(key1, (20, 32, 32, 3))
  >>> y_train = random.uniform(key1, (20, 10))
  >>> x_test = random.normal(key2, (5, 32, 32, 3))
  >>>
  >>> init_fn, apply_fn, kernel_fn = stax.serial(
  >>>     stax.Conv(128, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(256, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(512, (3, 3)),
  >>>     stax.Flatten(),
  >>>     stax.Dense(10)
  >>> )
  >>>
  >>> n_samples = 200
  >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
  >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
  >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
  >>>
  >>> n_samples = [1, 10, 100, 1000]
  >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
  >>>                                                n_samples)
  >>> kernel_samples = kernel_fn_generator(x_train, x_test, get=('nngp', 'ntk'))
  >>> for n, kernel in zip(n_samples, kernel_samples):
  >>>   print(n, kernel)
  >>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
  ```
  """
    kernel_fn = empirical.empirical_kernel_fn(apply_fn)

    kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn,
                                                   batch_size, device_count,
                                                   store_on_device)

    n_samples, get_generator = _canonicalize_n_samples(n_samples)
    kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples,
                                       get_generator)
    return kernel_fn
예제 #10
0
def monte_carlo_kernel_fn(init_fn: InitFn,
                          apply_fn: ApplyFn,
                          key: PRNGKey,
                          n_samples: Union[int, Iterable[int]],
                          batch_size: int = 0,
                          device_count: int = -1,
                          store_on_device: bool = True,
                          trace_axes: Axes = (-1, ),
                          diagonal_axes: Axes = (),
                          vmap_axes: VMapAxes = None,
                          implementation: int = 1) -> MonteCarloKernelFn:
    r"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

  Note that the returned function is appropriately batched / parallelized. You
  don't need to apply the `nt.batch` or `jax.jit` decorators  to it. Further,
  you do not need to apply `jax.jit` to the input `apply_fn` function, as the
  resulting empirical kernel function is JITted internally.

  Args:
    init_fn:
      a function initializing parameters of the neural network. From
      `jax.experimental.stax`: "takes an rng key and an input shape and returns
      an `(output_shape, params)` pair".
    apply_fn:
      a function computing the output of the neural network.
      From `jax.experimental.stax`: "takes params, inputs, and an rng key and
      applies the layer".
    key:
      RNG (`jax.random.PRNGKey`) for sampling random networks. Must have
      shape `(2,)`.
    n_samples:
      number of Monte Carlo samples. Can be either an integer or an
      iterable of integers at which the resulting generator will yield
      estimates. Example: use `n_samples=[2**k for k in range(10)]` for the
      generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples.
    batch_size: an integer making the kernel computed in batches of `x1` and
      `x2` of this size. `0` means computing the whole kernel. Must divide
      `x1.shape[0]` and `x2.shape[0]`.
    device_count:
      an integer making the kernel be computed in parallel across
      this number of devices (e.g. GPUs or TPU cores). `-1` means use all
      available devices. `0` means compute on a single device sequentially. If
      not `0`, must divide `x1.shape[0]`.
    store_on_device:
      a boolean, indicating whether to store the resulting
      kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger
      kernels may fit.
    trace_axes:
      output axes to trace the output kernel over, i.e. compute only the trace
      of the covariance along the respective pair of axes (one pair for each
      axis in `trace_axes`). This allows to save space and compute if you are
      only interested in the respective trace, but also improve approximation
      accuracy if you know that covariance along these pairs of axes converges
      to a `constant * identity matrix` in the limit of interest (e.g.
      infinite width or infinite `n_samples`). A common use case is the channel
      / feature / logit axis, since activation slices along such axis are i.i.d.
      and the respective covariance along the respective pair of axes indeed
      converges to a constant-diagonal matrix in the infinite width or infinite
      `n_samples` limit.
      Also related to "contracting dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)
    diagonal_axes:
      output axes to diagonalize the output kernel over, i.e. compute only the
      diagonal of the covariance along the respective pair of axes (one pair for
      each axis in `diagonal_axes`). This allows to save space and compute, if
      off-diagonal values along these axes are not needed, but also improve
      approximation accuracy if their limiting value is known theoretically,
      e.g. if they vanish in the limit of interest (e.g. infinite
      width or infinite `n_samples`). If you further know that on-diagonal
      values converge to the same constant in your limit of interest, you should
      specify these axes in `trace_axes` instead, to save even more compute and
      gain even more accuracy. A common use case is computing the variance
      (instead of covariance) along certain axes.
      Also related to "batch dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)
    vmap_axes:
      applicable only to NTK. A triple of `(in_axes, out_axes, kwargs_axes)`
      passed to `vmap` to evaluate the empirical NTK in parallel ove these axes.
      Precisely, providing this argument implies that `f(params, x, **kwargs)`
      equals to a concatenation along `out_axes` of `f` applied to slices of
      `x` and `**kwargs` along `in_axes` and `kwargs_axes`, i.e. `f` can be
      evaluated as a `vmap`. This allows to evaluate Jacobians much more
      efficiently. If `vmap_axes` is not a triple, it is interpreted as
      `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very
      common usecase is `vmap_axes=0` for a neural network with leading (`0`)
      batch dimension, both for inputs and outputs, and no interactions between
      different elements of the batch (e.g. no BatchNorm, and, in the case of
      `nt.stax`, also no Dropout). However, if there is interaction between
      batch elements or no concept of a batch axis at all, `vmap_axes` must be
      set to `None`, to avoid wrong (and potentially silent) results.
    implementation:
      applicable only to NTK.

      `1` or `2`.

      `1` directly instantiates Jacobians and computes their outer
      product.

      `2` uses implicit differentiation to avoid instantiating whole
      Jacobians at once. The implicit kernel is derived by observing that:
      :math:`\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)`,
      i.e. a linear function :math:`[J(X_1) J(X_2)^T]` applied to an identity
      matrix :math:`I`. This allows the computation of the NTK to be
      phrased as: :math:`a(v) = J(X_2)^T v`, which is computed by a
      vector-Jacobian product; :math:`b(v) = J(X_1) a(v)` which is computed by
      a Jacobian-vector product; and :math:`\Theta = [b(v)] / d[v^T](I)` which
      is computed via a `vmap` of :math:`b(v)` over columns of the identity
      matrix :math:`I`.

      It is best to benchmark each method on your specific task. We suggest
      using `1` unless you get OOMs due to large number of trainable parameters,
      otherwise - `2`.

  Returns:
    If `n_samples` is an integer, returns a function of signature
    `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using
    `n_samples`. If `n_samples` is a collection of integers,
    `kernel_fn(x1, x2, get)` returns a generator that yields estimates using
    `n` samples for `n in n_samples`.

  Example:
    >>> from jax import random
    >>> import neural_tangents as nt
    >>> from neural_tangents import stax
    >>>
    >>> key1, key2 = random.split(random.PRNGKey(1), 2)
    >>> x_train = random.normal(key1, (20, 32, 32, 3))
    >>> y_train = random.uniform(key1, (20, 10))
    >>> x_test = random.normal(key2, (5, 32, 32, 3))
    >>>
    >>> init_fn, apply_fn, _ = stax.serial(
    >>>     stax.Conv(128, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(256, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(512, (3, 3)),
    >>>     stax.Flatten(),
    >>>     stax.Dense(10)
    >>> )
    >>>
    >>> n_samples = 200
    >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
    >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
    >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
    >>>
    >>> n_samples = [1, 10, 100, 1000]
    >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
    >>>                                                n_samples)
    >>> kernel_samples = kernel_fn_generator(x_train, x_test,
    >>>                                      get=('nngp', 'ntk'))
    >>> for n, kernel in zip(n_samples, kernel_samples):
    >>>   print(n, kernel)
    >>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
  """
    kernel_fn = empirical.empirical_kernel_fn(apply_fn,
                                              trace_axes=trace_axes,
                                              diagonal_axes=diagonal_axes,
                                              vmap_axes=vmap_axes,
                                              implementation=implementation)

    kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn,
                                                   batch_size, device_count,
                                                   store_on_device)

    n_samples, get_generator = _canonicalize_n_samples(n_samples)
    kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples,
                                       get_generator)
    return kernel_fn
예제 #11
0
    def testPredictND(self):
        n_chan = 6
        key = random.PRNGKey(1)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = random.normal(key, (n_train, ) + im_shape)
        y_train = random.uniform(key, (n_train, 3, 2, n_chan))
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else random.normal(
                            key, (n_test, ) + im_shape)
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)