def _test_cg_gpr(config: ConfigDense, model: GPR, Xnew: tf.Tensor) -> tf.Tensor: """ Sample generation subroutine common to each unit test """ # Prepare preconditioner for CG X, y = model.data Kff = model.kernel(X, full_cov=True) max_rank = config.num_cond//(2 if config.num_cond > 1 else 1) preconditioner = get_default_preconditioner(Kff, diag=model.likelihood.variance, max_rank=max_rank) count = 0 L_joint = None samples = [] while count < config.num_samples: # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ size = min(config.shard_size, config.num_samples - count) # Generate draws from the joint distribution $p(f(X), f(Xnew))$ (f, fnew), L_joint = common.sample_joint(model.kernel, X, Xnew, num_samples=size, L=L_joint) # Solve for update functions update_fns = cg_update(model.kernel, X, y, f + model.mean_function(X), tol=1e-6, diag=model.likelihood.variance, max_iter=config.num_cond, preconditioner=preconditioner) samples.append(fnew + update_fns(Xnew)) count += size samples = tf.concat(samples, axis=0) if model.mean_function is not None: samples += model.mean_function(Xnew) return samples
def _test_exact_gpr(config: ConfigDense, model: GPR, Xnew: tf.Tensor) -> tf.Tensor: """ Sample generation subroutine common to each unit test """ # Precompute Cholesky factor (optional) X, y = model.data Kyy = model.kernel(X, full_cov=True) Kyy = tf.linalg.set_diag( Kyy, tf.linalg.diag_part(Kyy) + model.likelihood.variance) Lyy = tf.linalg.cholesky(Kyy) count = 0 L_joint = None samples = [] while count < config.num_samples: # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ size = min(config.shard_size, config.num_samples - count) # Generate draws from the joint distribution $p(f(X), f(Xnew))$ (f, fnew), L_joint = common.sample_joint(model.kernel, X, Xnew, num_samples=size, L=L_joint) # Solve for update functions update_fns = exact_update(model.kernel, X, y, f + model.mean_function(X), L=Lyy, diag=model.likelihood.variance) samples.append(fnew + update_fns(Xnew)) count += size samples = tf.concat(samples, axis=0) if model.mean_function is not None: samples += model.mean_function(Xnew) return samples