Example #1
0
def random_variable(
    gp: Prior, params: dict, sample_points: Array, jitter_amount: float = 1e-6
) -> tfd.Distribution:
    mu = gp.mean_function(sample_points)
    gram_matrix = gram(gp.kernel, sample_points, params)
    jitter_matrix = I(sample_points.shape[0]) * jitter_amount
    covariance = gram_matrix + jitter_matrix
    return tfd.MultivariateNormalFullCovariance(mu.squeeze(), covariance)
Example #2
0
def random_variable(gp: Prior,
                    params: dict,
                    sample_points: Dataset,
                    jitter_amount: float = 1e-6) -> tfd.Distribution:
    X = sample_points.X
    N = sample_points.n
    mu = gp.mean_function(X)
    gram_matrix = gram(gp.kernel, X, params)
    jitter_matrix = I(N) * jitter_amount
    covariance = gram_matrix + jitter_matrix
    return tfd.MultivariateNormalFullCovariance(mu.squeeze(), covariance)
Example #3
0
    def build_rv(test_points: Array):
        N = test_points.shape[0]
        phistar = jnp.matmul(test_points, jnp.transpose(w))
        phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)])
        mean = jnp.matmul(phistar, alpha)

        RtiPhistart = solve_triangular(RT, jnp.transpose(phistar))
        PhiRistar = jnp.transpose(RtiPhistart)
        cov = (params["obs_noise"] * params["variance"] / m *
               jnp.matmul(PhiRistar, jnp.transpose(PhiRistar)) + I(N) * 1e-6)
        return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
Example #4
0
def random_variable(
    gp: ConjugatePosterior,
    params: dict,
    sample_points: Array,
    train_inputs: Array,
    train_outputs: Array,
    jitter_amount: float = 1e-6,
) -> tfd.Distribution:
    n = sample_points.shape[0]
    # TODO: Return kernel matrices here to avoid replicated computation.
    mu = mean(gp, params, sample_points, train_inputs, train_outputs)
    cov = variance(gp, params, sample_points, train_inputs, train_outputs)
    return tfd.MultivariateNormalFullCovariance(mu.squeeze(), cov + I(n) * jitter_amount)
Example #5
0
def random_variable(
    gp: SpectralPosterior,
    params: dict,
    train_inputs: Array,
    train_outputs: Array,
    test_inputs: Array,
    static_params: dict = None,
) -> tfd.Distribution:
    params = concat_dictionaries(params, static_params)
    m = gp.prior.kernel.num_basis
    w = params["basis_fns"] / params["lengthscale"]
    phi = gp.prior.kernel._build_phi(train_inputs, params)

    A = (params["variance"] / m) * jnp.matmul(jnp.transpose(phi), phi) + params["obs_noise"] * I(
        2 * m
    )

    RT = jnp.linalg.cholesky(A)
    R = jnp.transpose(RT)

    RtiPhit = solve_triangular(RT, jnp.transpose(phi))
    # Rtiphity=RtiPhit*y_tr;
    Rtiphity = jnp.matmul(RtiPhit, train_outputs)

    alpha = params["variance"] / m * solve_triangular(R, Rtiphity, lower=False)

    phistar = jnp.matmul(test_inputs, jnp.transpose(w))
    # phistar = [cos(phistar) sin(phistar)];                              % test design matrix
    phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)])
    # out1(beg_chunk:end_chunk) = phistar*alfa;                           % Predictive mean
    mean = jnp.matmul(phistar, alpha)
    print(mean.shape)

    RtiPhistart = solve_triangular(RT, jnp.transpose(phistar))
    PhiRistar = jnp.transpose(RtiPhistart)
    cov = (
        params["obs_noise"]
        * params["variance"]
        / m
        * jnp.matmul(PhiRistar, jnp.transpose(PhiRistar))
        + I(test_inputs.shape[0]) * 1e-6
    )
    return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
Example #6
0
 def log_prob(self, data, covariates, **kwargs):
     d = tfp_dists.MultivariateNormalFullCovariance(
         self.predict(covariates), self.covariance_matrix)
     return d.log_prob(data)
Example #7
0
 def sample(self, covariates, seed, sample_shape=()):
     d = tfp_dists.MultivariateNormalFullCovariance(
         self.predict(covariates), self.covariance_matrix)
     return d.sample(sample_shape=sample_shape, seed=seed)
Example #8
0
 def build_rv(test_points: Array):
     n = test_points.shape[0]
     mu = meanf(test_points)
     cov = covf(test_points)
     return tfd.MultivariateNormalFullCovariance(mu.squeeze(),
                                                 cov + I(n) * jitter_amount)