Exemplo n.º 1
0
def test_kron2sum_large_covariance():

    random = RandomState(0)
    n = 50
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    scale = 1e4

    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lmm.fit(verbose=False)

    lmm_large = Kron2Sum(Y, A, F, scale * G, restricted=False)
    lmm_large.fit(verbose=False)

    assert_allclose(lmm_large.lml(), lmm.lml())
    assert_allclose(lmm_large.C0, lmm.C0 / (scale**2), rtol=1e-3, atol=1e-5)
    assert_allclose(lmm_large.C1, lmm.C1, rtol=1e-3, atol=1e-5)
    assert_allclose(lmm_large.beta, lmm.beta, rtol=1e-3, atol=1e-5)
    assert_allclose(lmm_large.beta_covariance,
                    lmm.beta_covariance,
                    rtol=1e-3,
                    atol=1e-5)
    assert_allclose(lmm_large.mean(), lmm.mean(), rtol=1e-2, atol=1e-5)
    assert_allclose(lmm_large.covariance(),
                    lmm.covariance(),
                    rtol=1e-3,
                    atol=1e-5)
Exemplo n.º 2
0
def test_kron2sum_large_outcome():

    random = RandomState(2)
    n = 50
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    B = random.randn(2, 3)
    C0 = random.randn(3, 3)
    C0 = C0 @ C0.T
    C1 = random.randn(3, 3)
    C1 = C1 @ C1.T
    K = kron(C0, (G @ G.T)) + kron(C1, eye(n))
    y = multivariate_normal(random, kron(A, F) @ vec(B), K)
    Y = unvec(y, (n, 3))
    Y = Y / Y.std(0)

    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lmm.fit(verbose=False)

    assert_allclose(lmm.lml(), -12.163158697588926)
    assert_allclose(lmm.C0[0, 1], -0.004781646218546575, rtol=1e-3, atol=1e-5)
    assert_allclose(lmm.C1[0, 1], 0.03454122242999587, rtol=1e-3, atol=1e-5)
    assert_allclose(lmm.beta[2], -0.02553979383437496, rtol=1e-3, atol=1e-5)
    assert_allclose(lmm.beta_covariance[0, 1],
                    0.0051326042358990865,
                    rtol=1e-3,
                    atol=1e-5)
    assert_allclose(lmm.mean()[3], 0.3442913781854699, rtol=1e-2, atol=1e-5)
    assert_allclose(lmm.covariance()[0, 1],
                    0.0010745698663887468,
                    rtol=1e-3,
                    atol=1e-5)
Exemplo n.º 3
0
    def fit(self, verbose=True):
        from glimix_core.lmm import Kron2Sum

        self._lmm = Kron2Sum(self._y, [[1]], self._M, self._W, restricted=True)
        self._lmm.fit(verbose=verbose)
        self._covarparam0 = self._lmm.C0[0, 0]
        self._covarparam1 = self._lmm.C1[0, 0]
Exemplo n.º 4
0
def test_kron2sum_gradient_unrestricted():
    random = RandomState(2)
    Y = random.randn(5, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(5, 2)
    G = random.randn(5, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lmm._cov.C0.Lu = random.randn(3)
    lmm._cov.C1.Lu = random.randn(6)

    def func(x):
        lmm._cov.C0.Lu = x[:3]
        lmm._cov.C1.Lu = x[3:9]
        return lmm.lml()

    def grad(x):
        lmm._cov.C0.Lu = x[:3]
        lmm._cov.C1.Lu = x[3:9]
        D = lmm.gradient()
        return concatenate((D["C0.Lu"], D["C1.Lu"]))

    assert_allclose(check_grad(func, grad, random.randn(9), epsilon=1e-8),
                    0,
                    atol=1e-3)
Exemplo n.º 5
0
def test_kron2sum_insufficient_sample_size():
    random = RandomState(0)
    n = 2
    Y = random.randn(n, 2)
    A = random.randn(2, 2)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 6)
    with pytest.warns(UserWarning):
        Kron2Sum(Y, A, F, G)
Exemplo n.º 6
0
def test_kron2sum_fit_C1_well_cond_redutant_F_unrestricted():
    random = RandomState(0)
    Y = random.randn(5, 2)
    A = random.randn(2, 2)
    A = A @ A.T
    F = random.randn(5, 2)
    F = concatenate((F, F), axis=1)
    G = random.randn(5, 2)
    with pytest.warns(UserWarning):
        Kron2Sum(Y, A, F, G, restricted=False)
Exemplo n.º 7
0
def test_lmm_kron_scan():
    random = RandomState(0)
    n = 20
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 6)
    lmm = Kron2Sum(Y, A, F, G, restricted=True)
    lmm.fit(verbose=False)
    scan = lmm.get_fast_scanner()

    m = lmm.mean()
    K = lmm.covariance()

    def func(scale):
        mv = st.multivariate_normal(m, scale * K)
        return -mv.logpdf(vec(Y))

    s = minimize(func, 1e-3, 5.0, 1e-5)[0]

    assert_allclose(scan.null_lml(),
                    st.multivariate_normal(m, s * K).logpdf(vec(Y)))
    assert_allclose(kron(A, F) @ scan.null_beta, m)

    A1 = random.randn(3, 2)
    F1 = random.randn(n, 4)

    r = scan.scan(A1, F1)
    assert_allclose(r["scale"], 0.7365021111700154, rtol=1e-3)

    m = kron(A, F) @ vec(r["effsizes0"]) + kron(A1, F1) @ vec(r["effsizes1"])

    def func(scale):
        mv = st.multivariate_normal(m, scale * K)
        return -mv.logpdf(vec(Y))

    s = minimize(func, 1e-3, 5.0, 1e-5)[0]

    assert_allclose(r["lml"], st.multivariate_normal(m, s * K).logpdf(vec(Y)))

    r = scan.scan(empty((3, 0)), F1)
    assert_allclose(r["lml"], -85.36667704747371, rtol=1e-4)
    assert_allclose(r["scale"], 0.8999995537936586, rtol=1e-3)
    assert_allclose(
        r["effsizes0"],
        [
            [0.21489119796865844, 0.6412947101778663, -0.7176143380221816],
            [0.8866722740598517, -0.18731140321348416, -0.26118052682069],
        ],
        rtol=1e-2,
        atol=1e-2,
    )
    assert_allclose(r["effsizes1"], [])
Exemplo n.º 8
0
def test_kron2sum_fit_C1_well_cond_redundant_Y_unrestricted():
    random = RandomState(0)
    Y = random.randn(5, 2)
    Y = concatenate((Y, Y), axis=1)
    A = random.randn(4, 4)
    A = A @ A.T
    F = random.randn(5, 2)
    G = random.randn(5, 2)
    with pytest.warns(UserWarning):
        lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lml = lmm.lml()
    assert_allclose(lml, -39.59627521826263)
Exemplo n.º 9
0
def test_lmm_kron_scan_unrestricted():
    random = RandomState(0)
    n = 15
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 6)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lmm.fit(verbose=False)
    scan = lmm.get_fast_scanner()

    assert_allclose(scan.null_scale, 1.0, rtol=1e-3)
    assert_allclose(lmm.beta_covariance, scan.null_beta_covariance, rtol=1e-3)
Exemplo n.º 10
0
    def _fit_lmm_multi_trait(self, verbose):
        from numpy import sqrt, asarray
        from glimix_core.lmm import Kron2Sum
        from numpy_sugar.linalg import economic_qs, ddot

        X = asarray(self._M, float)
        QS = economic_qs(self._covariance[0]._K)
        G = ddot(QS[0][0], sqrt(QS[1]))
        lmm = Kron2Sum(self._y, self._mean.A, X, G, rank=1, restricted=True)
        lmm.fit(verbose=verbose)
        self._glmm = lmm
        self._covariance[0]._set_kron2sum(lmm)
        self._covariance[1]._set_kron2sum(lmm)
        self._mean.B = lmm.B
Exemplo n.º 11
0
def test_kron2sum_fit_C1_well_cond_C0_fullrank_unrestricted():
    random = RandomState(0)
    Y = random.randn(5, 2)
    A = random.randn(2, 2)
    A = A @ A.T
    F = random.randn(5, 2)
    G = random.randn(5, 6)
    lmm = Kron2Sum(Y, A, F, G, rank=2, restricted=False)
    lml0 = lmm.lml()
    lmm.fit(verbose=False)
    lml1 = lmm.lml()
    assert_allclose([lml0, lml1], [-18.201106294121434, -11.853021889285362])
    grad = lmm.gradient()
    vars = grad.keys()
    assert_allclose(concatenate([grad[var] for var in vars]), [0] * 7,
                    atol=1e-2)
Exemplo n.º 12
0
def test_kron2sum_fit_C1_well_cond_unrestricted():
    random = RandomState(0)
    Y = random.randn(5, 2)
    A = random.randn(2, 2)
    A = A @ A.T
    F = random.randn(5, 2)
    G = random.randn(5, 6)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lml0 = lmm.lml()
    lmm.fit(verbose=False)
    lml1 = lmm.lml()
    assert_allclose([lml0, lml1], [-17.87016217772149, -11.853022179263597],
                    rtol=1e-5)
    grad = lmm.gradient()
    vars = grad.keys()
    assert_allclose(concatenate([grad[var] for var in vars]), [0] * 5,
                    atol=1e-2)
Exemplo n.º 13
0
def test_kron2sum_fit_ill_conditioned_unrestricted():
    random = RandomState(0)
    n = 30
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lml0 = lmm.lml()
    lmm.fit(verbose=False)
    lml1 = lmm.lml()
    assert_allclose([lml0, lml1], [-154.73966241953627, -122.97307227633186])
    grad = lmm.gradient()
    vars = grad.keys()
    assert_allclose(concatenate([grad[var] for var in vars]), [0] * 9,
                    atol=1e-2)
Exemplo n.º 14
0
def _mt_lmm(Y, A, M, QS, verbose):
    from glimix_core.lmm import Kron2Sum
    from numpy_sugar.linalg import ddot
    from numpy import sqrt, zeros

    if QS is None:
        KG = zeros((Y.shape[0], 1))
    else:
        KG = ddot(QS[0][0], sqrt(QS[1]))

    lmm = Kron2Sum(Y.values, A, M.values, KG, restricted=False)
    lmm.fit(verbose=verbose)
    sys.stdout.flush()

    C0 = lmm.C0
    C1 = lmm.C1

    return lmm.get_fast_scanner(), C0, C1
Exemplo n.º 15
0
def test_lmm_kron_scan_redundant():
    random = RandomState(0)
    n = 30
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 6)
    G = concatenate([G, G], axis=1)
    lmm = Kron2Sum(Y, A, F, G, restricted=True)
    lmm.fit(verbose=False)
    scan = lmm.get_fast_scanner()

    m = lmm.mean()
    K = lmm.covariance()

    def func(scale):
        mv = st.multivariate_normal(m, scale * K)
        return -mv.logpdf(vec(Y))

    s = minimize(func, 1e-3, 5.0, 1e-5)[0]

    assert_allclose(scan.null_lml(),
                    st.multivariate_normal(m, s * K).logpdf(vec(Y)))
    assert_allclose(kron(A, F) @ scan.null_beta, m)

    A1 = random.randn(3, 2)
    F1 = random.randn(n, 4)
    F1 = concatenate([F1, F1], axis=1)

    r = scan.scan(A1, F1)
    assert_allclose(r["scale"], 0.8843540849467378, rtol=1e-3)

    m = kron(A, F) @ vec(r["effsizes0"]) + kron(A1, F1) @ vec(r["effsizes1"])

    def func(scale):
        mv = st.multivariate_normal(m, scale * K)
        return -mv.logpdf(vec(Y))

    s = minimize(func, 1e-3, 5.0, 1e-5)[0]

    assert_allclose(r["lml"], st.multivariate_normal(m, s * K).logpdf(vec(Y)))
Exemplo n.º 16
0
def test_kron2sum_unrestricted_lml():
    random = RandomState(0)
    Y = random.randn(5, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(5, 2)
    G = random.randn(5, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    y = vec(lmm._Y)

    m = lmm.mean()
    K = lmm.covariance()
    assert_allclose(lmm.lml(), st.multivariate_normal(m, K).logpdf(y))

    lmm._cov.C0.Lu = random.randn(3)
    m = lmm.mean()
    K = lmm.covariance()
    assert_allclose(lmm.lml(), st.multivariate_normal(m, K).logpdf(y))

    lmm._cov.C1.Lu = random.randn(6)
    m = lmm.mean()
    K = lmm.covariance()
    assert_allclose(lmm.lml(), st.multivariate_normal(m, K).logpdf(y))
Exemplo n.º 17
0
def test_lmm_kron_scan_with_lmm():
    random = RandomState(0)
    n = 15
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 6)

    klmm = Kron2Sum(Y, A, F, G, restricted=True)
    klmm.fit(verbose=False)
    kscan = klmm.get_fast_scanner()

    K = klmm.covariance()

    X = kron(A, F)
    QS = economic_qs(K)
    scan = FastScanner(vec(Y), X, QS, 0.0)

    assert_allclose(klmm.covariance(), K)
    assert_allclose(kscan.null_scale, scan.null_scale)
    assert_allclose(kscan.null_beta, scan.null_beta)
    assert_allclose(kscan.null_lml(), scan.null_lml())
    assert_allclose(kscan.null_beta_covariance, scan.null_beta_covariance)

    A1 = random.randn(3, 2)
    F1 = random.randn(n, 2)
    M = kron(A1, F1)

    kr = kscan.scan(A1, F1)
    r = scan.scan(M)
    assert_allclose(kr["lml"], r["lml"])
    assert_allclose(kr["scale"], r["scale"])
    assert_allclose(vec(kr["effsizes0"]), r["effsizes0"])
    assert_allclose(vec(kr["effsizes1"]), r["effsizes1"])
    assert_allclose(vec(kr["effsizes0_se"]), r["effsizes0_se"])
    assert_allclose(vec(kr["effsizes1_se"]), r["effsizes1_se"])
Exemplo n.º 18
0
def test_kron2sum_restricted():
    random = RandomState(0)
    n = 5
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=True)

    assert_allclose(lmm.lml(), -16.081058762513514)
    assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
    assert_equal(lmm.nsamples, n)
    assert_equal(lmm.ntraits, 3)
    assert_equal(lmm.ncovariates, 2)

    n = 5
    Y = random.randn(n, 1)
    A = random.randn(1, 1)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=True)
    lmm.name = "KronSum"

    assert_allclose(lmm.lml(), -3.7547099473445003)
    assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
    assert_allclose([lmm.mean()[0], lmm.mean()[1]],
                    [0.06452826276050515, 0.4855196092646256])

    assert_allclose(
        lmm.covariance(),
        [
            [
                1.9379300845374776,
                -0.02014070399890988,
                -0.7399969689595782,
                -0.1402228534612341,
                -0.4690219904509089,
            ],
            [
                -0.02014070399890988,
                1.4797056135059965,
                0.0916295591269426,
                -0.3210581381149237,
                0.2558220662032061,
            ],
            [
                -0.7399969689595782,
                0.0916295591269426,
                1.6313538475715865,
                -0.07164808824303559,
                0.5063738410283093,
            ],
            [
                -0.1402228534612341,
                -0.3210581381149237,
                -0.07164808824303559,
                3.333140431376828,
                0.3424485007527981,
            ],
            [
                -0.4690219904509089,
                0.2558220662032061,
                0.5063738410283093,
                0.3424485007527981,
                2.023907116315917,
            ],
        ],
    )

    assert_equal(lmm.nsamples, n)
    assert_equal(lmm.ntraits, 1)
    assert_equal(lmm.name, "KronSum")
    lmm.fit(verbose=False)
    grad = lmm.gradient()
    assert_allclose(grad["C0.Lu"], [0], atol=1e-4)
    assert_allclose(grad["C1.Lu"], [0], atol=1e-4)
    assert_allclose(lmm.lml(), -0.6930197328322949, rtol=1e-5)

    A = lmm.beta_covariance
    assert_allclose(
        A,
        [
            [4.831901045051292, -2.1320785310203645],
            [-2.1320785310203645, 0.9438229054009741],
        ],
        atol=1e-5,
        rtol=1e-5,
    )
Exemplo n.º 19
0
def test_kron2sum_unrestricted():
    random = RandomState(0)
    n = 5
    Y = random.randn(n, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)

    assert_allclose(lmm.lml(), -21.917751466118062)
    assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
    assert_equal(lmm.nsamples, n)
    assert_equal(lmm.ntraits, 3)
    assert_equal(lmm.ncovariates, 2)

    n = 5
    Y = random.randn(n, 1)
    A = random.randn(1, 1)
    A = A @ A.T
    F = random.randn(n, 2)
    G = random.randn(n, 4)
    lmm = Kron2Sum(Y, A, F, G, restricted=False)
    lmm.name = "KronSum"

    assert_allclose(lmm.lml(), -6.293806054115431)
    assert_allclose(lmm._check_grad(step=1e-7), 0, atol=1e-4)
    assert_allclose(
        lmm.mean(),
        [
            0.06452826276050515,
            0.4855196092646256,
            0.1396748241908668,
            -0.264600249205993,
            -0.08238891336460354,
        ],
    )

    assert_allclose(
        lmm.covariance(),
        [
            [
                1.9379300845374776,
                -0.02014070399890988,
                -0.7399969689595782,
                -0.1402228534612341,
                -0.4690219904509089,
            ],
            [
                -0.02014070399890988,
                1.4797056135059965,
                0.0916295591269426,
                -0.3210581381149237,
                0.2558220662032061,
            ],
            [
                -0.7399969689595782,
                0.0916295591269426,
                1.6313538475715865,
                -0.07164808824303559,
                0.5063738410283093,
            ],
            [
                -0.1402228534612341,
                -0.3210581381149237,
                -0.07164808824303559,
                3.333140431376828,
                0.3424485007527981,
            ],
            [
                -0.4690219904509089,
                0.2558220662032061,
                0.5063738410283093,
                0.3424485007527981,
                2.023907116315917,
            ],
        ],
    )

    assert_equal(lmm.nsamples, n)
    assert_equal(lmm.ntraits, 1)
    assert_equal(lmm.name, "KronSum")
    lmm.fit(verbose=False)
    grad = lmm.gradient()
    assert_allclose(grad["C0.Lu"], [0], atol=1e-4)
    assert_allclose(grad["C1.Lu"], [0], atol=1e-4)
    assert_allclose(lmm.lml(), 2.3394131683065957, rtol=1e-5)

    A = [
        [3.621700765362852, -1.5979882078099437],
        [-1.5979882078099474, 0.7081144405074323],
    ]
    assert_allclose(lmm.beta_covariance, A, atol=1e-5, rtol=1e-5)
Exemplo n.º 20
0
def test_kron2sum_interface():
    random = RandomState(2)
    Y = random.randn(2, 3)
    A = random.randn(3, 3)
    A = A @ A.T
    F = random.randn(2, 2)
    G = random.randn(2, 4)
    with pytest.warns(UserWarning):
        lmm = Kron2Sum(Y, A, F, G, restricted=False)

    assert_allclose(
        lmm.covariance(),
        [
            [
                14.086388186569708,
                2.460064191520785,
                14.086373285408515,
                2.460064191520785,
                14.086373285408515,
                2.460064191520785,
            ],
            [
                2.460064191520785,
                24.620081892940938,
                2.460064191520785,
                24.620066991779744,
                2.460064191520785,
                24.620066991779744,
            ],
            [
                14.086373285408515,
                2.460064191520785,
                15.086388186569708,
                2.460064191520785,
                15.086373285408515,
                2.460064191520785,
            ],
            [
                2.460064191520785,
                24.620066991779744,
                2.460064191520785,
                25.620081892940938,
                2.460064191520785,
                25.620066991779744,
            ],
            [
                14.086373285408515,
                2.460064191520785,
                15.086373285408515,
                2.460064191520785,
                16.086388186569707,
                2.460064191520785,
            ],
            [
                2.460064191520785,
                24.620066991779744,
                2.460064191520785,
                25.620066991779744,
                2.460064191520785,
                26.620081892940938,
            ],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.mean(),
        [
            -0.4167578796040061,
            1.6402707883933658,
            -0.05626685903041562,
            -1.7934356050990345,
            -2.1361961254498922,
            -0.8417473850096115,
        ],
        atol=1e-7,
    )
    assert_allclose(lmm.lml(), -8.429274310765745, atol=1e-7)
    assert_allclose(lmm.value(), lmm.lml(), atol=1e-7)
    assert_allclose(
        lmm.A,
        [
            [2.9228950357645274, -3.568888742838519, 0.8427306809792194],
            [-3.568888742838519, 6.384613981254706, 0.58138966904566],
            [0.8427306809792194, 0.58138966904566, 1.5420666949903108],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.B,
        [
            [-155.30629394331186, -98.10725907737903, 124.05223343020717],
            [-371.5199454125782, -234.2023411372212, 295.50286962569464],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.X,
        [
            [-0.596159699806467, -0.019130496521151476],
            [1.175001219500291, -0.7478709492938624],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.M,
        [
            [
                -1.7425122270871933,
                -0.05591643331338421,
                2.127627641573291,
                0.06827461367924896,
                -0.5024020697902709,
                -0.016121856360740573,
            ],
            [
                3.4344052314946665,
                -2.1859482850835352,
                -4.193448625096121,
                2.6690682120308225,
                0.9902095778608936,
                -0.630253794382992,
            ],
            [
                2.127627641573291,
                0.06827461367924896,
                -3.8062495544449777,
                -0.12214083555728823,
                -0.34660109056884186,
                -0.011122273041111408,
            ],
            [
                -4.193448625096121,
                2.6690682120308225,
                7.501929214012888,
                -4.774867319035823,
                0.6831335701335212,
                -0.43480444369882226,
            ],
            [
                -0.5024020697902709,
                -0.016121856360740573,
                -0.34660109056884186,
                -0.011122273041111408,
                -0.9193180179669744,
                -0.029500501543895694,
            ],
            [
                0.9902095778608936,
                -0.630253794382992,
                0.6831335701335212,
                -0.43480444369882226,
                1.8119302471643985,
                -1.1532668830568527,
            ],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.C0,
        [
            [15.191012511337567, 15.191012511337567, 15.191012511337567],
            [15.191012511337567, 15.191012511337567, 15.191012511337567],
            [15.191012511337567, 15.191012511337567, 15.191012511337567],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.C1,
        [
            [1.0000149011611938, 1.0, 1.0],
            [1.0, 2.000014901161194, 2.0],
            [1.0, 2.0, 3.000014901161194],
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.beta,
        [
            -155.3062975390569,
            -371.5199529125092,
            -98.10726134921538,
            -234.20234587581533,
            124.0522362217501,
            295.5028754471357,
        ],
        atol=1e-7,
    )
    assert_allclose(
        lmm.beta_covariance,
        [
            [
                180454.5259497599,
                301884.9772469587,
                114011.82157538977,
                190733.14284940425,
                -140134.13859849886,
                -234404.7154325302,
            ],
            [
                301884.97720675723,
                720240.068298505,
                190733.1428240255,
                455058.3880442711,
                -234404.7154008012,
                -559111.09170418,
            ],
            [
                114011.82157538975,
                190733.14284942497,
                72033.11947254256,
                120506.65851010302,
                -88536.94418846158,
                -148098.13529283906,
            ],
            [
                190733.14282400408,
                455058.38804427016,
                120506.6584940551,
                287512.77822660643,
                -148098.13527277583,
                -353253.7315868218,
            ],
            [
                -140134.13859850125,
                -234404.71543202398,
                -88536.94418846308,
                -148098.13529250308,
                108827.80847946613,
                182016.00480204573,
            ],
            [
                -234404.71540133483,
                -559111.0917042121,
                -148098.13527312962,
                -353253.7315868427,
                182016.00477781752,
                434044.7463389145,
            ],
        ],
        atol=1e-7,
    )
    assert_equal(lmm.ncovariates, 2)
    assert_equal(lmm.nsamples, 2)
    assert_equal(lmm.ntraits, 3)
    assert_equal(lmm.name, "Kron2Sum")