示例#1
0
    def w2(self, other):
        """Compute the 2-Wasserstein distance with respect to another normal
        distribution.

        Args:
            other (:class:`.random.Normal`): Other normal.

        Returns:
            scalar: 2-Wasserstein distance.
        """
        root = B.root(B.matmul(B.root(self.var), other.var, B.root(self.var)))
        var_part = B.trace(self.var) + B.trace(other.var) - 2 * B.trace(root)
        mean_part = B.sum((self.mean - other.mean) ** 2)
        # The sum of `mean_part` and `var_par` should be positive, but this
        # may not be the case due to numerical errors.
        return B.abs(mean_part + var_part) ** .5
示例#2
0
def test_marginals():
    model = Graph()
    p = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    x = np.linspace(0, 5, 10)

    # Check that `marginals` outputs the right thing.
    mean, lower, upper = p(x).marginals()
    var = B.diag(p.kernel(x))
    yield assert_allclose, mean, p.mean(x)[:, 0]
    yield assert_allclose, lower, p.mean(x)[:, 0] - 2 * var ** .5
    yield assert_allclose, upper, p.mean(x)[:, 0] + 2 * var ** .5

    # Test correctness.
    y = p(x).sample()
    mean, lower, upper = (p | (x, y))(x).marginals()
    yield assert_allclose, mean, y[:, 0]
    yield le, B.mean(B.abs(upper - lower)), 1e-5

    mean, lower, upper = (p | (x, y))(x + 100).marginals()
    yield assert_allclose, mean, p.mean(x + 100)[:, 0]
    yield assert_allclose, upper - lower, 4 * np.ones(10)