def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape): if isinstance(fx_train_or_state_0, ODEState): fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test qx_train_0 = fx_train_or_state_0.qx_train qx_test_0 = fx_train_or_state_0.qx_test else: fx_train_0 = fx_train_or_state_0 qx_train_0 = qx_test_0 = None if fx_train_0 is None: fx_train_0 = np.zeros_like(y_train, dtype) else: fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape) if fx_test_0 is not None: fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape) if momentum is None: if qx_train_0 is not None or qx_test_0 is not None: raise ValueError('Got passed momentum state variables, while ' '`momentum is None`.') else: qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None else np.broadcast_to(qx_train_0, y_train.shape)) qx_test_0 = (None if fx_test_0 is None else (np.zeros(fx_test_shape, dtype) if qx_test_0 is None else np.broadcast_to(qx_test_0, fx_test_shape))) return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0) # pytype: disable=wrong-arg-count
def testZeroTimeAgreement(self, train_shape, test_shape, network, out_logits): """Test that the NTK and NNGP agree at t=0.""" _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 predictor = predict.gradient_descent_mse_ensemble(ker_fun, x_train, y_train, diag_reg=reg) for x in (None, 'x_test'): with self.subTest(x=x): x = x if x is None else x_test zero = predictor(t=0.0, x_test=x, get=('NTK', 'NNGP'), compute_cov=True) if x is None: k = ker_fun(x_train, None, get='nngp') ref = (np.zeros_like(y_train, k.dtype), k) else: ref = (np.zeros((test_shape[0], out_logits)), ker_fun(x_test, None, get='nngp')) self.assertAllClose((ref, ) * 2, zero, check_dtypes=False) if x is None: zero_x = predictor(t=0.0, x_test=x_train, get=('NTK', 'NNGP'), compute_cov=True) self.assertAllClose((ref, ) * 2, zero_x)
def init(x0): vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape] return x0, np.zeros_like(x0), vs
def init(x0): m0 = np.zeros_like(x0) u0 = np.zeros_like(x0) return x0, m0, u0
def init(x0): avg_sq_grad = np.zeros_like(x0) mom = np.zeros_like(x0) return x0, avg_sq_grad, mom
def init(x0): g_sq = np.zeros_like(x0) m = np.zeros_like(x0) return x0, g_sq, m
def init(x0): v0 = np.zeros_like(x0) return x0, v0
def predict_fn( get: Get, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out