def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = np.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = np.moveaxis(rhs, trace_axes, last_t_axes).reshape((-1, ) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = fx_train_0 + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = np.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = np.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t
def predict_fn_inf(fx_train_0, fx_test_0, k_test_train): fx_train_t = y_train.astype(k_train_train.dtype) if fx_test_0 is None: return fx_train_t rhs = y_train if fx_train_0 is None else y_train - fx_train_0 dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes), (odd, first)) dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is None: return fx_test_t return fx_train_t, fx_test_t
def dstate_dt(state_t: ODEState, unused_t) -> ODEState: fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train, state_t.fx_test, state_t.qx_train, state_t.qx_test) dy_df_t = grad_loss(fx_train_t) fx_train_t = -np.moveaxis( np.tensordot(k_train_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if fx_test_t is not None: fx_test_t = -np.moveaxis( np.tensordot(k_test_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if momentum is None: return ODEState(fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count fx_train_t += momentum * qx_train_t if qx_test_t is not None: fx_test_t += momentum * qx_test_t return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count
def outer_product(a): return np.tensordot(a, a, 0)
def predict_fn( get: Get, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out