コード例 #1
0
def test_combine():
    x1 = B.linspace(0, 2, 10)
    x2 = B.linspace(2, 4, 10)

    m = Measure()
    p1 = GP(EQ(), measure=m)
    p2 = GP(Matern12(), measure=m)
    y1 = p1(x1).sample()
    y2 = p2(x2).sample()

    # Check the one-argument case.
    assert_equal_normals(combine(p1(x1, 1)), p1(x1, 1))
    fdd_combined, y_combined = combine((p1(x1, 1), B.squeeze(y1)))
    assert_equal_normals(fdd_combined, p1(x1, 1))
    approx(y_combined, y1)

    # Check the two-argument case.
    fdd_combined = combine(p1(x1, 1), p2(x2, 2))
    assert_equal_normals(
        fdd_combined,
        Normal(B.block_diag(p1(x1, 1).var,
                            p2(x2, 2).var)),
    )
    fdd_combined, y_combined = combine((p1(x1, 1), B.squeeze(y1)),
                                       (p2(x2, 2), y2))
    assert_equal_normals(
        fdd_combined,
        Normal(B.block_diag(p1(x1, 1).var,
                            p2(x2, 2).var)),
    )
    approx(y_combined, B.concat(y1, y2, axis=0))
コード例 #2
0
    def predict(self, x, latent=False, return_variances=False):
        """Predict.

        Args:
            x (matrix): Input locations to predict at.
            latent (bool, optional): Predict noiseless processes. Defaults
                to `False`.
            return_variances (bool, optional): Return means and variances
                instead. Defaults to `False`.

        Returns:
            tuple: Tuple containing means, lower 95% central credible bound,
                and upper 95% central credible bound if `variances` is `False`,
                and means and variances otherwise.
        """
        mean = B.stack(*[B.squeeze(B.dense(f.mean(x))) for f in self.fs], axis=1)
        var = B.stack(*[B.squeeze(f.kernel.elwise(x)) for f in self.fs], axis=1)

        if not latent:
            var = var + self.noises[None, :]

        if return_variances:
            return mean, var
        else:
            error = 1.96 * B.sqrt(var)
            return mean, mean - error, mean + error
コード例 #3
0
ファイル: test_shaping.py プロジェクト: wesselb/lab
def test_squeeze(check_lazy_shapes):
    check_function(B.squeeze, (Tensor(3, 4, 5),))
    check_function(B.squeeze, (Tensor(1, 4, 5),))
    check_function(B.squeeze, (Tensor(1, 4, 5),), {"axis": Value(None, 0)})
    check_function(B.squeeze, (Tensor(3, 1, 5),))
    check_function(B.squeeze, (Tensor(3, 1, 5),), {"axis": Value(None, 1)})
    check_function(B.squeeze, (Tensor(1, 4, 1),))
    check_function(B.squeeze, (Tensor(1, 4, 1),), {"axis": Value(None, 0, 2)})

    # Test squeezing lists and tuples
    assert B.squeeze((1,)) == 1
    assert B.squeeze((1, 2)) == (1, 2)
    assert B.squeeze([1]) == 1
    assert B.squeeze([1, 2]) == [1, 2]
コード例 #4
0
ファイル: multiply.py プロジェクト: wesselb/matrix
def multiply(a: Diagonal, b: AbstractMatrix):
    assert_compatible(a, b)
    # In the case of broadcasting, `B.diag(b)` will not get the diagonal of the
    # broadcasted version of `b`, so we exercise extra caution in that case.
    rows, cols = B.shape(b)
    if rows == 1 or cols == 1:
        b_diag = B.squeeze(B.dense(b))
    else:
        b_diag = B.diag(b)
    return Diagonal(a.diag * b_diag)
コード例 #5
0
def squeeze(a: AbstractMatrix):
    if structured(a):
        warn_upmodule(f"Squeezing {a}: converting to dense.",
                      category=ToDenseWarning)
    return B.squeeze(B.dense(a))
コード例 #6
0
def _convert(x: B.Numeric):
    return B.squeeze(B.to_numpy(x))