def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(data_train, None, params) g_td = kernel_fn(data_test, data_train, params) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count
def mc_sampling(count=10): key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) collect_test_predict = [] for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) collect_test_predict.append(fx_pred_test) collect_test_predict = np.array(collect_test_predict) mean_emp = np.mean(collect_test_predict, axis=0) mean_subtracted = collect_test_predict - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) return mean_emp, cov_emp
def testNTKMeanCovPrediction(self, train_shape, test_shape, network, out_logits): key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) init_fn, f, kernel_fn = stax.serial( stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) reg = 1e-6 predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) ts = np.array([1., 5., 10.]) fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True) self.assertEqual(cov_test_inf.shape[1], x_test.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8) fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True) self.assertEqual(cov_train_inf.shape[1], x_train.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params)) def predict_empirical(key): _, params = init_fn(key, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) return predict_fn(ts, fx_train_0, fx_test_0, g_td) def predict_mc(count, key): key = random.split(key, count) fx_train, fx_test = vmap(predict_empirical)(key) fx_train_mean = np.mean(fx_train, axis=0, keepdims=True) fx_test_mean = np.mean(fx_test, axis=0, keepdims=True) fx_train_centered = fx_train - fx_train_mean fx_test_centered = fx_test - fx_test_mean cov_train = PredictTest._cov_empirical(fx_train_centered) cov_test = PredictTest._cov_empirical(fx_test_centered) return fx_train_mean, fx_test_mean, cov_train, cov_test fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc( 4096, key) rtol = 0.05 self._assertAllClose(fx_train_mc, fx_train_inf, rtol) self._assertAllClose(cov_train_mc, cov_train_inf, rtol) self._assertAllClose(cov_test_mc, cov_test_inf, rtol) self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
def testGpInference(self): reg = 1e-5 key = random.PRNGKey(1) x_train = random.normal(key, (4, 2)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = random.normal(key, (4, 10)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [None, 'nngp', 'ntk', ('nngp',), ('ntk',), ('nngp', 'ntk'), ('ntk', 'nngp')]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else random.normal(key, (8, 2)) k_td = None if x_test is None else kernel_fn(x_test, x_train, get) for compute_cov in [True, False]: with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def test_sample_once_batch(self, batch_size, device_count, store_on_device, get): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = empirical.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn) sample_once_batch_fn = monte_carlo._sample_once_kernel_fn( kernel_fn, init_fn, batch_size, device_count, store_on_device) one_sample = sample_once_fn(x1, x2, key, get) one_sample_batch = sample_once_batch_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_sample_batch, True)
def test_batch_sample_once(self, batch_size, device_count, store_on_device, get): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = empirical.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn, device_count=0) batch_sample_once_fn = batch.batch(sample_once_fn, batch_size, device_count, store_on_device) if get is None: raise jtu.SkipTest('No default `get` values for this method.') else: one_sample = sample_once_fn(x1, x2, key, get) one_batch_sample = batch_sample_once_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_batch_sample, True)
def monte_carlo_kernel_fn( init_fn: InitFn, apply_fn: ApplyFn, key: PRNGKey, n_samples: Union[int, Iterable[int]], batch_size: int = 0, device_count: int = -1, store_on_device: bool = True, trace_axes: Axes = (-1, ), diagonal_axes: Axes = () ) -> MonteCarloKernelFn: """Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. Note that the returned function is appropriately batched / parallelized. You don't need to apply the `nt.batch` or `jax.jit` decorators to it. Further, you do not need to apply `jax.jit` to the input `apply_fn` function, as the resulting empirical kernel function is JITted internally. Args: init_fn: a function initializing parameters of the neural network. From `jax.experimental.stax`: "takes an rng key and an input shape and returns an `(output_shape, params)` pair". apply_fn: a function computing the output of the neural network. From `jax.experimental.stax`: "takes params, inputs, and an rng key and applies the layer". key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have shape `(2,)`. n_samples: number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use `n_samples=[2**k for k in range(10)]` for the generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples. batch_size: an integer making the kernel computed in batches of `x1` and `x2` of this size. `0` means computing the whole kernel. Must divide `x1.shape[0]` and `x2.shape[0]`. device_count: an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). `-1` means use all available devices. `0` means compute on a single device sequentially. If not `0`, must divide `x1.shape[0]`. store_on_device: a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in `trace_axes`). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a `constant * identity matrix` in the limit of interest (e.g. infinite width or infinite `n_samples`). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite `n_samples` limit. Also related to "contracting dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) diagonal_axes: output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in `diagonal_axes`). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite `n_samples`). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in `trace_axes` instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) Returns: If `n_samples` is an integer, returns a function of signature `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using `n_samples`. If `n_samples` is a collection of integers, `kernel_fn(x1, x2, get)` returns a generator that yields estimates using `n` samples for `n in n_samples`. Example: >>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> >>> init_fn, apply_fn, _ = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> stax.Conv(256, (3, 3)), >>> stax.Relu(), >>> stax.Conv(512, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. >>> >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) >>> kernel_samples = kernel_fn_generator(x_train, x_test, >>> get=('nngp', 'ntk')) >>> for n, kernel in zip(n_samples, kernel_samples): >>> print(n, kernel) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples. """ kernel_fn = empirical.empirical_kernel_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes) kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn, batch_size, device_count, store_on_device) n_samples, get_generator = _canonicalize_n_samples(n_samples) kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples, get_generator) return kernel_fn
def _empirical_kernel(key, input_shape, network, out_logits): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1, ) + input_shape) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = lambda x1, x2, get: _kernel_fn(x1, x2, params, get) return params, f, jit(kernel_fn, static_argnums=(2, ))
def monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=-1, store_on_device=True): """Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. Args: init_fn: a function initializing parameters of the neural network. From `jax.experimental.stax`: "takes an rng key and an input shape and returns an `(output_shape, params)` pair". apply_fn: a function computing the output of the neural network. From `jax.experimental.stax`: "takes params, inputs, and an rng key and applies the layer". key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have shape `(2,)`. n_samples: number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use `n_samples=[2**k for k in range(10)]` for the generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples. batch_size: an integer making the kernel computed in batches of `x1` and `x2` of this size. `0` means computing the whole kernel. Must divide `x1.shape[0]` and `x2.shape[0]`. device_count: an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). `-1` means use all available devices. `0` means compute on a single device sequentially. If not `0`, must divide `x1.shape[0]`. store_on_device: a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit. Returns: If `n_samples` is an integer, returns a function of signature `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using `n_samples`. If `n_samples` is a collection of integers, `kernel_fn(x1, x2, get)` returns a generator that yields estimates using `n` samples for `n in n_samples`. Example: ```python >>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> >>> init_fn, apply_fn, kernel_fn = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> stax.Conv(256, (3, 3)), >>> stax.Relu(), >>> stax.Conv(512, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. >>> >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) >>> kernel_samples = kernel_fn_generator(x_train, x_test, get=('nngp', 'ntk')) >>> for n, kernel in zip(n_samples, kernel_samples): >>> print(n, kernel) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples. ``` """ kernel_fn = empirical.empirical_kernel_fn(apply_fn) kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn, batch_size, device_count, store_on_device) n_samples, get_generator = _canonicalize_n_samples(n_samples) kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples, get_generator) return kernel_fn
def monte_carlo_kernel_fn(init_fn: InitFn, apply_fn: ApplyFn, key: PRNGKey, n_samples: Union[int, Iterable[int]], batch_size: int = 0, device_count: int = -1, store_on_device: bool = True, trace_axes: Axes = (-1, ), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, implementation: int = 1) -> MonteCarloKernelFn: r"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. Note that the returned function is appropriately batched / parallelized. You don't need to apply the `nt.batch` or `jax.jit` decorators to it. Further, you do not need to apply `jax.jit` to the input `apply_fn` function, as the resulting empirical kernel function is JITted internally. Args: init_fn: a function initializing parameters of the neural network. From `jax.experimental.stax`: "takes an rng key and an input shape and returns an `(output_shape, params)` pair". apply_fn: a function computing the output of the neural network. From `jax.experimental.stax`: "takes params, inputs, and an rng key and applies the layer". key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have shape `(2,)`. n_samples: number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use `n_samples=[2**k for k in range(10)]` for the generator to yield estimates using 1, 2, 4, ..., 512 Monte Carlo samples. batch_size: an integer making the kernel computed in batches of `x1` and `x2` of this size. `0` means computing the whole kernel. Must divide `x1.shape[0]` and `x2.shape[0]`. device_count: an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). `-1` means use all available devices. `0` means compute on a single device sequentially. If not `0`, must divide `x1.shape[0]`. store_on_device: a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in `trace_axes`). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a `constant * identity matrix` in the limit of interest (e.g. infinite width or infinite `n_samples`). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite `n_samples` limit. Also related to "contracting dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) diagonal_axes: output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in `diagonal_axes`). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite `n_samples`). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in `trace_axes` instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) vmap_axes: applicable only to NTK. A triple of `(in_axes, out_axes, kwargs_axes)` passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that `f(params, x, **kwargs)` equals to a concatenation along `out_axes` of `f` applied to slices of `x` and `**kwargs` along `in_axes` and `kwargs_axes`, i.e. `f` can be evaluated as a `vmap`. This allows to evaluate Jacobians much more efficiently. If `vmap_axes` is not a triple, it is interpreted as `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very common usecase is `vmap_axes=0` for a neural network with leading (`0`) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of `nt.stax`, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, `vmap_axes` must be set to `None`, to avoid wrong (and potentially silent) results. implementation: applicable only to NTK. `1` or `2`. `1` directly instantiates Jacobians and computes their outer product. `2` uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: :math:`\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)`, i.e. a linear function :math:`[J(X_1) J(X_2)^T]` applied to an identity matrix :math:`I`. This allows the computation of the NTK to be phrased as: :math:`a(v) = J(X_2)^T v`, which is computed by a vector-Jacobian product; :math:`b(v) = J(X_1) a(v)` which is computed by a Jacobian-vector product; and :math:`\Theta = [b(v)] / d[v^T](I)` which is computed via a `vmap` of :math:`b(v)` over columns of the identity matrix :math:`I`. It is best to benchmark each method on your specific task. We suggest using `1` unless you get OOMs due to large number of trainable parameters, otherwise - `2`. Returns: If `n_samples` is an integer, returns a function of signature `kernel_fn(x1, x2, get)` that returns an MC estimation of the kernel using `n_samples`. If `n_samples` is a collection of integers, `kernel_fn(x1, x2, get)` returns a generator that yields estimates using `n` samples for `n in n_samples`. Example: >>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> >>> init_fn, apply_fn, _ = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> stax.Conv(256, (3, 3)), >>> stax.Relu(), >>> stax.Conv(512, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. >>> >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) >>> kernel_samples = kernel_fn_generator(x_train, x_test, >>> get=('nngp', 'ntk')) >>> for n, kernel in zip(n_samples, kernel_samples): >>> print(n, kernel) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples. """ kernel_fn = empirical.empirical_kernel_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=vmap_axes, implementation=implementation) kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn, batch_size, device_count, store_on_device) n_samples, get_generator = _canonicalize_n_samples(n_samples) kernel_fn = _sample_many_kernel_fn(kernel_fn_sample_once, key, n_samples, get_generator) return kernel_fn
def testPredictND(self): n_chan = 6 key = random.PRNGKey(1) im_shape = (5, 4, 3) n_train = 2 n_test = 2 x_train = random.normal(key, (n_train, ) + im_shape) y_train = random.uniform(key, (n_train, 3, 2, n_chan)) init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2)) _, params = init_fn(key, x_train.shape) fx_train_0 = apply_fn(params, x_train) for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ), (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3), (0, 1, -1, 2)]: for ts in [None, np.arange(6).reshape((2, 3))]: for x in [None, 'x_test']: with self.subTest(trace_axes=trace_axes, ts=ts, x=x): t_shape = ts.shape if ts is not None else () y_test_shape = t_shape + (n_test, ) + y_train.shape[1:] y_train_shape = t_shape + y_train.shape x = x if x is None else random.normal( key, (n_test, ) + im_shape) fx_test_0 = None if x is None else apply_fn(params, x) kernel_fn = empirical.empirical_kernel_fn( apply_fn, trace_axes=trace_axes) # TODO(romann): investigate the SIGTERM error on CPU. # kernel_fn = jit(kernel_fn, static_argnums=(2,)) ntk_train_train = kernel_fn(x_train, None, 'ntk', params) if x is not None: ntk_test_train = kernel_fn(x, x_train, 'ntk', params) loss = lambda x, y: 0.5 * np.mean(x - y)**2 predict_fn_mse = predict.gradient_descent_mse( ntk_train_train, y_train, trace_axes=trace_axes) predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, trace_axes=trace_axes, params=params) if x is None: p_train_mse = predict_fn_mse(ts, fx_train_0) else: p_train_mse, p_test_mse = predict_fn_mse( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test_mse.shape) self.assertAllClose(y_train_shape, p_train_mse.shape) p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble( ts, x, ('nngp', 'ntk'), compute_cov=True) ref_shape = y_train_shape if x is None else y_test_shape self.assertAllClose(ref_shape, p_ntk_mse_ens.mean.shape) self.assertAllClose(ref_shape, p_nngp_mse_ens.mean.shape) if ts is not None: predict_fn = predict.gradient_descent( loss, ntk_train_train, y_train, trace_axes=trace_axes) if x is None: p_train = predict_fn(ts, fx_train_0) else: p_train, p_test = predict_fn( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test.shape) self.assertAllClose(y_train_shape, p_train.shape)