Пример #1
0
    def logpdf(self, x):
        """Compute the log-pdf.

        Args:
            x (input): Values to compute the log-pdf of.

        Returns:
            list[tensor]: Log-pdf for every input in `x`. If it can be
                determined that the list contains only a single log-pdf,
                then the list is flattened to a scalar.
        """
        x = B.uprank(x)

        # Handle missing data. We don't handle missing data for batched computation.
        if B.rank(x) == 2 and B.shape(x, 1) == 1:
            available = B.jit_to_numpy(~B.isnan(x[:, 0]))
            if not B.all(available):
                # Take the elements of the mean, variance, and inputs corresponding to
                # the available data.
                available_mean = B.take(self.mean, available)
                available_var = B.submatrix(self.var, available)
                available_x = B.take(x, available)
                return Normal(available_mean,
                              available_var).logpdf(available_x)

        logpdfs = (
            -(
                B.logdet(self.var)[...,
                                   None]  # Correctly line up with `iqf_diag`.
                +
                B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi) +
                B.iqf_diag(self.var, B.subtract(x, self.mean))) / 2)
        return logpdfs[..., 0] if B.shape(logpdfs, -1) == 1 else logpdfs
Пример #2
0
def test_predict_noisy(construct_ilmm, x):
    ilmm = construct_ilmm(noise_amplification=1000)

    y = ilmm.sample(x)
    ilmm = ilmm.condition(x, y)
    means, lowers, uppers = ilmm.predict(x, latent=False)

    # Test that predictions have high uncertainty.
    assert B.all(uppers - lowers > 10)
Пример #3
0
def test_predict_noiseless(construct_ilmm, x):
    ilmm = construct_ilmm(noise_amplification=1e-10)

    y = ilmm.sample(x)
    ilmm = ilmm.condition(x, y)
    means, lowers, uppers = ilmm.predict(x, latent=True)

    # Test that predictions match sample and have low uncertainty.
    approx(means, y, atol=1e-3)
    assert B.all(uppers - lowers < 1e-4)
Пример #4
0
def tuple_equal(x, y):
    """Check tuples for equality.

    Args:
        x (tuple): First tuple.
        y (tuple): Second tuple.

    Returns:
        bool: `x` and `y` are equal.
    """
    return len(x) == len(y) and \
           all([_shape(xi) == _shape(yi) and B.all(xi == yi)
                for xi, yi in zip(x, y)])
Пример #5
0
def test_mo_batched():
    x = B.randn(16, 10, 1)

    with Measure():
        p = cross(GP(1, 2 * EQ().stretch(0.5)), GP(2, 2 * EQ().stretch(0.5)))
    y = p(x).sample()
    logpdf = p(x, 0.1).logpdf(y)

    assert B.shape(logpdf) == (16, )
    assert B.shape(y) == (16, 20, 1)

    p = p | (p(x), y)
    y2 = p(x).sample()
    logpdf2 = p(x, 0.1).logpdf(y)

    assert B.shape(y2) == (16, 20, 1)
    assert B.shape(logpdf2) == (16, )
    assert B.all(logpdf2 > logpdf)
    approx(y, y2, atol=1e-5)
Пример #6
0
def test_batched():
    x1 = B.randn(16, 10, 1)
    x2 = B.randn(16, 5, 1)

    p = GP(1, 2 * EQ().stretch(0.5))
    y1, y2 = p.measure.sample(p(x1), p(x2))
    logpdf = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2))

    assert B.shape(y1) == (16, 10, 1)
    assert B.shape(y2) == (16, 5, 1)
    assert B.shape(logpdf) == (16, )

    p = p | ((p(x1), y1), (p(x2), y2))
    y1_2, y2_2 = p.measure.sample(p(x1), p(x2))
    logpdf2 = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2))

    assert B.shape(y1_2) == (16, 10, 1)
    assert B.shape(y2_2) == (16, 5, 1)
    approx(y1, y1_2, atol=1e-5)
    approx(y2, y2_2, atol=1e-5)
    assert B.shape(logpdf2) == (16, )
    assert B.all(logpdf2 > logpdf)
Пример #7
0
 def __eq__(self, other):
     return B.all(self.noises == other.noises)
Пример #8
0
 def __eq__(self, other):
     return B.all(self.alpha == other.alpha)
Пример #9
0
 def __eq__(self, other):
     return self[0] == other[0] and B.all(self.period == other.period)
Пример #10
0
 def __eq__(self, other):
     return B.all(self.alpha == other.alpha) and \
            B.all(self.beta == other.beta)
Пример #11
0
 def __eq__(self, other):
     return B.all(self.scale == other.scale) and self[0] == other[0]